Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 6, 2025
1 parent 529aec2 commit 44c7d70
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 18 deletions.
10 changes: 3 additions & 7 deletions transformer_engine/common/cudnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,11 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
}
}

void nvte_cudnn_handle_init() {
auto _ = cudnnExecutionPlanManager::Instance().GetHandle();
}
void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }

namespace detail {

void CreateCuDNNHandle(cudnnHandle_t* handle) {
NVTE_CHECK_CUDNN(cudnnCreate(handle));
}
void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }

} // namespace detail

Expand All @@ -76,6 +72,6 @@ namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
void* cudnn_dlhandle = nullptr;

} // namespace cudnn_frontend
4 changes: 2 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#include <mutex>

#include "../common.h"
#include "../util/logging.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"

namespace {

Expand Down Expand Up @@ -47,7 +47,7 @@ uint32_t _getAlignment(uintptr_t address) {
}
}

inline void CreateCublasHandle(cublasLtHandle_t* handle) {
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}

Expand Down
16 changes: 7 additions & 9 deletions transformer_engine/common/util/handle_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_

#include <cuda_runtime_api.h>

#include <mutex>
#include <vector>
#include <cuda_runtime_api.h>
#include "logging.h"

#include "cuda_runtime.h"
#include "logging.h"

namespace transformer_engine::detail {

template <typename Handle,
void Create(Handle*),
void Destroy(Handle) = nullptr>
template <typename Handle, void Create(Handle*), void Destroy(Handle) = nullptr>
class HandleManager {
public:
static HandleManager &Instance() {
static HandleManager& Instance() {
static thread_local HandleManager instance;
return instance;
}
Expand All @@ -29,9 +29,7 @@ class HandleManager {
static std::vector<std::once_flag> flags(handles_.size());
int device_id = cuda::current_device();
NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID");
auto init = [&]() {
Create(&(handles_[device_id]));
};
auto init = [&]() { Create(&(handles_[device_id])); };
std::call_once(flags[device_id], init);
return handles_[device_id];
}
Expand Down

0 comments on commit 44c7d70

Please sign in to comment.