From c1cdf92ae67982db71c8ebcd1b86ad6746d281fc Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Fri, 24 Jan 2025 10:00:04 -0800 Subject: [PATCH] Implement on_device_shape() and logical_on_device_shape(). PiperOrigin-RevId: 719342037 --- xla/pjrt/pjrt_c_api_client.cc | 22 ++++++++++++++++++++++ xla/pjrt/pjrt_c_api_client.h | 11 ++++------- xla/pjrt/pjrt_c_api_client_test.cc | 23 +++++++++++++++++++++++ 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index adb4a4909248cf..efe72578105276 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2161,6 +2161,28 @@ std::shared_ptr PjRtCApiBuffer::layout() const { return layout_; } +const Shape& PjRtCApiBuffer::on_device_shape() const { + if (!on_device_shape_.has_value()) { + Shape shape(element_type(), dimensions(), is_dynamic_dimension(), + /*tuple_shapes=*/{}); + *shape.mutable_layout() = layout()->xla_layout(); + absl::MutexLock lock(&mu_); + on_device_shape_ = shape; + } + return *on_device_shape_; +} + +absl::StatusOr PjRtCApiBuffer::logical_on_device_shape() { + absl::StatusOr> dims = logical_dimensions(); + if (!dims.ok()) { + return dims.status(); + } + Shape result(element_type(), *dims, is_dynamic_dimension(), + /*tuple_shapes=*/{}); + *result.mutable_layout() = layout()->xla_layout(); + return result; +} + bool PjRtCApiBuffer::has_dynamic_dimensions() const { PJRT_Buffer_DynamicDimensionIndices_Args args; args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index ad278249f2a9b3..b2ee8b1954c506 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -485,9 +485,7 @@ class PjRtCApiBuffer : public PjRtBuffer { // PJRT C API doesn't support tuple buffers. bool IsTuple() const override { return false; } - const Shape& on_device_shape() const override { - LOG(FATAL) << "PjRtBuffer::on_device_shape() not implemented in PJRT C API"; - } + const Shape& on_device_shape() const override; bool has_dynamic_dimensions() const override; @@ -495,10 +493,7 @@ class PjRtCApiBuffer : public PjRtBuffer { absl::StatusOr> logical_dimensions() override; - absl::StatusOr logical_on_device_shape() override { - LOG(FATAL) << "PjRtBuffer::on_logical_device_shape() not implemented in " - "PJRT C API"; - } + absl::StatusOr logical_on_device_shape() override; PjRtMemorySpace* memory_space() const override; @@ -584,6 +579,8 @@ class PjRtCApiBuffer : public PjRtBuffer { is_dynamic_dimension_; // Used to synchronize concurrent setting of cached values. mutable absl::Mutex mu_; + // Cached result of on_device_shape(); + mutable std::optional on_device_shape_; }; class PjRtCApiExternalReference : public PjRtBuffer::ExternalReference { diff --git a/xla/pjrt/pjrt_c_api_client_test.cc b/xla/pjrt/pjrt_c_api_client_test.cc index 115dc5a35e2ea2..bb68c866eb48ad 100644 --- a/xla/pjrt/pjrt_c_api_client_test.cc +++ b/xla/pjrt/pjrt_c_api_client_test.cc @@ -111,6 +111,29 @@ TEST(PjRtCApiClientTest, IsDynamicDimension) { EXPECT_THAT(is_dynamic_dimension, ::testing::ElementsAreArray(dims_are_dynamic)); + EXPECT_EQ(result_buffer->on_device_shape(), + ShapeUtil::MakeShape(S32, {2, 3}, {false, true})); + EXPECT_EQ(*result_buffer->logical_on_device_shape(), + ShapeUtil::MakeShape(S32, {2, 2}, {false, true})); +} + +TEST(PjRtCApiClientTest, OnDeviceShape) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + std::vector data{1, 2, 3, 4, 5, 6}; + for (PrimitiveType t : {F32, F16, S8, BF16}) { + Shape shape = ShapeUtil::MakeShape(t, {3, 2}); + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + client->addressable_devices()[0])); + EXPECT_EQ(buffer->on_device_shape(), shape); + EXPECT_EQ(*buffer->logical_on_device_shape(), shape); + } } TEST(PjRtCApiClientTest, PlatformId) {