diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 005d3162..51bbea17 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -195,7 +195,7 @@ def forward( kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) ): # from flash_attn.flash_attn_interface import flash_attn_varlen_func - from sb_varlen import sb_flash_attn_varlen + from .sb_varlen import sb_flash_attn_varlen import math # TODO @thomasw21: Compute once, instead of computing for each layers. # cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) @@ -218,14 +218,18 @@ def forward( level=logging.WARNING, rank=0, ) - attn_output = sb_flash_attn_varlen( + v_ = value_states.permute(1, 0, 2) + attn_output, rem = sb_flash_attn_varlen( q=query_states.permute(1, 0, 2), k=key_states.permute(1, 0, 2), - v=value_states.permute(1, 0, 2), + v=v_, cu_seqlens=cu_seqlens, inv_temp=sb_scale, zero_start=False ) + + attn_output = attn_output + rem[..., None] * v_ + attn_output = attn_output.permute(1, 0, 2) """ attn_output = flash_attn_varlen_func( q=query_states,