From b6828818ef9cbd390a0d69e367afe2611da150ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 00:28:47 +0000 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) - [github.com/psf/black: 24.4.2 → 24.10.0](https://github.com/psf/black/compare/24.4.2...24.10.0) - [github.com/pre-commit/mirrors-clang-format: v18.1.6 → v19.1.6](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.6...v19.1.6) --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100755 => 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100755 new mode 100644 index bbe486fac7..26e2d8cd60 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-merge-conflict - id: check-added-large-files @@ -24,7 +24,7 @@ repos: files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$ - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black name: Format python code @@ -32,7 +32,7 @@ repos: types: [python] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.6 + rev: v19.1.6 hooks: - id: clang-format entry: clang-format -i From f0138bfdbf5602d59ace863ec19a99229750cf9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 00:29:24 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.h | 72 +++++++++++++++------ transformer_engine/pytorch/csrc/type_shim.h | 4 +- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index d47ce472e5..8c7a590aef 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -120,35 +120,51 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kByte: { \ using type = unsigned char; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt32: { \ using type = int32_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kInt64: { \ using type = int64_t; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -159,23 +175,33 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -186,11 +212,15 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ @@ -201,15 +231,21 @@ struct TypeInfo { using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ - { __VA_ARGS__ } \ + { \ + __VA_ARGS__ \ + } \ } break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: { \ diff --git a/transformer_engine/pytorch/csrc/type_shim.h b/transformer_engine/pytorch/csrc/type_shim.h index 8100f0e4a2..6efe4194df 100644 --- a/transformer_engine/pytorch/csrc/type_shim.h +++ b/transformer_engine/pytorch/csrc/type_shim.h @@ -292,7 +292,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1, final = x[tid] + x[tid + 32]; else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); @@ -333,7 +333,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1, final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); else final = val; - // __SYNCWARP(); + // __SYNCWARP(); #pragma unroll for (int i = 16; i >= lanes; i >>= 1)