Skip to content

Commit

Permalink
Move calls to GpuDriver::GetComputeCapability into CudaExecutor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684930026
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 11, 2024
1 parent f9c0a3c commit 3f0596e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
12 changes: 0 additions & 12 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1080,18 +1080,6 @@ absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr,
return cuda::ToStatus(cuMemGetAddressRange(base, size, dptr));
}

absl::Status GpuDriver::GetComputeCapability(int* cc_major, int* cc_minor,
CUdevice device) {
*cc_major = 0;
*cc_minor = 0;

TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute(
cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));

return cuda::ToStatus(cuDeviceGetAttribute(
cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
}

absl::Status GpuDriver::GetGpuISAVersion(int* version, CUdevice device) {
return absl::Status{
absl::StatusCode::kInternal,
Expand Down
19 changes: 15 additions & 4 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ absl::StatusOr<std::string> GetDeviceName(CUdevice device) {
chars[kCharLimit - 1] = '\0';
return chars.begin();
}

// Returns the compute capability for the device; i.e (3, 5).
absl::Status GetComputeCapability(int* cc_major, int* cc_minor,
CUdevice device) {
*cc_major = 0;
*cc_minor = 0;

TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute(
cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));

return cuda::ToStatus(cuDeviceGetAttribute(
cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
}
} // namespace

// Given const GPU memory, returns a libcuda device pointer datatype, suitable
Expand Down Expand Up @@ -300,8 +313,7 @@ absl::Status CudaExecutor::Init() {
TF_ASSIGN_OR_RETURN(Context * context,
CudaContext::Create(device_ordinal(), device_));
set_context(context);
TF_RETURN_IF_ERROR(
GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_));
TF_RETURN_IF_ERROR(GetComputeCapability(&cc_major_, &cc_minor_, device_));
TF_ASSIGN_OR_RETURN(delay_kernels_supported_, DelayKernelIsSupported());
return absl::OkStatus();
}
Expand Down Expand Up @@ -855,8 +867,7 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) {

int cc_major;
int cc_minor;
TF_RETURN_IF_ERROR(
GpuDriver::GetComputeCapability(&cc_major, &cc_minor, device));
TF_RETURN_IF_ERROR(GetComputeCapability(&cc_major, &cc_minor, device));

DeviceDescription desc;

Expand Down
7 changes: 0 additions & 7 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,6 @@ class GpuDriver {

// -- Device-specific calls.

// Returns the compute capability for the device; i.e (3, 5).
// This is currently done via the deprecated device API.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1ge2091bbac7e1fb18c2821612115607ea
// (supported on CUDA only)
static absl::Status GetComputeCapability(int* cc_major, int* cc_minor,
GpuDeviceHandle device);

// Returns Gpu ISA version for the device; i.e 803, 900.
// (supported on ROCm only)
static absl::Status GetGpuISAVersion(int* version, GpuDeviceHandle device);
Expand Down

0 comments on commit 3f0596e

Please sign in to comment.