From b6de54e48eb11308b6db6bead7ef7abd4af018ca Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 7 Jan 2025 13:39:57 -0800 Subject: [PATCH] Create an IFRT wrapper around NanoRT. This will allow NanoRT to be easily used from a caller that depends on IFRT, but we can add faster "pass-through" APIs as needed when we encounter performance defects. PiperOrigin-RevId: 713026345 --- xla/backends/cpu/nanort/BUILD | 72 +- xla/backends/cpu/nanort/ifrt_client.cc | 1420 +++++++++++++++++ xla/backends/cpu/nanort/ifrt_client.h | 197 +++ xla/backends/cpu/nanort/ifrt_client_test.cc | 34 + .../nanort/register_nanort_for_ifrt_tests.cc | 29 + 5 files changed, 1751 insertions(+), 1 deletion(-) create mode 100644 xla/backends/cpu/nanort/ifrt_client.cc create mode 100644 xla/backends/cpu/nanort/ifrt_client.h create mode 100644 xla/backends/cpu/nanort/ifrt_client_test.cc create mode 100644 xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc diff --git a/xla/backends/cpu/nanort/BUILD b/xla/backends/cpu/nanort/BUILD index 55e7bfc2687ce7..2e21559e1d8f5a 100644 --- a/xla/backends/cpu/nanort/BUILD +++ b/xla/backends/cpu/nanort/BUILD @@ -1,7 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/backends/cpu/nanort:package_groups.bzl", "xla_cpu_nanort_packages") load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -106,3 +105,74 @@ cc_library( "@tsl//tsl/profiler/lib:traceme_encode", ], ) + +cc_library( + name = "ifrt_client", + srcs = ["ifrt_client.cc"], + hdrs = ["ifrt_client.h"], + deps = [ + ":nanort_client", + ":nanort_executable", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu:alignment", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_sharding", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_layout", + "//xla/pjrt:utils", + "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt/hlo:hlo_program", + "//xla/python/pjrt_ifrt:pjrt_dtype", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:hlo_module_config", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:fingerprint", + ], +) + +cc_library( + name = "register_nanort_for_ifrt_tests", + testonly = True, + srcs = ["register_nanort_for_ifrt_tests.cc"], + deps = [ + ":ifrt_client", + "//xla/python/ifrt:test_util", + ], + alwayslink = True, +) + +xla_cc_test( + name = "ifrt_client_test", + srcs = ["ifrt_client_test.cc"], + deps = [ + ":register_nanort_for_ifrt_tests", + "//xla/python/ifrt:array_impl_test_lib", + "//xla/python/ifrt:client_impl_test_lib", + "//xla/python/ifrt:test_util", + "//xla/python/ifrt:tuple_impl_test_lib", + "//xla/python/pjrt_ifrt:xla_executable_impl_test_lib", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/backends/cpu/nanort/ifrt_client.cc b/xla/backends/cpu/nanort/ifrt_client.cc new file mode 100644 index 00000000000000..cf4365656b72f4 --- /dev/null +++ b/xla/backends/cpu/nanort/ifrt_client.cc @@ -0,0 +1,1420 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/nanort/ifrt_client.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/backends/cpu/alignment.h" +#include "xla/backends/cpu/nanort/nanort_executable.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/utils.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/index.h" +#include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" + +namespace xla::cpu { +namespace { + +static const char kMemoryKind[] = ""; + +// Returns a Future that is immediately ready with the given status. This is +// mostly useful because everything NanoRT does is immediately ready. +ifrt::Future<> Ready(absl::Status status = absl::OkStatus()) { + return ifrt::Future<>(std::move(status)); +} + +// Base class for all value types. This class doesn't participate in the llvm +// RTTI hierarchy (you can't dynamically cast to it), rather it just +// implements some virtual methods that have the same implementation for all +// NanoRT value types. +template +class NanoValue : public llvm::RTTIExtends { + public: + explicit NanoValue(NanoIfrtClient* client) : client_(client) {} + + ifrt::Client* client() const override { return client_; } + + // Called by subclasses to get access to client() without having to cast. + NanoIfrtClient* nano_client() const { return client_; } + + // All nano values are immediately ready. + ifrt::Future<> GetReadyFuture() const override { return Ready(); } + + // Subclasses must still implement Delete(). + ifrt::Future<> Delete() override = 0; + bool IsDeleted() const override = 0; + + // Helper that returns an error if this value is accessed after it has been + // deleted. Meant to be called with TF_RETURN_IF_ERROR at the top of + // relevant methods. + absl::Status ValidateNotDeleted() const { + if (IsDeleted()) { + return absl::FailedPreconditionError("Tried to access a deleted value."); + } + return absl::OkStatus(); + } + + private: + NanoIfrtClient* client_; +}; + +// Array implementation. +// +// This class always holds a continuous buffer of memory, if a sharding is +// provided, it will be disassembled as needed to satisfy caller expectations. +// +// See ShardedNanoArray for the case where the array is constructed from +// multiple existing shards. +class NanoArray final : public NanoValue { + public: + // A pointer to the underlying buffer. We use a shared_ptr because for some + // operations (like disassembly) we can just alias the memory, but we still + // need to support deletion of the NanoArray that created the buffer. + using DataPtr = std::shared_ptr; + + NanoArray(NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + DataPtr data, std::shared_ptr sharding) + : NanoValue(client), + dtype_(std::move(dtype)), + shape_(std::move(shape)), + data_(std::move(data)), + sharding_(std::move(sharding)) {} + + // Allocates a new array of the given type and shape. + static absl::StatusOr> Allocate( + NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding) { + TF_RET_CHECK(dtype.byte_size().has_value()); + TF_ASSIGN_OR_RETURN( + DataPtr data_ptr, + AllocateData(dtype.byte_size().value() * shape.num_elements())); + return tsl::TakeRef(new NanoArray(client, dtype, shape, std::move(data_ptr), + std::move(sharding))); + } + + // Creates an array from a host buffer. The buffer will be used directly + // without a copy if the copy semantics allow it and the layout is row major + // and dense. + static absl::StatusOr> FromBuffer( + NanoIfrtClient* client, void* data, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding, + std::optional> byte_strides, bool make_copy, + std::function on_done_with_host_buffer) { + auto size = dtype.byte_size().value_or(0) * shape.num_elements(); + TF_RET_CHECK(size > 0); + DataPtr data_ptr; + if (!on_done_with_host_buffer) { + on_done_with_host_buffer = [] {}; + } + bool layout_compatible = LayoutCompatible(dtype, shape, byte_strides); + bool aligned = reinterpret_cast(data) % Align() == 0; + + if (!layout_compatible || !aligned) { + // Input is not aligned, or has a weird layout, so we need to copy it. + make_copy = true; + } + + if (make_copy) { + TF_ASSIGN_OR_RETURN(data_ptr, AllocateData(size)); + if (layout_compatible) { + // Input has a compatible layout, so we can just do a memcpy. + memcpy(data_ptr.get(), data, size); + } else { + // Input has an incompatible layout, so we need to copy it with an + // appropriate stride. + TF_ASSIGN_OR_RETURN(auto dense_strides, DenseByteStrides(dtype, shape)); + TF_RETURN_IF_ERROR(CopyWithByteStrides( + reinterpret_cast(data_ptr.get()), dense_strides, + reinterpret_cast(data), + byte_strides.value_or(dense_strides), shape.dims(), + dtype.byte_size().value())); + } + // We're done with the input buffer, so we can allow the caller to clean + // it up. + on_done_with_host_buffer(); + } else { + // We're allowed to keep the input buffer, and it's dense and row major, + // so we can just use it directly. + data_ptr = DataPtr(data, [done = std::move(on_done_with_host_buffer)]( + void* ptr) { done(); }); + } + TF_RET_CHECK(data_ptr != nullptr); + return tsl::TakeRef(new NanoArray(client, dtype, shape, std::move(data_ptr), + std::move(sharding))); + } + + const DataPtr& data() const { return data_; } + + // Copies a sub-array of the given size from src to dst. The dst array must + // already be allocated and of the correct type and shape. Values outside of + // the specified sub-array of dst will be left untouched. + // + // This is mostly intended to support sharding and assembling. + static absl::Status CopySubArray(NanoArray& dst, + absl::Span dst_loc, + NanoArray& src, + absl::Span src_loc, + absl::Span size) { + // Make sure the arrays are the same type and the type is supported. + TF_RET_CHECK(dst.dtype() == src.dtype()); + TF_RET_CHECK(dst.dtype().byte_size().has_value()); + + // Make sure all the dims are compatible. + TF_RET_CHECK(dst.shape().dims().size() == size.size()); + TF_RET_CHECK(src.shape().dims().size() == size.size()); + TF_RET_CHECK(dst.shape().dims().size() == size.size()); + TF_RET_CHECK(dst_loc.size() == size.size()); + TF_RET_CHECK(src_loc.size() == size.size()); + + // Make sure what we're copying is within the bounds of the arrays. + for (size_t i = 0; i < size.size(); ++i) { + TF_RET_CHECK(dst_loc[i] + size[i] <= dst.shape().dims()[i]); + TF_RET_CHECK(src_loc[i] + size[i] <= src.shape().dims()[i]); + } + + int64_t element_size = dst.dtype().byte_size().value(); + + // Returns the size of a row in bytes for the given shape. + auto row_size = [=](absl::Span shape) { + if (shape.empty()) return element_size; // Scalar. + return shape.back() * element_size; + }; + + // Since this is always row major, we can do one memcpy per row, and rows + // will always be evenly spaces within the arrays. + int64_t src_row_stride = row_size(src.shape().dims()); + int64_t dst_row_stride = row_size(dst.shape().dims()); + int64_t copy_row_size = row_size(size); + + // How many rows do we have to copy? + int64_t copy_num_rows = 1; + for (int64_t i = 0; i + 1 < size.size(); ++i) { + copy_num_rows *= size[i]; + } + + // Returns a pointer to the given position in the array. + auto get_row_ptr = [&](NanoArray& array, + absl::Span position) -> std::byte* { + size_t offset = 0; + size_t stride = 1; + for (int i = position.size() - 1; i >= 0; --i) { + offset += stride * position[i]; + stride *= array.shape().dims()[i]; + } + offset *= element_size; + return static_cast(array.data().get()) + offset; + }; + + // Get the pointers to the start of the rows we're copying. + std::byte* dst_row_start = get_row_ptr(dst, dst_loc); + std::byte* src_row_start = get_row_ptr(src, src_loc); + + // Copy the rows. + for (int64_t i = 0; i < copy_num_rows; ++i) { + memcpy(dst_row_start, src_row_start, copy_row_size); + dst_row_start += dst_row_stride; + src_row_start += src_row_stride; + } + return absl::OkStatus(); + } + + absl::StatusOr>> Disassemble() { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + if (sharding().IsFullyReplicated()) { + if (sharding().devices()->size() == 1) { + // Only one device and one shard, so we can just return a reference to + // this array. + return std::vector>{tsl::FormRef(this)}; + } + + // If the array is fully replicated and there are multiple "devices", we + // need to make one "copy" per device. + std::vector> shards; + shards.reserve(sharding().devices()->size()); + for (auto* device : sharding().devices()->devices()) { + auto one_device_sharding = ifrt::SingleDeviceSharding::Create( + device, sharding().memory_kind()); + shards.push_back( + tsl::TakeRef(new NanoArray(nano_client(), dtype_, shape_, data_, + std::move(one_device_sharding)))); + } + return shards; + } + + // The array is sharded, copy the appropriate sub-arrays. + TF_ASSIGN_OR_RETURN(auto index_domains, sharding().IndexDomains(shape())); + TF_RET_CHECK(index_domains.size() == sharding().devices()->size()); + std::vector> shards; + shards.reserve(index_domains.size()); + for (int i = 0; i < index_domains.size(); ++i) { + const auto& index_domain = index_domains[i]; + auto* device = sharding().devices()->devices()[i]; + auto one_device_sharding = + ifrt::SingleDeviceSharding::Create(device, sharding().memory_kind()); + TF_ASSIGN_OR_RETURN( + auto shard, + NanoArray::Allocate(nano_client(), dtype(), index_domain.shape(), + std::move(one_device_sharding))); + TF_RETURN_IF_ERROR(NanoArray::CopySubArray( + // To the origin of this shard. + *shard, ifrt::Index::Zeros(shape().dims().size()).elements(), + // From the assembled array. + *this, index_domain.origin().elements(), + // The in the shape of this shard. + index_domain.shape().dims())); + shards.push_back(std::move(shard)); + } + return shards; + } + + NanoRtExecutable::Argument AsArgument() { + return NanoRtExecutable::Argument( + reinterpret_cast(data_.get()), + dtype_.byte_size().value() * shape_.num_elements()); + } + + NanoRtExecutable::Result AsResult() { + return NanoRtExecutable::Result( + reinterpret_cast(data_.get()), + dtype_.byte_size().value() * shape_.num_elements()); + } + + std::string DebugString() const override { + return absl::StrCat("NanoArray(", dtype_.DebugString(), ", ", + shape_.DebugString(), ", @", + reinterpret_cast(data_.get()), ")"); + } + + ifrt::Future<> Delete() override { + data_ = nullptr; + return Ready(); + } + + bool IsDeleted() const override { return data_ == nullptr; } + + ifrt::DType dtype() const override { return dtype_; } + + const ifrt::Shape& shape() const override { return shape_; } + + const ifrt::Sharding& sharding() const override { return *sharding_; } + + absl::Nonnull> shared_ptr_sharding() + const override { + return sharding_; + } + + absl::StatusOr> layout() const override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return std::make_shared(xla::Layout(shape().dims())); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + TF_ASSIGN_OR_RETURN(auto shards, Disassemble()); + return std::vector>(shards.begin(), shards.end()); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return DisassembleIntoSingleDeviceArrays(array_copy_semantics); + } + + absl::StatusOr> FullyReplicatedShard( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return tsl::FormRef(this); + } + + ifrt::Future<> CopyToHostBuffer( + void* data, std::optional> byte_strides, + ifrt::ArrayCopySemantics semantics) override { + // Run everything in a lambda so we can use error macros and convert to a + // future once. + return Ready([&] { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_dtype, + ifrt::ToPrimitiveType(dtype())); + if (!byte_strides.has_value() || + xla::HasMajorToMinorLayout(xla_dtype, shape().dims(), + *byte_strides)) { + memcpy(data, data_.get(), + dtype().byte_size().value() * shape().num_elements()); + } else { + TF_ASSIGN_OR_RETURN(auto in_strides, + DenseByteStrides(dtype(), shape())); + TF_RETURN_IF_ERROR(CopyWithByteStrides( + reinterpret_cast(data), *byte_strides, + reinterpret_cast(data_.get()), in_strides, + shape().dims(), dtype().byte_size().value())); + } + return absl::OkStatus(); + }()); + } + + static char ID; // NOLINT + + private: + // Returns true if the given data type, shape, and strides are compatible + // with NanoArray (we can either use this memory directly or memcpy it into + // our own memory). + static bool LayoutCompatible( + ifrt::DType dtype, const ifrt::Shape& shape, + std::optional> byte_strides) { + if (!dtype.byte_size().has_value()) { + return false; + } + auto xla_dtype = ifrt::ToPrimitiveType(dtype); + if (!xla_dtype.ok()) { + return false; + } + if (!byte_strides.has_value()) { + return true; + } + return xla::HasMajorToMinorLayout(*xla_dtype, shape.dims(), *byte_strides); + } + + // Returns the byte strides for a dense array with the given type and shape. + static absl::StatusOr> DenseByteStrides( + ifrt::DType dtype, ifrt::Shape shape) { + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_dtype, + ifrt::ToPrimitiveType(dtype)); + auto xla_shape = xla::ShapeUtil::MakeShape(xla_dtype, shape.dims()); + auto strides = xla::ShapeUtil::ByteStrides(xla_shape); + if (!strides.has_value()) { + return absl::InvalidArgumentError(absl::StrCat( + "Couldn't compute byte strides for shape:", xla_shape.ToString())); + } + return std::move(*strides); + } + + // Allocates an aligned buffer of the given size. + static absl::StatusOr AllocateData(size_t size) { + DataPtr data_ptr(aligned_alloc(Align(), std::max(size, Align())), + [](void* ptr) { free(ptr); }); + if (data_ptr == nullptr) { + return absl::InternalError(absl::StrCat( + "Failed to allocate memory for NanoArray. Errno: ", strerror(errno))); + } + return data_ptr; + } + + // Copies data between two buffers that represent the same shape but have + // different byte strides. This is a recursive method that peels back dims + // until we get to a scalar, which isn't very efficient but the common case + // is expected to be a row major array without padding. + static absl::Status CopyWithByteStrides( + std::byte* dst, absl::Span dst_byte_strides, + const std::byte* src, absl::Span src_byte_strides, + absl::Span dims, int64_t elem_size) { + TF_RET_CHECK(dims.size() == dst_byte_strides.size()); + TF_RET_CHECK(dims.size() == src_byte_strides.size()); + // Scalar. Just copy it. + if (dims.empty()) { + memcpy(dst, src, elem_size); + return absl::OkStatus(); + } + // Peel back dims recursively until we get to a scalar. + for (int64_t i = 0; i < dims[0]; ++i) { + TF_RETURN_IF_ERROR(CopyWithByteStrides(dst, dst_byte_strides.subspan(1), + src, src_byte_strides.subspan(1), + dims.subspan(1), elem_size)); + dst += dst_byte_strides[0]; + src += src_byte_strides[0]; + } + return absl::OkStatus(); + } + + ifrt::DType dtype_; + ifrt::Shape shape_; + DataPtr data_; + std::shared_ptr sharding_; +}; + +char NanoArray::ID = 'A'; // NOLINT + +// Sharded array implementation. Represents an array that should be assembled +// from multiple arrays, but we aren't sure how to assemble it yet. +class ShardedNanoArray final : public NanoValue { + public: + // Creates an array from the given shards. Note that if we can assemble the + // array using the given sharding, this method will return a NanoArray. + static absl::StatusOr> FromShards( + NanoIfrtClient* client, ifrt::Shape shape, + std::shared_ptr sharding, + std::vector> shards) { + if (shards.empty()) { + return absl::InvalidArgumentError( + "Can't create a sharded array with no shards."); + } + xla::ifrt::DType dtype = shards[0]->dtype(); + + auto array = tsl::TakeRef(new ShardedNanoArray( + client, dtype, shape, sharding, std::move(shards))); + + // Try to eagerly assemble the array. Sometimes this cannot be done + // because arrays are loaded with a simple per device sharding and we + // won't know how to assemble it until the program is run. + if (auto dense_array = array->Assemble(sharding); dense_array.ok()) { + return dense_array; + } + + // If we can't assemble the array, we'll just return the sharded array. It + // will be assembled at execution time when we know the actual sharding. + return array; + } + + const std::vector>& shards() { return shards_; } + + // Assembles the array using the given sharding to prepare it as an input to + // execution. If this array has already been assembled using the given + // sharding, this method will return the cached result. This optimizes a + // common case where a checkpoint is loaded with an unknown sharding, but + // then we find the real sharding when the program is run. + absl::StatusOr> AssembleForExecution( + std::shared_ptr sharding) { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + absl::call_once(assemble_once_, [this, sharding]() { + assemble_result_ = Assemble(sharding); + }); + TF_RETURN_IF_ERROR(assemble_result_.status()); + if (assemble_result_.value()->shared_ptr_sharding() != sharding) { + // Bleh... We cached the wrong sharding somehow. This means one sharded + // array was an input to two different programs with different + // shardings, this should be unlikely. + return Assemble(sharding); + } + return assemble_result_; + } + + ifrt::Future<> Delete() override { + // Sharded arrays are never borrowed like dense arrays are, so we can just + // clear the shards and let them be destroyed. + shards_.clear(); + assemble_result_ = absl::Status(absl::StatusCode::kUnavailable, ""); + return Ready(); + } + + bool IsDeleted() const override { return shards_.empty(); } + + std::string DebugString() const override { + auto result = + absl::StrCat("ShardedNanoArray(", dtype_.DebugString(), ", ", + shape_.DebugString(), ", ", sharding_->DebugString()); + for (const auto& shard : shards_) { + absl::StrAppend(&result, ", ", shard->DebugString()); + } + absl::StrAppend(&result, ")"); + return result; + } + + ifrt::DType dtype() const override { return dtype_; } + + const ifrt::Shape& shape() const override { return shape_; } + + const ifrt::Sharding& sharding() const override { return *sharding_; } + + absl::Nonnull> shared_ptr_sharding() + const override { + return sharding_; + } + + absl::StatusOr> layout() const override { + return std::make_shared(xla::Layout(shape().dims())); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return std::vector>(shards_.begin(), shards_.end()); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + return DisassembleIntoSingleDeviceArrays(array_copy_semantics); + } + + absl::StatusOr> FullyReplicatedShard( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return tsl::FormRef(this); + } + + ifrt::Future<> CopyToHostBuffer( + void* data, std::optional> byte_strides, + ifrt::ArrayCopySemantics semantics) override { + return Ready( + absl::InternalError("Cannot copy sharded array to host buffer.")); + } + + static char ID; // NOLINT + + private: + ShardedNanoArray(NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding, + std::vector> shards) + : NanoValue(client), + dtype_(std::move(dtype)), + shape_(std::move(shape)), + sharding_(std::move(sharding)), + shards_(std::move(shards)) {} + + absl::StatusOr> Assemble( + std::shared_ptr sharding) { + TF_ASSIGN_OR_RETURN(auto index_domains, sharding->IndexDomains(shape())); + if (index_domains.size() != shards_.size()) { + return absl::FailedPreconditionError( + absl::StrCat("Number of index domains ", index_domains.size(), + " not equal to number of arrays ", shards_.size())); + } + + for (int i = 0; i < index_domains.size(); ++i) { + if (index_domains[i].shape() != shards_[i]->shape()) { + return absl::FailedPreconditionError(absl::StrCat( + "Index domain ", index_domains[i].shape().DebugString(), + " not equal to array shape ", shards_[i]->shape().DebugString())); + } + } + + // If the sharding is replicated in any way, this comparator will dedupe + // arrays that have the same logical destination. + struct IndexDomainCmp { + bool operator()(const ifrt::IndexDomain& a, + const ifrt::IndexDomain& b) const { + return std::lexicographical_compare( + a.origin().elements().begin(), a.origin().elements().end(), + b.origin().elements().begin(), b.origin().elements().end()); + } + }; + + // Index the arrays by where we are copying them to. Note that this will + // implicitly filter out replicated shards since they will have the same + // destination in the assembled array. + absl::btree_map + index_domain_device_arrays; + for (int i = 0; i < index_domains.size(); ++i) { + index_domain_device_arrays[index_domains[i]] = shards_[i].get(); + } + + TF_ASSIGN_OR_RETURN(auto result, NanoArray::Allocate(nano_client(), dtype(), + shape(), sharding)); + + // Copy the shards into the final array. + auto shard_origin = ifrt::Index::Zeros(shards_[0]->shape().dims().size()); + for (const auto& [index_domain, shard] : index_domain_device_arrays) { + TF_RETURN_IF_ERROR(NanoArray::CopySubArray( + *result, index_domain.origin().elements(), *shard, + shard_origin.elements(), shard->shape().dims())); + } + + return result; + } + + ifrt::DType dtype_; + ifrt::Shape shape_; + std::shared_ptr sharding_; + std::vector> shards_; + + absl::once_flag assemble_once_; + absl::StatusOr> assemble_result_; +}; + +char ShardedNanoArray::ID = 'A'; // NOLINT + +// Tuple implementation. +class NanoTuple final : public NanoValue { + public: + explicit NanoTuple(NanoIfrtClient* client, + absl::Span> values) + : NanoValue(client), + values_(values.begin(), values.end()) {} + + ifrt::Future<> Delete() override { + for (auto& value : values_) { + value->Delete(); + } + values_.clear(); + deleted_ = true; + return Ready(); + } + + bool IsDeleted() const override { + for (auto& value : values_) { + if (value->IsDeleted()) { + return true; + } + } + return deleted_; + } + + // Returns the arity of the tuple. + int Arity() override { return values_.size(); } + + // Unpacks the tuple into its constituent pieces. + absl::Status Unpack( + absl::Span> values) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + if (values.size() != values_.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Tuple arity mismatch: expected ", values_.size(), + ", got ", values.size())); + } + for (int i = 0; i < values_.size(); ++i) { + values[i] = values_[i]; + } + return absl::OkStatus(); + } + + std::string DebugString() const override { + std::string result = "NanoTuple("; + for (const auto& value : values_) { + absl::StrAppend(&result, value->DebugString(), ", "); + } + absl::StrAppend(&result, ")"); + return result; + } + + static char ID; // NOLINT + + private: + bool deleted_ = false; + std::vector> values_; +}; + +char NanoTuple::ID = 'T'; // NOLINT + +// Executable implementation. +class NanoExecutable final + : public llvm::RTTIExtends { + public: + // Creates a NanoExecutable from an ifrt::Program. + static absl::StatusOr> Create( + NanoIfrtClient* client, std::unique_ptr program) { + auto* xla_program = llvm::dyn_cast(program.get()); + if (xla_program == nullptr) { + return absl::InvalidArgumentError("NanoRT requires an HloProgram"); + } + XlaComputation computation; + TF_RETURN_IF_ERROR(MlirToXlaComputation(xla_program->mlir_module, + computation, false, true, false)); + TF_ASSIGN_OR_RETURN(auto nano_executable, + client->nano_client()->Compile(computation)); + + if (computation.proto().computations().size() != 1) { + return absl::InvalidArgumentError( + absl::StrCat("NanoRT only supports single-computation programs, got ", + computation.proto().computations().size())); + } + + TF_ASSIGN_OR_RETURN(auto program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(auto proto_input_shardings, + GetInputShardings(program_shape, computation)); + TF_ASSIGN_OR_RETURN(auto proto_output_shardings, + GetOutputShardings(program_shape, computation)); + auto input_shardings = + IfrtShardingsFromProto(client, proto_input_shardings); + auto output_shardings = + IfrtShardingsFromProto(client, proto_output_shardings); + + return absl::WrapUnique(new NanoExecutable( + client, std::move(computation), std::move(program_shape), + std::move(nano_executable), std::move(input_shardings), + std::move(output_shardings))); + } + + ifrt::Client* client() const override { return client_; } + + absl::string_view name() const override { return program_.name(); } + + absl::StatusOr Execute( + absl::Span> args, + const ExecuteOptions& options, + std::optional> devices) override { + if (args.size() != input_shardings_.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Number of arguments ", args.size(), + " is not what executable expects ", input_shardings_.size())); + } + + // Convert the ifrt arrays to nano arrays. 'tmp' holds any arrays that had + // to be assembled. + std::vector> tmp; + TF_ASSIGN_OR_RETURN(auto nano_args, + NanoArgumentsFromIfrtArguments(args, tmp)); + + TF_ASSIGN_OR_RETURN(auto result_arrays, AllocateResults()); + std::vector nano_results; + nano_results.reserve(result_arrays.size()); + for (auto& result_array : result_arrays) { + nano_results.push_back( + llvm::dyn_cast(result_array.get())->AsResult()); + } + + auto event = executable_->Execute(nano_args, nano_results, + NanoRtExecutable::PreallocatedTemp{}); + + // TODO(jsoyke): Consider making this non-blocking if we ever use this + // interface for models that require threading, or if we want to delay + // execution until we know where the outputs will be stored. + tsl::BlockUntilReady(event); + + if (event.IsError()) return event.GetError(); + if (!event.IsConcrete()) { + return absl::InternalError("NanoRT result is not concrete."); + } + + ExecuteResult result; + if (options.fill_status) { + result.status = Ready(); + } + result.outputs = std::move(result_arrays); + return result; + } + + // Returns a fingerprint of this executable. + absl::StatusOr> Fingerprint() const override { + return absl::UnimplementedError("Fingerprint is not implemented."); + } + + absl::StatusOr Serialize() const override { + return absl::UnimplementedError("Serialize is not implemented."); + } + + ifrt::Future<> GetReadyFuture() const override { return Ready(); } + + int num_devices() const override { return 1; } + + int64_t SizeOfGeneratedCodeInBytes() const override { return 0; } + + absl::StatusOr GetCompiledMemoryStats() const override { + return absl::UnimplementedError( + "GetCompiledMemoryStats is not implemented."); + } + + std::optional> GetParameterShardings() + const override { + auto shardings = GetInputShardings(program_shape_, program_); + if (!shardings.ok()) return std::nullopt; + return *shardings; + } + + std::optional> GetOutputShardings() const override { + auto shardings = GetOutputShardings(program_shape_, program_); + if (!shardings.ok()) return std::nullopt; + return *shardings; + } + + absl::StatusOr>> + GetParameterLayouts() const override { + std::vector> layouts; + layouts.reserve(program_shape_.parameters().size()); + for (const auto& shape : program_shape_.parameters()) { + layouts.push_back( + std::make_shared(xla::Layout(shape.dimensions()))); + } + return layouts; + } + + absl::StatusOr>> + GetOutputLayouts() const override { + const auto& result_shape = program_shape_.result(); + const auto result_shapes = + result_shape.IsTuple() + ? absl::MakeConstSpan(result_shape.tuple_shapes()) + : absl::MakeConstSpan(&result_shape, 1); + std::vector> layouts; + layouts.reserve(result_shapes.size()); + for (const auto& shape : result_shapes) { + layouts.push_back( + std::make_shared(xla::Layout(shape.dimensions()))); + } + return layouts; + } + + absl::StatusOr>> GetHloModules() + const override { + std::vector> hlo_modules(1); + TF_ASSIGN_OR_RETURN( + hlo_modules[0], + HloModule::CreateFromProto(program_.proto(), HloModuleConfig())); + return hlo_modules; + } + + absl::StatusOr>> + GetOutputMemoryKinds() const override { + std::vector> memory_kinds; + memory_kinds.reserve(output_shardings_.size()); + for (const auto& _ : output_shardings_) { + memory_kinds.push_back({kMemoryKind}); + } + return memory_kinds; + } + + absl::StatusOr GetCostAnalysis() const override { + return absl::UnimplementedError("GetCostAnalysis is not implemented."); + } + + ifrt::Future<> Delete() override { + client_ = nullptr; + program_ = {}; + program_shape_ = {}; + executable_.reset(); + input_shardings_.clear(); + output_shardings_.clear(); + return Ready(); + } + + bool IsDeleted() const override { return executable_ == nullptr; } + + absl::Span addressable_devices() const override { + return client_->addressable_devices(); + } + + static char ID; // NOLINT + + private: + NanoExecutable(NanoIfrtClient* client, XlaComputation program, + ProgramShape program_shape, + std::unique_ptr executable, + std::vector> input_shardings, + std::vector> output_shardings) + : client_(client), + program_(std::move(program)), + program_shape_(std::move(program_shape)), + executable_(std::move(executable)), + input_shardings_(std::move(input_shardings)), + output_shardings_(std::move(output_shardings)) {} + + // Converts an OpSharding proto (from an HLO Instruction) to an ifrt + // sharding. + static std::vector> IfrtShardingsFromProto( + NanoIfrtClient* client, absl::Span shardings) { + std::vector> result; + result.reserve(shardings.size()); + for (const auto& sharding : shardings) { + if (sharding.type() == OpSharding::REPLICATED || + sharding.type() == OpSharding::MAXIMAL) { + result.push_back(client->default_sharding()); + continue; + } + int num_tiles = 1; + for (const auto dim : sharding.tile_assignment_dimensions()) { + num_tiles *= dim; + } + // Repeat the device for each tile. We only have one device anyway so + // just used the first. + auto device_list = ifrt::BasicDeviceList::Create( + ifrt::BasicDeviceList::Devices(num_tiles, client->devices()[0])); + auto xla_sharding = *HloSharding::FromProto(sharding); + result.push_back(ifrt::HloSharding::Create( + std::move(device_list), client->devices()[0]->Memories()[0]->Kind(), + std::move(xla_sharding))); + } + return result; + } + + static absl::StatusOr> GetInputShardings( + const ProgramShape& program_shape, const XlaComputation& computation) { + std::vector shardings(program_shape.parameters().size()); + for (const auto& instruction : + computation.proto().computations(0).instructions()) { + if (instruction.opcode() == "parameter" && instruction.has_sharding()) { + if (instruction.parameter_number() >= shardings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Parameter number ", instruction.parameter_number(), + " is out of range for program with ", + program_shape.parameters().size(), " parameters.")); + } + shardings[instruction.parameter_number()] = instruction.sharding(); + } + } + return shardings; + } + + static absl::StatusOr> GetOutputShardings( + const ProgramShape& program_shape, const XlaComputation& computation) { + const auto& result_shape = program_shape.result(); + + int output_id = computation.proto().computations(0).root_id(); + + std::vector shardings( + (result_shape.IsTuple() ? result_shape.tuple_shapes().size() : 1)); + + for (const auto& instruction : + computation.proto().computations(0).instructions()) { + // We found a sharded output instruction. + if (instruction.id() == output_id && instruction.has_sharding()) { + if (result_shape.IsTuple()) { + TF_RET_CHECK(instruction.sharding().tuple_shardings().size() == + result_shape.tuple_shapes().size()); + for (int i = 0; i < instruction.sharding().tuple_shardings().size(); + ++i) { + shardings[i] = instruction.sharding().tuple_shardings()[i]; + } + } else { + shardings[0] = instruction.sharding(); + } + } + } + return shardings; + } + + // Allocates the results for the program. + absl::StatusOr>> AllocateResults() { + const auto& result_shape = program_shape_.result(); + const auto result_shapes = + result_shape.IsTuple() + ? absl::MakeConstSpan(result_shape.tuple_shapes()) + : absl::MakeConstSpan(&result_shape, 1); + TF_RET_CHECK(result_shapes.size() == output_shardings_.size()); + + std::vector> result_arrays; + result_arrays.reserve(result_shapes.size()); + + for (int i = 0; i < result_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN(auto ifrt_type, + ifrt::ToDType(result_shapes[i].element_type())); + ifrt::Shape ifrt_shape(result_shapes[i].dimensions()); + TF_ASSIGN_OR_RETURN(auto array, + NanoArray::Allocate(client_, ifrt_type, ifrt_shape, + output_shardings_[i])); + result_arrays.push_back(std::move(array)); + } + return result_arrays; + } + + // Converts the ifrt arrays to nano arguments. 'tmp' holds any arrays that + // had to be assembled. + absl::StatusOr> + NanoArgumentsFromIfrtArguments( + absl::Span> args, + std::vector>& tmp) { + std::vector nano_args; + nano_args.reserve(args.size()); + + for (int i = 0; i < args.size(); ++i) { + auto* nano_array = llvm::dyn_cast_or_null(args[i].get()); + if (nano_array == nullptr) { + // The input isn't a nano array, so it must be a sharded array. + auto* sharded_array = + llvm::dyn_cast_or_null(args[i].get()); + if (sharded_array == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Argument is not a NanoArray or ShardedNanoArray: ", + args[i]->DebugString())); + } + TF_ASSIGN_OR_RETURN( + auto dense_array, + sharded_array->AssembleForExecution(input_shardings_[i])); + nano_array = dense_array.get(); + tmp.push_back(std::move(dense_array)); + } + nano_args.push_back(nano_array->AsArgument()); + } + + return nano_args; + } + + NanoIfrtClient* client_; + XlaComputation program_; + ProgramShape program_shape_; + std::unique_ptr executable_; + std::vector> input_shardings_; + std::vector> output_shardings_; +}; + +char NanoExecutable::ID = 'E'; // NOLINT + +// Compiler implementation. +class NanoCompiler final + : public llvm::RTTIExtends { + public: + explicit NanoCompiler(NanoIfrtClient* client) : client_(client) {} + + absl::StatusOr> Compile( + std::unique_ptr program, + std::unique_ptr options) override { + return NanoExecutable::Create(client_, std::move(program)); + } + + absl::StatusOr> Compile( + std::unique_ptr program, const ifrt::Topology& topology, + std::unique_ptr options) override { + return absl::UnimplementedError("Partial compilation is not implemented."); + } + + absl::StatusOr> + DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) override { + return absl::UnimplementedError( + "DeserializeLoadedExecutable is not implemented."); + } + static char ID; // NOLINT + + private: + NanoIfrtClient* client_; +}; + +char NanoCompiler::ID = 'C'; // NOLINT + +// Memory implementation. There is only one address space so this doesn't do +// much. +class NanoMemory final : public llvm::RTTIExtends { + public: + explicit NanoMemory(NanoIfrtClient* client) : client_(client) {} + + ifrt::MemoryId Id() const override { return ifrt::MemoryId(0); } + + const ifrt::MemoryKind& Kind() const override { + static ifrt::MemoryKind mem_kind(kMemoryKind); + return mem_kind; + } + + absl::string_view ToString() const override { return "NanoRT CPU Memory"; } + absl::string_view DebugString() const override { return ToString(); } + absl::Span Devices() const override { + return client_->devices(); + } + + static char ID; // NOLINT + + private: + NanoMemory() = default; + + NanoIfrtClient* client_; +}; + +char NanoMemory::ID = 'M'; // NOLINT + +// Device implementation. There is only one device so this doesn't do much. +class NanoDevice final : public llvm::RTTIExtends { + public: + NanoDevice(NanoIfrtClient* client, ifrt::Memory* memory) + : client_(client), memory_(memory) {} + + ifrt::Client* client() const override { return client_; } + + ifrt::DeviceId Id() const override { return ifrt::DeviceId(0); } + + const ifrt::AttributeMap& Attributes() const override { + static auto attributes = new ifrt::AttributeMap({}); + return *attributes; + } + + absl::string_view Kind() const override { return "cpu"; } + + absl::string_view ToString() const override { return "NanoRT CPU"; } + + absl::string_view DebugString() const override { return ToString(); } + + absl::StatusOr DefaultMemory() const override { + return memory_; + } + + absl::Span Memories() const override { + return absl::MakeConstSpan(&memory_, 1); + } + + bool IsAddressable() const override { return true; } + + int ProcessIndex() const override { return 0; } + + static char ID; // NOLINT + + private: + NanoIfrtClient* client_; + ifrt::Memory* memory_; +}; + +char NanoDevice::ID = 'D'; // NOLINT + +} // namespace + +NanoIfrtClient::~NanoIfrtClient() = default; + +std::shared_ptr NanoIfrtClient::Create() { + return CreateWithDevices(1); +} + +std::shared_ptr NanoIfrtClient::CreateWithDevices( + int num_devices) { + return std::shared_ptr(new NanoIfrtClient(num_devices)); +} + +std::shared_ptr NanoIfrtClient::default_sharding() const { + return ifrt::SingleDeviceSharding::Create(device_.get(), ifrt::MemoryKind{}); +} + +absl::StatusOr> +NanoIfrtClient::MakeArrayFromHostBuffer( + const void* data, ifrt::DType dtype, ifrt::Shape shape, + std::optional> byte_strides, + absl::Nonnull> sharding, + HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + bool make_copy = false; + switch (semantics) { + case HostBufferSemantics::kImmutableUntilTransferCompletes: + case HostBufferSemantics::kImmutableOnlyDuringCall: + make_copy = true; + break; + case HostBufferSemantics::kImmutableZeroCopy: + case HostBufferSemantics::kMutableZeroCopy: + make_copy = false; + break; + } + return NanoArray::FromBuffer(this, const_cast(data), dtype, shape, + std::move(sharding), byte_strides, make_copy, + on_done_with_host_buffer); +} + +absl::StatusOr> +NanoIfrtClient::AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) { + std::vector> nano_arrays; + nano_arrays.reserve(arrays.size()); + for (const auto& array : arrays) { + auto* nano_array = llvm::dyn_cast_or_null(array.get()); + if (nano_array == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Array is not a NanoArray: ", array->DebugString())); + } + nano_arrays.push_back(tsl::FormRef(nano_array)); + } + return ShardedNanoArray::FromShards(this, shape, sharding, + std::move(nano_arrays)); +} + +absl::StatusOr> +NanoIfrtClient::AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) { + return AssembleArrayFromSingleDeviceArrays(shape, sharding, arrays, + array_copy_semantics); +} + +absl::StatusOr>> +NanoIfrtClient::CopyArrays( + absl::Span> arrays, + std::optional> devices, + std::optional memory_kind, + ifrt::ArrayCopySemantics semantics) { + std::vector> result; + result.reserve(arrays.size()); + for (const auto& array : arrays) { + tsl::RCReference copy; + TF_ASSIGN_OR_RETURN(auto sharding, array->sharding().WithDeviceAssignment( + devices, memory_kind)); + if (auto nano_array = llvm::dyn_cast_or_null(array.get())) { + copy = tsl::TakeRef(new NanoArray(this, nano_array->dtype(), + nano_array->shape(), nano_array->data(), + std::move(sharding))); + } else if (auto sharded_nano_array = + llvm::dyn_cast_or_null(array.get())) { + std::vector> shards_copy; + shards_copy.reserve(sharded_nano_array->shards().size()); + for (const auto& shard : sharded_nano_array->shards()) { + shards_copy.push_back(tsl::TakeRef( + new NanoArray(this, shard->dtype(), shard->shape(), shard->data(), + shard->shared_ptr_sharding()))); + } + TF_ASSIGN_OR_RETURN( + copy, ShardedNanoArray::FromShards(this, sharded_nano_array->shape(), + std::move(sharding), + std::move(shards_copy))); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Array is not a NanoArray or ShardedNanoArray: ", + array->DebugString())); + } + TF_RET_CHECK(copy != nullptr); + result.push_back(copy); + } + return result; +} + +absl::StatusOr>> +NanoIfrtClient::RemapArrays( + const ifrt::RemapPlan& plan, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) { + return absl::UnimplementedError("RemapArrays is not implemented."); +} + +ifrt::Future<> NanoIfrtClient::GetReadyFuture( + absl::Span> values) { + return Ready(); +} + +absl::StatusOr> NanoIfrtClient::MakeTuple( + absl::Span> values) { + return tsl::MakeRef(this, std::move(values)); +} + +absl::string_view NanoIfrtClient::runtime_type() const { return "nano"; } + +absl::string_view NanoIfrtClient::platform_name() const { + return xla::CpuName(); +} + +absl::string_view NanoIfrtClient::platform_version() const { + return xla::CpuName(); +} + +ifrt::PlatformId NanoIfrtClient::platform_id() const { + return tsl::Fingerprint64(platform_name()); +} + +const ifrt::AttributeMap& NanoIfrtClient::Attributes() const { + static auto attributes = new ifrt::AttributeMap({}); + return *attributes; +} + +int NanoIfrtClient::device_count() const { return devices_.size(); } + +int NanoIfrtClient::addressable_device_count() const { return device_count(); } + +absl::Span NanoIfrtClient::devices() const { + return devices_; +} + +absl::Span NanoIfrtClient::addressable_devices() const { + return devices(); +} + +int NanoIfrtClient::process_index() const { return 0; } + +absl::Span NanoIfrtClient::GetAllDevices() const { + return devices(); +} + +absl::StatusOr +NanoIfrtClient::GetDefaultDeviceAssignment(int num_replicas, + int num_partitions) const { + return ifrt::DeviceAssignment(1, 1); +} + +absl::StatusOr NanoIfrtClient::LookupDevice( + ifrt::DeviceId device_id) const { + return LookupAddressableDevice(device_id.value()); +} + +absl::StatusOr NanoIfrtClient::LookupAddressableDevice( + int local_hardware_id) const { + return device_.get(); +} + +ifrt::Compiler* NanoIfrtClient::GetDefaultCompiler() { return compiler_.get(); } + +absl::StatusOr> +NanoIfrtClient::GetTopologyForDevices( + const tsl::RCReference& devices) const { + return absl::UnimplementedError("GetTopologyForDevices is not implemented."); +} + +absl::StatusOr> +NanoIfrtClient::GetDefaultLayout(ifrt::DType dtype, + absl::Span dims, + ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) const { + return std::make_shared(xla::Layout(dims)); +} + +NanoIfrtClient::NanoIfrtClient(int32_t num_devices) + : compiler_(std::make_unique(this)), + memory_(std::make_unique(this)), + device_(std::make_unique(this, memory_.get())), + default_sharding_( + ifrt::SingleDeviceSharding::Create(device_.get(), memory_->Kind())), + devices_(num_devices, device_.get()) {} + +char NanoIfrtClient::ID = 'N'; // NOLINT + +} // namespace xla::cpu diff --git a/xla/backends/cpu/nanort/ifrt_client.h b/xla/backends/cpu/nanort/ifrt_client.h new file mode 100644 index 00000000000000..96530d62bdb1bf --- /dev/null +++ b/xla/backends/cpu/nanort/ifrt_client.h @@ -0,0 +1,197 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ +#define XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/backends/cpu/nanort/nanort_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace xla::cpu { + +// NanoIfrtClient is a thin wrapper around NanoRtClient that implements the +// ifrt::Client interface. +// +// Unlike NanoRtClient, this class will honor sharding annotations in XLA +// programs, mostly to satisfy IFRT callers. The sharding will be undone as soon +// as possible and reused (either when the sharded arrays is assembled or when +// it is first accessed by an executable). Even so, this client will have much +// better performance with unsharded inputs. +// +// Note: Array remapping is currently unimplemented. +// +// Note: We may add support for callers to access the underlying executables and +// buffers directly in the future, this would allow the "load path" that +// initializes programs and variables to be reused while still getting the +// performance wins of NanoRt at execution time. +class NanoIfrtClient : public llvm::RTTIExtends { + public: + ~NanoIfrtClient() override; + + // Creates a client with a single device. Typically this is how this client + // should be used. + static std::shared_ptr Create(); + + // Creates a client with the given number of devices, this is provided for + // testing and to allow the client to be used in applications that expect + // programs to be sharded. + static std::shared_ptr CreateWithDevices(int32_t num_devices); + + // Returns a single device sharding. Generally callers should prefer to use + // this when possible for optimal performance. + std::shared_ptr default_sharding() const; + + // Returns the underlying NanoRtClient. + NanoRtClient* nano_client() { return &client_; } + + using HostBufferSemantics = xla::ifrt::Client::HostBufferSemantics; + + // Creates an array from a host buffer. The buffer will be used directly + // without a copy if the copy semantics allow it and the layout is row major + // and dense. + absl::StatusOr> MakeArrayFromHostBuffer( + const void* data, ifrt::DType dtype, ifrt::Shape shape, + std::optional> byte_strides, + absl::Nonnull> sharding, + HostBufferSemantics semantics, + std::function on_done_with_host_buffer) override; + + // Assembles a sharded array from a list of single device arrays. If the + // provided sharding is specific enough to assemble a dense array, this method + // will actually return an assembled array that pretends it is sharded. + // + // Otherwise we will produce an assembled array on demand when it is first + // accessed by an XLA program. + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override; + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override; + + absl::StatusOr>> CopyArrays( + absl::Span> arrays, + std::optional> devices, + std::optional memory_kind, + ifrt::ArrayCopySemantics semantics) override; + + absl::StatusOr>> RemapArrays( + const ifrt::RemapPlan& plan, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override; + + ifrt::Future<> GetReadyFuture( + absl::Span> values) override; + + absl::StatusOr> MakeTuple( + absl::Span> values) override; + + absl::string_view runtime_type() const override; + + absl::string_view platform_name() const override; + absl::string_view platform_version() const override; + ifrt::PlatformId platform_id() const override; + + const ifrt::AttributeMap& Attributes() const override; + + int device_count() const override; + int addressable_device_count() const override; + absl::Span devices() const override; + absl::Span addressable_devices() const override; + int process_index() const override; + + absl::Span GetAllDevices() const override; + + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + absl::StatusOr LookupDevice( + ifrt::DeviceId device_id) const override; + absl::StatusOr LookupAddressableDevice( + int local_hardware_id) const override; + + ifrt::Compiler* GetDefaultCompiler() override; + + absl::StatusOr> GetTopologyForDevices( + const tsl::RCReference& devices) const override; + + absl::StatusOr> GetDefaultLayout( + ifrt::DType dtype, absl::Span dims, ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) const override; + + static char ID; // NOLINT + + private: + explicit NanoIfrtClient(int32_t num_devices); + + // The underlying NanoRtClient. + NanoRtClient client_; + + // The compiler, memory, and device objects. See cc file for implementation + // details. + std::unique_ptr compiler_; + std::unique_ptr memory_; + std::unique_ptr device_; + + // The default sharding for this client. When this sharding is used it + // typically means that we can use an array's contents directly. + std::shared_ptr default_sharding_; + + // Some of the ifrt::Client methods return a span of devices, so we need to + // keep storage for them here. Note that this may repeat the device_ pointer + // multiple times if this client is configured with multiple devices. This is + // mostly to make IFRT callers that expect sharded programs to run on multiple + // devices happy. This has the unusual property that we have multiple devices + // but a single device_id, but this seems to work fine and most documentation + // warns that devices may be repeated within a device list or sharding. + std::vector devices_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ diff --git a/xla/backends/cpu/nanort/ifrt_client_test.cc b/xla/backends/cpu/nanort/ifrt_client_test.cc new file mode 100644 index 00000000000000..efe24079a9016a --- /dev/null +++ b/xla/backends/cpu/nanort/ifrt_client_test.cc @@ -0,0 +1,34 @@ +/* Copyright 2023 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/test_util.h" + +// For now, all of the tests we run are provided by IFRT, they use +// NanoIfrtClient via the "register_nanort_for_ifrt_tests" target, which can +// also be used to run NanoIfrtClient in other tests. see the BUILD file for the +// list. We need a main function to filter out one test that doesn't seem worth +// supporting. + +int main(int argc, char** argv) { + // This test expects copies to multiple devices to fail, but we only have one + // device and it doesn't seem worth pretending that we have more. + static constexpr absl::string_view kFilter = + "-ArrayImplTest.CopyMixedSourceDevices"; + xla::ifrt::test_util::SetTestFilterIfNotUserSpecified(kFilter); + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc b/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc new file mode 100644 index 00000000000000..b804c257f79be5 --- /dev/null +++ b/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc @@ -0,0 +1,29 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/nanort/ifrt_client.h" +#include "xla/python/ifrt/test_util.h" + +namespace xla::cpu { +namespace { + +// Link this in to use the NanoIfrtClient as the default IFRT client for tests. +// IFRT tests expect the client to have multiple devices. +const bool kUnused = (ifrt::test_util::RegisterClientFactory( + [] { return NanoIfrtClient::CreateWithDevices(4); }), + true); + +} // namespace +} // namespace xla::cpu