Skip to content

Commit

Permalink
flashinfer: pass window size and dtype (#2574)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk authored Sep 28, 2024
1 parent 5b6b74e commit 1028996
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
19 changes: 13 additions & 6 deletions server/text_generation_server/layers/attention/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
Expand Down Expand Up @@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
Expand Down Expand Up @@ -119,7 +121,8 @@ def use_prefill_state(
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
Expand All @@ -135,7 +138,8 @@ def use_prefill_state(
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
Expand Down Expand Up @@ -200,7 +204,8 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
Expand Down Expand Up @@ -235,7 +240,9 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
Expand Down
4 changes: 4 additions & 0 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,8 @@ def _forward_context(
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
dtype=self.dtype,
window_left=self.sliding_window,
)
else:
assert input_lengths_tensor is not None
Expand All @@ -1971,6 +1973,8 @@ def _forward_context(
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
dtype=self.dtype,
window_left=self.sliding_window,
)


Expand Down

0 comments on commit 1028996

Please sign in to comment.