Skip to content

Commit

Permalink
Implement on_device_shape() and logical_on_device_shape().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720712341
  • Loading branch information
matthiaskramm authored and Google-ML-Automation committed Jan 28, 2025
1 parent 32791d7 commit 0b0cd6e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
22 changes: 22 additions & 0 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,28 @@ std::shared_ptr<const PjRtLayout> PjRtCApiBuffer::layout() const {
return layout_;
}

const Shape& PjRtCApiBuffer::on_device_shape() const {
if (!on_device_shape_.has_value()) {
Shape shape(element_type(), dimensions(), is_dynamic_dimension(),
/*tuple_shapes=*/{});
*shape.mutable_layout() = layout()->xla_layout();
absl::MutexLock lock(&mu_);
on_device_shape_ = shape;
}
return *on_device_shape_;
}

absl::StatusOr<Shape> PjRtCApiBuffer::logical_on_device_shape() {
absl::StatusOr<std::vector<int64_t>> dims = logical_dimensions();
if (!dims.ok()) {
return dims.status();
}
Shape result(element_type(), *dims, is_dynamic_dimension(),
/*tuple_shapes=*/{});
*result.mutable_layout() = layout()->xla_layout();
return result;
}

bool PjRtCApiBuffer::has_dynamic_dimensions() const {
PJRT_Buffer_DynamicDimensionIndices_Args args;
args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE;
Expand Down
11 changes: 4 additions & 7 deletions xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,20 +485,15 @@ class PjRtCApiBuffer : public PjRtBuffer {
// PJRT C API doesn't support tuple buffers.
bool IsTuple() const override { return false; }

const Shape& on_device_shape() const override {
LOG(FATAL) << "PjRtBuffer::on_device_shape() not implemented in PJRT C API";
}
const Shape& on_device_shape() const override;

bool has_dynamic_dimensions() const override;

absl::Span<const bool> is_dynamic_dimension() const override;

absl::StatusOr<std::vector<int64_t>> logical_dimensions() override;

absl::StatusOr<Shape> logical_on_device_shape() override {
LOG(FATAL) << "PjRtBuffer::on_logical_device_shape() not implemented in "
"PJRT C API";
}
absl::StatusOr<Shape> logical_on_device_shape() override;

PjRtMemorySpace* memory_space() const override;

Expand Down Expand Up @@ -584,6 +579,8 @@ class PjRtCApiBuffer : public PjRtBuffer {
is_dynamic_dimension_;
// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mu_;
// Cached result of on_device_shape();
mutable std::optional<Shape> on_device_shape_;
};

class PjRtCApiExternalReference : public PjRtBuffer::ExternalReference {
Expand Down
23 changes: 23 additions & 0 deletions xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,29 @@ TEST(PjRtCApiClientTest, IsDynamicDimension) {

EXPECT_THAT(is_dynamic_dimension,
::testing::ElementsAreArray(dims_are_dynamic));
EXPECT_EQ(result_buffer->on_device_shape(),
ShapeUtil::MakeShape(S32, {2, 3}, {false, true}));
EXPECT_EQ(*result_buffer->logical_on_device_shape(),
ShapeUtil::MakeShape(S32, {2, 2}, {false, true}));
}

TEST(PjRtCApiClientTest, OnDeviceShape) {
SetUpCpuPjRtApi();
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
GetCApiClient("cpu"));
std::vector<int32_t> data{1, 2, 3, 4, 5, 6};
for (PrimitiveType t : {F32, F16, S8, BF16}) {
Shape shape = ShapeUtil::MakeShape(t, {3, 2});
TF_ASSERT_OK_AND_ASSIGN(
auto buffer,
client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
client->addressable_devices()[0]));
EXPECT_EQ(buffer->on_device_shape(), shape);
EXPECT_EQ(*buffer->logical_on_device_shape(), shape);
}
}

TEST(PjRtCApiClientTest, PlatformId) {
Expand Down

0 comments on commit 0b0cd6e

Please sign in to comment.