Skip to content

Commit

Permalink
Fix for multiple GPUs per thread
Browse files Browse the repository at this point in the history
Signed-off-by: Przemek Tredak <[email protected]>
  • Loading branch information
ptrendx committed Jan 6, 2025
1 parent 7b868b0 commit 529aec2
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 47 deletions.
10 changes: 9 additions & 1 deletion transformer_engine/common/cudnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,17 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
}

void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto _ = cudnnExecutionPlanManager::Instance().GetHandle();
}

namespace detail {

void CreateCuDNNHandle(cudnnHandle_t* handle) {
NVTE_CHECK_CUDNN(cudnnCreate(handle));
}

} // namespace detail

} // namespace transformer_engine

namespace cudnn_frontend {
Expand Down
30 changes: 9 additions & 21 deletions transformer_engine/common/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,25 @@
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>

#include <cstdint>
#include <mutex>
#include <cudnn_graph.h>

#include "transformer_engine/transformer_engine.h"
#include "util/handle_manager.h"

namespace transformer_engine {

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
namespace detail {

cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
void CreateCuDNNHandle(cudnnHandle_t* handle);

class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
} // namespace detail

cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);

~cudnnExecutionPlanManager() {}
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);

private:
cudnnHandle_t handle_ = nullptr;
};
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;

} // namespace transformer_engine

#endif
#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_
12 changes: 6 additions & 6 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
t = input_QKV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Expand Down Expand Up @@ -386,7 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
t = input_QKV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Expand Down Expand Up @@ -486,7 +486,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
t_kv = input_KV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);

Expand Down Expand Up @@ -577,7 +577,7 @@ void nvte_fused_attn_bwd_kvpacked(
t_kv = input_KV->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);

Expand Down Expand Up @@ -674,7 +674,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

Expand Down Expand Up @@ -761,7 +761,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

Expand Down
24 changes: 6 additions & 18 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "../common.h"
#include "../util/logging.h"
#include "../util/handle_manager.h"

namespace {

Expand Down Expand Up @@ -46,28 +47,15 @@ uint32_t _getAlignment(uintptr_t address) {
}
}

inline void CreateCublasHandle(cublasLtHandle_t* handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

} // namespace

namespace transformer_engine {

class cublasHandleManager {
public:
static cublasHandleManager &Instance() {
static thread_local cublasHandleManager instance;
return instance;
}

cublasLtHandle_t GetHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { NVTE_CHECK_CUBLAS(cublasLtCreate(&handle_)); });
return handle_;
}

~cublasHandleManager() {}

private:
cublasLtHandle_t handle_ = nullptr;
};
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);

_handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
_handle = cudnnExecutionPlanManager::Instance().GetHandle();

_graph.set_io_data_type(get_cudnn_fe_dtype(itype))
.set_intermediate_data_type(get_cudnn_fe_dtype(ctype))
Expand Down
55 changes: 55 additions & 0 deletions transformer_engine/common/util/handle_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_

#include <mutex>
#include <vector>
#include <cuda_runtime_api.h>
#include "logging.h"
#include "cuda_runtime.h"

namespace transformer_engine::detail {

template <typename Handle,
void Create(Handle*),
void Destroy(Handle) = nullptr>
class HandleManager {
public:
static HandleManager &Instance() {
static thread_local HandleManager instance;
return instance;
}

Handle GetHandle() {
static std::vector<std::once_flag> flags(handles_.size());
int device_id = cuda::current_device();
NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID");
auto init = [&]() {
Create(&(handles_[device_id]));
};
std::call_once(flags[device_id], init);
return handles_[device_id];
}

~HandleManager() {
if (Destroy != nullptr) {
for (auto& handle : handles_) {
Destroy(handle);
}
}
}

private:
HandleManager() : handles_(cuda::num_devices(), nullptr) {}

std::vector<Handle> handles_ = nullptr;
};

} // namespace transformer_engine::detail

#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_

0 comments on commit 529aec2

Please sign in to comment.