Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selective merged prefill #643

Open
wants to merge 25 commits into
base: mlperf_features
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import time
from typing import List, Optional
import os

import pandas as pd
import torch
Expand Down Expand Up @@ -71,17 +72,31 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
if os.path.splitext(dataset_path)[1] == ".json":
with open(dataset_path) as f:
dataset = json.load(f)
elif os.path.splitext(dataset_path)[1] == ".pkl":
import pandas as pd
dataset = pd.read_pickle(dataset_path)
dataset = dataset[['input', 'output']].to_dict(orient="records")
for data in dataset:
data["conversations"] = [
{"value": data["input"]},
{"value": data["output"]}
]

# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
prompt_lens = []
for data in dataset:
if len(filtered_dataset) == num_requests:
if args.sort_by_len:
filtered_dataset = sorted(filtered_dataset, key=lambda x: x.prompt_len)
break

# Only keep the first two turns of each conversation.
Expand Down Expand Up @@ -120,7 +135,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))
prompt_lens.append(prompt_len)
print("!!!!prompt length are: ", pd.Series(prompt_lens).describe())

# for i, data in enumerate(filtered_dataset):
# print(i, data.prompt)
return filtered_dataset


Expand Down Expand Up @@ -151,9 +170,9 @@ def run_vllm(
use_beam_search = False

if not use_beam_search:
for _ in range(2):
for _ in range(1):
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
llm.generate(prompts, sampling_params, use_tqdm=False)
end = time.perf_counter()
else:
prompts = [request.prompt for request in requests]
Expand Down Expand Up @@ -445,6 +464,12 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--sort-by-len",
action='store_true',
default=False)
parser.add_argument("--bucket-selective",
action='store_true',
default=False)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand Down
52 changes: 48 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,40 @@
"vLLM will use native implementation.")


def prompt_fsdpa(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op=None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get(
'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1'
# TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE:
key = ops.repeat_kv(key, int(query_heads // kv_heads))
value = ops.repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False,
scale, softmax_mode, recompute_mode, None,
'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights


class HPUAttentionBackend(AttentionBackend):

@staticmethod
Expand Down Expand Up @@ -83,6 +117,9 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
enable_merged_prefill: bool = False
actual_num_prefills: Optional[torch.Tensor] = None
repeated_idx_tensor: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -213,6 +250,7 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
if block_offsets is None:
Expand All @@ -221,7 +259,7 @@ def forward(
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if attn_metadata.is_prompt:
if attn_metadata.is_prompt and not enable_merged_prefill:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
Expand All @@ -232,9 +270,9 @@ def forward(
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -253,10 +291,16 @@ def forward(
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
elif enable_merged_prefill:
pass
else:
attn_bias = None

out = ops.prompt_attention(
if enable_merged_prefill and self.prefill_use_fusedsdpa:
prompt_attn_func = prompt_fsdpa
else:
prompt_attn_func = ops.prompt_attention
out = prompt_attn_func(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
Expand Down
Loading
Loading