diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c6b5393ee92..eb5556a1ac9 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -30,7 +30,6 @@ BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels class WrapperDispatch(Enum): @@ -38,6 +37,12 @@ class WrapperDispatch(Enum): CROSS_ATTENTION = auto() +def _grouped_size_compiled_for_decode_kernels( + num_qo_heads: int, num_kv_heads: int +) -> bool: # TODO: Remove me! https://github.com/flashinfer-ai/flashinfer/issues/549 + return (num_qo_heads // num_kv_heads) in [1, 2, 4, 8] + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels."""