Skip to content

Commit

Permalink
flashinfer: switch to plan API
Browse files Browse the repository at this point in the history
This change doesn't switch `forward` to `run` yet, since it requires
that we have access to the softmax scale and the logit softcap outside
the model.
  • Loading branch information
danieldk authored and Narsil committed Jan 17, 2025
1 parent d61f14f commit 0ce2dff
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
1 change: 0 additions & 1 deletion server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def attention(
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
Expand Down
6 changes: 2 additions & 4 deletions server/text_generation_server/layers/attention/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def use_prefill_with_paged_kv_state(

token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
state.plan(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
Expand All @@ -99,7 +99,6 @@ def use_prefill_with_paged_kv_state(
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)

Expand Down Expand Up @@ -200,7 +199,7 @@ def use_decode_state(
token = decode_state.set(state)

try:
state.begin_forward(
state.plan(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
Expand All @@ -214,6 +213,5 @@ def use_decode_state(
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

0 comments on commit 0ce2dff

Please sign in to comment.