From 71cd6d55337b1541f602c1afffa6baf6dd75b09c Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 31 Dec 2024 09:32:22 +0100 Subject: [PATCH] Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688) * update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace --- candle-flash-attn/build.rs | 1 + candle-flash-attn/cutlass | 2 +- candle-flash-attn/kernels/block_info.h | 8 ++-- candle-flash-attn/kernels/flash.h | 13 ++---- .../flash_fwd_hdim128_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim128_bf16_sm80.cu | 2 +- .../flash_fwd_hdim128_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim128_fp16_sm80.cu | 2 +- .../flash_fwd_hdim160_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim160_bf16_sm80.cu | 2 +- .../flash_fwd_hdim160_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim160_fp16_sm80.cu | 2 +- .../flash_fwd_hdim192_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim192_bf16_sm80.cu | 2 +- .../flash_fwd_hdim192_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim192_fp16_sm80.cu | 2 +- .../flash_fwd_hdim224_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim224_bf16_sm80.cu | 2 +- .../flash_fwd_hdim224_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim224_fp16_sm80.cu | 2 +- .../flash_fwd_hdim256_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim256_bf16_sm80.cu | 2 +- .../flash_fwd_hdim256_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim256_fp16_sm80.cu | 2 +- .../flash_fwd_hdim32_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim32_bf16_sm80.cu | 2 +- .../flash_fwd_hdim32_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim32_fp16_sm80.cu | 2 +- .../flash_fwd_hdim64_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim64_bf16_sm80.cu | 2 +- .../flash_fwd_hdim64_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim64_fp16_sm80.cu | 2 +- .../flash_fwd_hdim96_bf16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim96_bf16_sm80.cu | 2 +- .../flash_fwd_hdim96_fp16_causal_sm80.cu | 2 +- .../kernels/flash_fwd_hdim96_fp16_sm80.cu | 2 +- candle-flash-attn/kernels/flash_fwd_kernel.h | 30 ++++++------- .../kernels/flash_fwd_launch_template.h | 15 ++++--- candle-flash-attn/kernels/hardware_info.h | 42 +++++++++++++++++++ candle-flash-attn/kernels/kernel_traits.h | 30 ++++++------- candle-flash-attn/kernels/utils.h | 18 ++++++++ 41 files changed, 140 insertions(+), 83 deletions(-) create mode 100644 candle-flash-attn/kernels/hardware_info.h diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 53fec5deab..37247646e3 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -54,6 +54,7 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); println!("cargo:rerun-if-changed=kernels/block_info.h"); println!("cargo:rerun-if-changed=kernels/static_switch.h"); + println!("cargo:rerun-if-changed=kernels/hardware_info.h"); let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { Err(_) => diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass index 7d49e6c7e2..4c42f73fda 160000 --- a/candle-flash-attn/cutlass +++ b/candle-flash-attn/cutlass @@ -1 +1 @@ -Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc +Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 3a23a1e1f2..cf60d653c3 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -18,8 +18,9 @@ struct BlockInfo { , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -30,13 +31,14 @@ struct BlockInfo { template __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; } const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; const int seqlen_k_cache; const int actual_seqlen_k; }; diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index 88c2f22a59..f21e4d6205 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,13 +7,7 @@ #include #include -// #ifdef OLD_GENERATOR_PATH -// #include -// #else -// #include -// #endif -// -// #include // For at::cuda::philox::unpack +// #include // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; @@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +// template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu index f19049b496..9383c10249 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index cb13574195..f03abda486 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu index dfb04b78b8..c616628c87 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 6df16b2c34..4ff6b9fbfb 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu index 230af9069c..d6d4371bfb 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu index cf1ffad209..5af68ac38f 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu index 1fc5ac5970..1ef511a6b7 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu index a9796aded8..96abfbd8a1 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu index 94792d4d3b..077d25d091 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index 76d5136b1d..ea5f265fe3 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu index 9e5b21e022..a4a7bc2422 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index b4019a0bef..c30c4a14fe 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu index a12a5f4ad7..db69f21cdf 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu index 8690bdb1a4..9a11724b2b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu index f01dad09cf..d02edae078 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu index 7ec1e16b7f..28150ed0ad 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu index 3d816ab608..f84e978c91 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index c6c55229c3..c52f0417b9 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu index 0149abacd2..f96f7edc67 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index 9c9a1715e7..9c7c6b93d8 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu index 29097ac3a1..e21d0408ca 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index cb52f34fa9..f377a5b8fa 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu index 7bdadefbea..74e4d66ae9 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index 44b3881610..e85db18e39 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu index 99cd728bcf..9297e8bb68 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index c11096ac12..8364b1e7ee 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu index 2fbcd44e65..1c6ed7ef02 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 7b65a9c9ec..3c87573ba2 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu index 6fb3cf6427..49fae856a5 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index e696b2f2cd..c5af1cf634 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu index bb3b744d15..b0d6c9928e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index 5f3accc300..c97aa33f8b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 1bf77f81d3..b6b26d5207 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +// #include "philox_unpack.cuh" // For at::cuda::philox::unpack + #include #include @@ -22,14 +24,6 @@ namespace flash { using namespace cute; -template -__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( @@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. @@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), @@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } @@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 9e5449d736..bb581eb369 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -3,11 +3,11 @@ ******************************************************************************/ #pragma once - -// #include +// #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "error.h" #include "static_switch.h" +#include "hardware_info.h" #include "flash.h" #include "flash_fwd_kernel.h" @@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), diff --git a/candle-flash-attn/kernels/hardware_info.h b/candle-flash-attn/kernels/hardware_info.h new file mode 100644 index 0000000000..d5c48d3517 --- /dev/null +++ b/candle-flash-attn/kernels/hardware_info.h @@ -0,0 +1,42 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#if !defined(__CUDACC_RTC__) +#include "cuda_runtime.h" +#endif + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + + +inline int get_current_device() { + int device; + CHECK_CUDA(cudaGetDevice(&device)); + return device; +} + +inline std::tuple get_compute_capability(int device) { + int capability_major, capability_minor; + CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device)); + return {capability_major, capability_minor}; +} + +inline int get_num_sm(int device) { + int multiprocessor_count; + CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + return multiprocessor_count; +} diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 5a7b74911d..8c0897488d 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom; + using SmemCopyAtomO = Copy_Atom, Element>; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store @@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; @@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load }; @@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - using SmemCopyAtomPdS = Copy_Atom; + using SmemCopyAtomPdS = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); @@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom; + using SmemCopyAtomdKV = Copy_Atom, elem_type>; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, @@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom; + using SmemCopyAtomdQ = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< @@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 708aeddfa3..b7408ec444 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash