Skip to content

Commit

Permalink
fix some pylint error
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Ren <[email protected]>
  • Loading branch information
xrennvidia committed Dec 30, 2024
1 parent d9443d4 commit 09b4f4f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,7 @@ def backward(ctx, dout):
# [t, np] -> [t, np, 1]
softmax_lse.unsqueeze_(-1)

dq = None
dout_dtype = dout.dtype
fused_attn_backend = None
fused_attn_qkv_dtype = None
Expand Down Expand Up @@ -2836,7 +2837,8 @@ def backward(ctx, dout):
)

kv = p2p_comm_buffers[i % 2][0]
dk_, dv_ = None, None
q_, kv_, out_, dout_ = None, None, None, None
dq_, dk_, dv_ = None, None, None
if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
Expand Down Expand Up @@ -3767,7 +3769,6 @@ def backward(ctx, dout):
deterministic=ctx.deterministic,
)
else:
batch_size = k_.shape[0]
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
Expand Down

0 comments on commit 09b4f4f

Please sign in to comment.