diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100755 new mode 100644 index bbe486fac7..26e2d8cd60 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-merge-conflict - id: check-added-large-files @@ -24,7 +24,7 @@ repos: files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$ - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black name: Format python code @@ -32,7 +32,7 @@ repos: types: [python] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.6 + rev: v19.1.6 hooks: - id: clang-format entry: clang-format -i diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index d47ce472e5..8c7a590aef 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -120,35 +120,51 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kByte: { \ using type = unsigned char; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt32: { \ using type = int32_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt64: { \ using type = int64_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -159,23 +175,33 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -186,11 +212,15 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -201,15 +231,21 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: { \ diff --git a/transformer_engine/pytorch/csrc/type_shim.h b/transformer_engine/pytorch/csrc/type_shim.h index 8100f0e4a2..6efe4194df 100644 --- a/transformer_engine/pytorch/csrc/type_shim.h +++ b/transformer_engine/pytorch/csrc/type_shim.h @@ -292,7 +292,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1, final = x[tid] + x[tid + 32]; else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); @@ -333,7 +333,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1, final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1)