Skip to content

Commit

Permalink
Use StreamExecutor::CreateMemoryAllocator and StreamExecutorAllocator…
Browse files Browse the repository at this point in the history
… in common_runtime instead of DeviceMemAllocator.

PiperOrigin-RevId: 718541716
  • Loading branch information
klucke authored and Google-ML-Automation committed Jan 24, 2025
1 parent 70ce106 commit 5127a43
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
3 changes: 2 additions & 1 deletion xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ cc_library(
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/integrations:device_mem_allocator",
"//xla/stream_executor/integrations:stream_executor_allocator",
"//xla/tsl/framework:allocator",
"//xla/tsl/framework:bfc_allocator",
"//xla/tsl/framework:device_id_impl",
"//xla/tsl/platform:statusor",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
22 changes: 16 additions & 6 deletions xla/pjrt/gpu/gpu_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ limitations under the License.
#include "xla/service/platform_util.h"
#include "xla/stream_executor/integrations/device_host_allocator.h"
#include "xla/stream_executor/integrations/device_mem_allocator.h"
#include "xla/stream_executor/integrations/stream_executor_allocator.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/framework/bfc_allocator.h"
#include "xla/tsl/framework/device_id.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/env_var.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -98,11 +99,20 @@ absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
}

int device_ordinal = executor->device_ordinal();
auto sub_allocator = std::make_unique<se::DeviceMemAllocator>(
executor, tsl::PlatformDeviceId(device_ordinal),
/*memory_type=*/
enable_unified_memory ? stream_executor::MemoryType::kUnified
: stream_executor::MemoryType::kDevice);
std::unique_ptr<tsl::SubAllocator> sub_allocator;

if (enable_unified_memory) {
TF_ASSIGN_OR_RETURN(
auto unified_memory_allocator,
executor->CreateMemoryAllocator(stream_executor::MemoryType::kUnified));
sub_allocator = std::make_unique<se::StreamExecutorAllocator>(
std::move(unified_memory_allocator),
stream_executor::MemoryType::kUnified, device_ordinal);
} else {
sub_allocator = std::make_unique<se::DeviceMemAllocator>(
executor, tsl::PlatformDeviceId(device_ordinal),
stream_executor::MemoryType::kDevice);
}

int64_t free_memory;
int64_t total_memory;
Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/integrations/stream_executor_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class StreamExecutorAllocator : public tsl::SubAllocator {
public:
StreamExecutorAllocator(std::unique_ptr<MemoryAllocator> memory_allocator,
MemoryType memory_type, int index,
const std::vector<Visitor>& alloc_visitors,
const std::vector<Visitor>& free_visitors);
const std::vector<Visitor>& alloc_visitors = {},
const std::vector<Visitor>& free_visitors = {});

~StreamExecutorAllocator() override = default;
void* Alloc(size_t alignment, size_t num_bytes,
Expand Down

0 comments on commit 5127a43

Please sign in to comment.