Skip to content

Commit

Permalink
Modified with stickbeaking.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yikang Shen [email protected] committed Jun 24, 2024
1 parent f8b6d17 commit bae01ec
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
from sb_varlen import sb_flash_attn_varlen
from .sb_varlen import sb_flash_attn_varlen
import math
# TODO @thomasw21: Compute once, instead of computing for each layers.
# cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
Expand All @@ -218,14 +218,18 @@ def forward(
level=logging.WARNING,
rank=0,
)
attn_output = sb_flash_attn_varlen(
v_ = value_states.permute(1, 0, 2)
attn_output, rem = sb_flash_attn_varlen(
q=query_states.permute(1, 0, 2),
k=key_states.permute(1, 0, 2),
v=value_states.permute(1, 0, 2),
v=v_,
cu_seqlens=cu_seqlens,
inv_temp=sb_scale,
zero_start=False
)

attn_output = attn_output + rem[..., None] * v_
attn_output = attn_output.permute(1, 0, 2)
"""
attn_output = flash_attn_varlen_func(
q=query_states,
Expand Down

0 comments on commit bae01ec

Please sign in to comment.