Skip to content

Commit

Permalink
Revert "Update the cache functions"
Browse files Browse the repository at this point in the history
This reverts commit 4a30634.
  • Loading branch information
archana-ramalingam committed Jan 24, 2025
1 parent b61a747 commit 7893014
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,25 @@ def paged_attention(
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride

xk, xv = attention_block.transact_cache(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
)
if attention_block.cache.is_paged:
xk, xv = attention_block.transact_cache_paged(
xk_cache_update=xk,
xv_cache_update=xv,
seq_block_ids=seq_block_ids,
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
)
elif attention_block.cache.is_direct:
xk, xv = attention_block.transact_cache_direct(
xk_cache_update=xk,
xv_cache_update=xv,
start_positions=start_positions,
kv_seq_len=kv_seq_len,
cache_state=cache_state,
)
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}")

# Expand kv heads for GQA.
gqa_n_rep = attention_block.head_count // attention_block.head_count_kv
Expand Down

0 comments on commit 7893014

Please sign in to comment.