Skip to content

Commit

Permalink
Change MI300A to use hipMalloc per LC tips
Browse files Browse the repository at this point in the history
  • Loading branch information
zatkins-dev committed Feb 4, 2025
1 parent ff4acf1 commit d0fa0f5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 90 deletions.
132 changes: 45 additions & 87 deletions backends/hip-ref/ceed-hip-ref-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
CeedVector_Hip *impl;

CeedCallBackend(CeedVectorGetData(vec, &impl));
CeedCheck(impl->h_array && !impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "Unified shared memory should only use host pointers");
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
CeedCheck(impl->d_array && !impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND,
"Unified shared memory should only use device pointers");
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -155,8 +157,8 @@ static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

switch (mem_type) {
case CEED_MEM_HOST:
Expand All @@ -179,8 +181,8 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, Cee
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

switch (mem_type) {
case CEED_MEM_HOST:
Expand Down Expand Up @@ -239,8 +241,8 @@ static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));

// Use host memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

switch (local_mem_type) {
case CEED_MEM_HOST:
Expand All @@ -267,7 +269,6 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
CeedVector_Hip *impl;
Ceed_Hip *hip_data;
hipblasHandle_t handle;
CeedScalar *d_array;

CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(vec), &handle));
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
Expand All @@ -280,11 +281,8 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
length = length_vec < length_copy ? length_vec : length_copy;
}

// Use host memory for unified memory
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;

// Set value for synced device/host array
if (d_array) {
if (impl->d_array) {
CeedScalar *copy_array;

// Number of values to copy
Expand All @@ -293,12 +291,11 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
#if defined(CEED_SCALAR_IS_FP32)
CeedCallHipblas(CeedVectorReturnCeed(vec),
hipblasScopy_64(handle, (int64_t)length, d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
hipblasScopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
#else
CeedCallHipblas(CeedVectorReturnCeed(vec),
hipblasDcopy_64(handle, (int64_t)length, d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
hipblasDcopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
#endif
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
} else if (impl->h_array) {
CeedScalar *copy_array;
Expand Down Expand Up @@ -331,7 +328,6 @@ int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val)
static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
CeedSize length;
CeedVector_Hip *impl;
CeedScalar *d_array;
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
Expand All @@ -352,15 +348,11 @@ static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
}
}

// Use host memory for unified memory
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;

if (d_array) {
CeedCallBackend(CeedDeviceSetValue_Hip(d_array, length, val));
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
if (!hip_data->has_unified_addressing) impl->h_array = NULL;
if (impl->d_array) {
CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
impl->h_array = NULL;
}
if (impl->h_array && d_array != impl->h_array) {
if (impl->h_array) {
CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
impl->d_array = NULL;
}
Expand All @@ -387,20 +379,15 @@ static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSiz
CeedSize length;
CeedVector_Hip *impl;
Ceed_Hip *hip_data;
CeedScalar *d_array;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));
CeedCallBackend(CeedVectorGetLength(vec, &length));

// Use host memory for unified memory
d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;

// Set value for synced device/host array
if (d_array) {
CeedCallBackend(CeedDeviceSetValueStrided_Hip(d_array, start, step, length, val));
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
if (!hip_data->has_unified_addressing) impl->h_array = NULL;
if (impl->d_array) {
CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, step, length, val));
impl->h_array = NULL;
} else if (impl->h_array) {
CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, step, length, val));
impl->d_array = NULL;
Expand All @@ -420,8 +407,8 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedSca
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

// Sync array to requested mem_type
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
Expand Down Expand Up @@ -453,8 +440,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec, CeedMemType mem_type
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

// Sync array to requested mem_type
CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
Expand Down Expand Up @@ -489,8 +476,8 @@ static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_ty
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, local_mem_type, array));
CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
Expand All @@ -517,8 +504,8 @@ static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType m
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));

// Use host memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : mem_type;
// Use device memory for unified memory
local_mem_type = hip_data->has_unified_addressing ? CEED_MEM_DEVICE : mem_type;

CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, local_mem_type, &has_array_of_type));
if (!has_array_of_type) {
Expand Down Expand Up @@ -557,8 +544,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));

// Compute norm
CeedMemType mem_type = hip_data->has_unified_addressing ? CEED_MEM_HOST : CEED_MEM_DEVICE;
CeedCallBackend(CeedVectorGetArrayRead(vec, mem_type, &d_array));
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));

*norm = 0.0;
switch (type) {
Expand Down Expand Up @@ -625,20 +611,14 @@ static int CeedVectorReciprocal_Hip(CeedVector vec) {
CeedSize length;
CeedVector_Hip *impl;
Ceed_Hip *hip_data;
CeedScalar *d_array;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));
CeedCallBackend(CeedVectorGetData(vec, &impl));
CeedCallBackend(CeedVectorGetLength(vec, &length));

d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;

// Set value for synced device/host array
if (d_array) {
CeedCallBackend(CeedDeviceReciprocal_Hip(d_array, length));
CeedCallHip(CeedVectorReturnCeed(vec), hipDeviceSynchronize());
}
if (impl->h_array && d_array != impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -658,25 +638,21 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
CeedVector_Hip *impl;
Ceed_Hip *hip_data;
hipblasHandle_t handle;
CeedScalar *d_array;

CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(x), &hip_data));
CeedCallBackend(CeedVectorGetData(x, &impl));
CeedCallBackend(CeedVectorGetLength(x, &length));

d_array = hip_data->has_unified_addressing ? impl->h_array : impl->d_array;

// Set value for synced device/host array
if (d_array) {
if (impl->d_array) {
#if defined(CEED_SCALAR_IS_FP32)
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, d_array, 1));
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
#else
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, d_array, 1));
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
#endif
CeedCallHip(CeedVectorReturnCeed(x), hipDeviceSynchronize());
}
if (impl->h_array && d_array != impl->h_array) CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
if (impl->h_array) CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
return CEED_ERROR_SUCCESS;
}

Expand All @@ -696,7 +672,6 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
CeedVector_Hip *y_impl, *x_impl;
Ceed_Hip *hip_data;
hipblasHandle_t handle;
CeedScalar *x_d_array, *y_d_array;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(y), &hip_data));
CeedCallBackend(CeedVectorGetData(y, &y_impl));
Expand All @@ -705,20 +680,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {

CeedCallBackend(CeedVectorGetLength(y, &length));

y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;

// Set value for synced device/host array
if (y_d_array) {
if (y_impl->d_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
#if defined(CEED_SCALAR_IS_FP32)
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_d_array, 1, y_d_array, 1));
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
#else
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_d_array, 1, y_d_array, 1));
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
#endif
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
}
if (y_impl->h_array && y_d_array != y_impl->h_array) {
if (y_impl->h_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
}
Expand Down Expand Up @@ -746,7 +717,6 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,
CeedVector_Hip *y_impl, *x_impl;
Ceed_Hip *hip_data;
hipblasHandle_t handle;
CeedScalar *x_d_array, *y_d_array;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(y), &hip_data));
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
Expand All @@ -755,16 +725,12 @@ static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta,

CeedCallBackend(CeedVectorGetLength(y, &length));

y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;

// Set value for synced device/host array
if (y_d_array) {
if (y_impl->d_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
CeedCallBackend(CeedDeviceAXPBY_Hip(y_d_array, alpha, beta, x_d_array, length));
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
}
if (y_impl->h_array && y_d_array != y_impl->h_array) {
if (y_impl->h_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
}
Expand All @@ -790,7 +756,6 @@ int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedSc
static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
CeedSize length;
CeedVector_Hip *w_impl, *x_impl, *y_impl;
CeedScalar *w_d_array, *x_d_array, *y_d_array;
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(CeedVectorReturnCeed(x), &hip_data));
Expand All @@ -804,17 +769,12 @@ static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y)
CeedCallBackend(CeedVectorSetValue(w, 0.0));
}

w_d_array = hip_data->has_unified_addressing ? w_impl->h_array : w_impl->d_array;

if (w_d_array) {
if (w_impl->d_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
x_d_array = hip_data->has_unified_addressing ? x_impl->h_array : x_impl->d_array;
y_d_array = hip_data->has_unified_addressing ? y_impl->h_array : y_impl->d_array;
CeedCallBackend(CeedDevicePointwiseMult_Hip(w_d_array, x_d_array, y_d_array, length));
CeedCallHip(CeedVectorReturnCeed(y), hipDeviceSynchronize());
CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
}
if (w_impl->h_array && w_d_array != w_impl->h_array) {
if (w_impl->h_array) {
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
Expand All @@ -832,9 +792,7 @@ static int CeedVectorDestroy_Hip(const CeedVector vec) {
CeedCallBackend(CeedVectorGetData(vec, &impl));
CeedCallBackend(CeedGetData(CeedVectorReturnCeed(vec), &hip_data));

if (!hip_data->has_unified_addressing) {
CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
}
CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
CeedCallBackend(CeedFree(&impl->h_array_owned));
CeedCallBackend(CeedFree(&impl));
return CEED_ERROR_SUCCESS;
Expand Down
3 changes: 0 additions & 3 deletions backends/hip/ceed-hip-compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ int CeedRunKernel_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, cons

CeedCallBackend(CeedGetData(ceed, &data));
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size, 1, 1, 0, NULL, args, NULL));
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
return CEED_ERROR_SUCCESS;
}

Expand All @@ -184,7 +183,6 @@ int CeedRunKernelDim_Hip(Ceed ceed, hipFunction_t kernel, const int grid_size, c

CeedCallBackend(CeedGetData(ceed, &data));
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, 0, NULL, args, NULL));
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
return CEED_ERROR_SUCCESS;
}

Expand All @@ -197,7 +195,6 @@ int CeedRunKernelDimShared_Hip(Ceed ceed, hipFunction_t kernel, const int grid_s

CeedCallBackend(CeedGetData(ceed, &data));
CeedCallHip(ceed, hipModuleLaunchKernel(kernel, grid_size, 1, 1, block_size_x, block_size_y, block_size_z, shared_mem_size, NULL, args, NULL));
if (data->has_unified_addressing) CeedCallHip(ceed, hipDeviceSynchronize());
return CEED_ERROR_SUCCESS;
}

Expand Down

0 comments on commit d0fa0f5

Please sign in to comment.