Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take token count quantization of fused attention into consideration for CP results correction #1396

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct ReadLseFunctor {

template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int total_tokens) {
int num_heads, int lse_seqlen, int second_half_lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
Expand All @@ -85,15 +85,15 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * second_half_lse_seqlen + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];

idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
idx = row * lse_seqlen + col + seq_len;
half_idx = row * second_half_lse_seqlen + col;
}

Functor::run(lse, half_lse, idx, half_idx);
Expand All @@ -108,7 +108,8 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
int num_heads, int dim_per_head, int lse_seqlen,
int lse_per_step_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
Expand All @@ -128,13 +129,13 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float

if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
idx_per_step = head_id * lse_per_step_seqlen + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
idx_per_step = row * lse_per_step_seqlen + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);

Expand Down
13 changes: 11 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2473,6 +2473,10 @@ def forward(

torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

second_half_lse_seqlen = None
if causal and rank < (cp_size - 1):
second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1]

softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
if i <= rank or not causal:
Expand Down Expand Up @@ -2621,6 +2625,7 @@ def forward(
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format
ctx.second_half_lse_seqlen = second_half_lse_seqlen
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
Expand Down Expand Up @@ -2670,10 +2675,14 @@ def backward(ctx, dout):
attn_dbias = None
attn_dbias_ = None

if causal:
softmax_lse_ = None
if causal and ctx.second_half_lse_seqlen is not None:
if ctx.qkv_format == "thd":
softmax_lse_ = tex.thd_read_second_half_lse(
softmax_lse, cu_seqlens_q_padded, ctx.softmax_lse_in_packed_format
softmax_lse,
cu_seqlens_q_padded,
ctx.softmax_lse_in_packed_format,
ctx.second_half_lse_seqlen,
)
else:
# [b, np, sq] -> [b, np, 2, sq//2]
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
const at::Tensor &cu_seqlens, bool lse_packed);

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed);
bool lse_packed, int second_half_lse_seqlen);

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
Expand Down
51 changes: 29 additions & 22 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1420,94 +1420,99 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);

int batch, num_heads, total_tokens;
int batch, num_heads, lse_seqlen, second_half_lse_seqlen;

if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);

batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
lse_seqlen = lse.size(1);
second_half_lse_seqlen = lse_per_step.size(1);

NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2);
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);

batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
lse_seqlen = lse.size(2);
second_half_lse_seqlen = lse_per_step.size(2);

NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}

constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, total_tokens);
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
}
}

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed) {
bool lse_packed, int second_half_lse_seqlen) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);

int batch, num_heads, total_tokens;
int batch, num_heads, lse_seqlen;
std::vector<int64_t> shape;

if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);

batch = cu_seqlens.size(0) - 1;
num_heads = lse.size(0);
total_tokens = lse.size(1);
lse_seqlen = lse.size(1);

shape = {num_heads, total_tokens / 2};
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);

shape = {num_heads, second_half_lse_seqlen};
} else {
NVTE_CHECK(lse.dim() == 3);

batch = lse.size(0);
num_heads = lse.size(1);
total_tokens = lse.size(2);
lse_seqlen = lse.size(2);

NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);

shape = {batch, num_heads, total_tokens / 2};
shape = {batch, num_heads, second_half_lse_seqlen};
}

at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));

constexpr unsigned int block = 256;
unsigned int grid_x = (total_tokens / 2 + block - 1) / block;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, total_tokens);
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, total_tokens);
num_heads, lse_seqlen, second_half_lse_seqlen);
}

return half_lse;
Expand All @@ -1534,23 +1539,25 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);

int batch, lse_seqlen;
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = lse.size(1);
lse_per_step_seqlen = lse_per_step.size(1);

NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
lse_per_step_seqlen = lse_per_step.size(2);

NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}

Expand All @@ -1565,13 +1572,13 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen);
dim_per_head, lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen);
dim_per_head, lse_seqlen, lse_per_step_seqlen);
}
}

Expand Down
Loading