diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index f44edffe66..eaf6de680a 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -57,15 +57,11 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } } -void nvte_cudnn_handle_init() { - auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); -} +void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } namespace detail { -void CreateCuDNNHandle(cudnnHandle_t* handle) { - NVTE_CHECK_CUDNN(cudnnCreate(handle)); -} +void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } } // namespace detail @@ -76,6 +72,6 @@ namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. -void *cudnn_dlhandle = nullptr; +void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bbd2b2ed50..0a78ded306 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -14,8 +14,8 @@ #include #include "../common.h" -#include "../util/logging.h" #include "../util/handle_manager.h" +#include "../util/logging.h" namespace { @@ -47,7 +47,7 @@ uint32_t _getAlignment(uintptr_t address) { } } -inline void CreateCublasHandle(cublasLtHandle_t* handle) { +inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h index 1bdf3e92d2..27666c6948 100644 --- a/transformer_engine/common/util/handle_manager.h +++ b/transformer_engine/common/util/handle_manager.h @@ -7,20 +7,20 @@ #ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ #define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#include + #include #include -#include -#include "logging.h" + #include "cuda_runtime.h" +#include "logging.h" namespace transformer_engine::detail { -template +template class HandleManager { public: - static HandleManager &Instance() { + static HandleManager& Instance() { static thread_local HandleManager instance; return instance; } @@ -29,9 +29,7 @@ class HandleManager { static std::vector flags(handles_.size()); int device_id = cuda::current_device(); NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); - auto init = [&]() { - Create(&(handles_[device_id])); - }; + auto init = [&]() { Create(&(handles_[device_id])); }; std::call_once(flags[device_id], init); return handles_[device_id]; }