diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..6cb4ff6 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,30 @@ +# Test vLLM via Mock Model + +This directory contains scripts to allow for rapid testing and development of benchmarking, evaluation, and stress testing procedures for available models through [vllm](https://github.com/tenstorrent/vllm/tree/dev) + +To run the mock offline inference script `mock_vllm_offline_inference_tt.py` follow the steps below: + +## 1. Build Docker Container + +Follow instructions under `evals/README.md` to build the docker container. To set the environment variables appropriately for the latest supported versions of tt-metal and vllm, refer to [vllm/tt_metal](https://github.com/tenstorrent/vllm/blob/dev/tt_metal/README.md) when setting: + +```bash +export TT_METAL_COMMIT_SHA_OR_TAG= +export VLLM_COMMIT_SHA= +``` + +## 2. Run The Docker Container + +Add a volume mounting the `test` directory in the container before running with the following in the docker run command: + +```bash +--volume $PWD/tests:/home/user/tests +``` + +## 3. Run The Mock Model + +Once in the docker container, run the mock script with: + +```bash +WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python /home/user/tests/mock_vllm_offline_inference_tt.py +``` \ No newline at end of file diff --git a/tests/mock_vllm_model.py b/tests/mock_vllm_model.py new file mode 100644 index 0000000..728ec15 --- /dev/null +++ b/tests/mock_vllm_model.py @@ -0,0 +1,316 @@ +import copy +import os +import sys +import time +from dataclasses import dataclass +from typing import List + +import torch +from loguru import logger + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from tt_metal.models.demos.t3000.llama2_70b.tt.llama_common import ( + setup_llama_env, +) +from tt_metal.models.demos.t3000.llama2_70b.tt.llama_generation import ( + TtLlamaModelForGeneration, + get_padded_prefill_len, +) +from tt_metal.models.demos.t3000.llama2_70b.tt.model_config import ( + get_model_config, +) + + +def new_init_cache_enginer(self): + assert self.cache_config.num_gpu_blocks is not None + + # Get cache path from TT model for caching kv blocks + self.cache_config.tt_cache_path = None + + from vllm.worker.tt_worker import TTCacheEngine + + self.cache_engine = TTCacheEngine( + self.cache_config, self.model_config, self.parallel_config, self.device_config + ) + self.tt_cache = self.cache_engine.tt_cache + + +def new_allocate_kv_cache( + self, + num_blocks: int, + device: str, +) -> List[torch.Tensor]: + """Allocates KV cache on the specified device. + The assumption is that KV cache for a layer is packed into one tensor. + We will have a separate tensor for K and V. + + In the mock implementation, device is always cpu + """ + # K and V each have the following shape: (num_blocks, num_kv_heads, block_size, head_size) + kv_cache_shape = (num_blocks, self.num_kv_heads, self.block_size, self.head_size) + kv_cache: List[torch.Tensor] = [] + num_layers = self.num_attention_layers + if device == "cpu": + for _ in range(num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # Zero-initialize CPU cache + cache_k = torch.zeros(kv_cache_shape, dtype=self.dtype, device=device) + cache_v = torch.zeros(kv_cache_shape, dtype=self.dtype, device=device) + kv_cache.append([cache_k, cache_v]) + self.tt_cache = kv_cache # set tt_cache to just be cpu + return kv_cache + + +class MockModel(TtLlamaModelForGeneration): + # mock implementation in TtLlamaModelForGeneration + # see: tt-metal/models/demos/t3000/llama2_70b/tt/llama_generation.py + # inherits from llama at the moment since only this model is currently used with vllm + def __init__( + self, + configuration, + state_dict, + model_args, + tt_args, + paged_attention_config=None, + vllm=False, + ): + self.params = copy.deepcopy(configuration) + + # required to setup model config + self.llama_version = model_args.llama_version + self.max_batch_size = model_args.max_batch_size + self.max_kv_context_len = model_args.max_kv_context_len + + self.mesh_device = tt_args.mesh_device + + # Initial model_config is set in decode mode + # model conifg is required for vllm + model_config = get_model_config( + llama_version=self.llama_version, + max_batch_size=self.max_batch_size, + max_context_len=self.max_kv_context_len, + vllm=vllm, + ) + self.model_config = model_config + + @classmethod + def initialize_vllm_model(cls, hf_config, t3k_mesh_device, max_batch_size): + # TODO: pass in model args and tt args as parameters from vllm + # Note: since mock, do not load state dict and do not look for mesh device + @dataclass + class ModelArgs: + llama_version: str = None + ckpt_dir: str = None + max_batch_size: int = 32 # overwritten by max_num_seqs from vllm + num_layers: int = 80 + max_kv_context_len: int = 131072 + + @dataclass + class TTArgs: + mesh_device: object = None + cache_path: str = None + + # setup configs + llama_version = "llama3" + model_config, ckpt_dir, _, cache_path = setup_llama_env( + llama_version=llama_version, + ) + # initialize arg classes + model_args = ModelArgs( + llama_version=llama_version, + ckpt_dir=ckpt_dir, + max_batch_size=max_batch_size, + ) + tt_args = TTArgs(mesh_device=t3k_mesh_device, cache_path=cache_path) + + # TODO: delete this configuration setup once llama can directly accept hf_config + import json + from pathlib import Path + + from models.demos.t3000.llama2_70b.reference.llama.llama.model import ( + ModelArgs as ReferenceModelArgs, + ) + + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + configuration = ReferenceModelArgs( + max_seq_len=model_args.max_kv_context_len, + max_batch_size=model_args.max_batch_size, + **params, + ) + + return cls( + configuration=configuration, + state_dict=None, + model_args=model_args, + tt_args=tt_args, + vllm=True, + ) + + def capture_trace( + self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None + ): + """ + Called in TTModelRunner to capture trace for the first decode execution + """ + # mock out computing trace since TT/GPU device is not being used, only return logits from decode pass + tt_logits = self.decode_forward( + tokens, start_pos, page_table, kv_cache + ) # mock out self.tt_model() call + return None, None, None, None, tt_logits, None + + def decode_forward_trace( + self, + tokens: torch.Tensor, + start_pos: int, + trace_id, + tt_inp, + rot_mat, + cache_idxs_tt, + tt_logits, + page_table=None, + tt_page_table=None, + ): + """ + Runs model in TTModelRunner by executing trace + """ + # mock out excuting the trace and only return logits directly + batch, seqlen = tokens.shape + logits = tt_logits + logits = logits[:batch] # Remove padded users + + return logits + + def delete_trace(self, trace_id): + """ + Called to delete trace in TTModelRunner + """ + return + + def prefill_forward_single_user( + self, + tokens: torch.Tensor, + start_pos: int, + user_id: int, + last_token_idx=None, + page_table=None, + kv_cache=None, + ): + return self.decode_forward(tokens=tokens, start_pos=start_pos) + + def prefill_forward( + self, + tokens: torch.Tensor, + start_pos: int, + page_table=None, + kv_cache=None, + prompt_lens=None, + ): + """ + Called in forward when seq_len != 1. + Finds correct padding and calls prefill forward for each user in batch. + """ + + batch, batch_seq_len = tokens.shape + output_logits = torch.zeros(batch, 1, self.params.vocab_size) + prompt_lens = ( + prompt_lens + if prompt_lens is not None + else torch.tensor([batch_seq_len] * batch) + ) + for user_id in range(batch): + seq_len = prompt_lens[user_id] + prefill_seq_len = get_padded_prefill_len(seq_len) + prefill_ids = torch.cat( + [ + tokens[user_id : user_id + 1, :seq_len], + torch.zeros(1, prefill_seq_len - seq_len).long(), + ], + dim=-1, + ) + logger.info(f"Filling kv cache for user {user_id + 1}") + last_token_idx = seq_len - 1 + logits = self.prefill_forward_single_user( + prefill_ids, + start_pos, + user_id, + last_token_idx=last_token_idx, + page_table=page_table, + kv_cache=kv_cache, + ) + # Since we give unpadded_seq_len, only the tile containing the last token is returned + output_logits[user_id] = logits[ + :, last_token_idx % 32 : last_token_idx % 32 + 1, : + ] + + return output_logits + + def decode_forward( + self, + tokens: torch.Tensor, + start_pos: int, + page_table=None, + kv_cache=None, + ): + """ + Does forward pass. consdiring if in prefill stage or decode stage. + """ + assert len(tokens.shape) == 2 + batch, seqlen = tokens.shape + forward_start = time.time() + simulated_tps = 10000.0 + simulated_duration = 1.0 / simulated_tps + # update the new tokens generated to the input id + # vocab_size = tokenizer.nwords + # logits: [batch, seqlen, vocab_size] + logits = torch.randn((batch, seqlen, 128256)) + # send a token every period loops + EOT_ID = 128009 + # EOS_ID = 128001 + send_index = 200 + send_token = EOT_ID + if start_pos is not None: + if isinstance(start_pos, int): + # if start pos is same across batch, ie. now in prefill + cache_idxs = torch.tensor( + [start_pos for _ in range(batch)], dtype=torch.int64 + ) + else: # if start_pos is a tensor ie. is different across batch, now in decode mode + # if start position is greater than index to send EOT + cache_idxs = start_pos.to(dtype=torch.int64) + send_token_mask = cache_idxs > send_index + # find positions where start pos passes send_index (ie. done decoding) + make 1D + batch_indices = torch.nonzero(send_token_mask).squeeze() + # assign a high logit at at the send _token index so model will select it and generate the EOT so that generation stops + logits[batch_indices, 0, send_token] = 100.0 + + actual_duration = time.time() - forward_start + # simulate forward latency + time.sleep(max(simulated_duration - actual_duration, 0)) + return logits + + def forward( + self, + tokens: torch.Tensor, + start_pos: int, + page_table=None, + kv_cache=None, + prompt_lens=None, + ): + """ + Called in TTModelRunner if trace mode is not on + """ + _, seq_len = tokens.shape + if seq_len == 1: + return self.decode_forward( + tokens, start_pos, page_table=page_table, kv_cache=kv_cache + ) + else: + return self.prefill_forward( + tokens, + start_pos, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + ) diff --git a/tests/mock_vllm_offline_inference_tt.py b/tests/mock_vllm_offline_inference_tt.py new file mode 100644 index 0000000..990d06c --- /dev/null +++ b/tests/mock_vllm_offline_inference_tt.py @@ -0,0 +1,239 @@ +import argparse +import json +import time +from unittest.mock import patch + +import uvloop +from mock_vllm_model import MockModel, new_allocate_kv_cache, new_init_cache_enginer +from tqdm import tqdm +from vllm import LLM, ModelRegistry, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args, +) +from vllm.inputs.data import TokensPrompt +from vllm.utils import merge_async_iterators +from vllm.worker.tt_worker import TTCacheEngine, TTWorker + +ModelRegistry.register_model("TTLlamaForCausalLM", MockModel) + + +@patch.object( + TTWorker, "init_device", new=lambda x: None +) # Patch to stop TT device init +@patch.object(TTWorker, "_init_cache_engine", new=new_init_cache_enginer) +@patch.object( + TTCacheEngine, "_allocate_kv_cache", new=new_allocate_kv_cache +) # Patch to stop allocation on TT device since nonexistent +def run_inference( + prompts_json, + max_tokens=128, + max_seqs_in_batch=32, + num_repeat_prompts=2, + measure_perf=False, + perf_prompt_len=None, + greedy_sampling=False, # Option to use greedy decoding instead of top-k/p + async_engine=False, +): + # LLM args + engine_kw_args = { + "model": "meta-llama/Meta-Llama-3.1-70B", + "block_size": 64, + "max_num_seqs": max_seqs_in_batch, + "max_model_len": 131072, + "disable_log_stats": False, + "max_num_batched_tokens": 131072, + "log_global_stats": True if measure_perf else False, + "num_scheduler_steps": 10, + "disable_async_output_proc": True, + } + + # Generation args + ignore_eos = True if measure_perf else False + + if greedy_sampling: + sampling_params = SamplingParams( + max_tokens=max_tokens, ignore_eos=ignore_eos, temperature=0.0 + ) + else: + sampling_params = SamplingParams( + max_tokens=max_tokens, + ignore_eos=ignore_eos, + top_k=10, + top_p=0.9, + temperature=1.0, + ) + + # Prepare inputs + if not measure_perf: + # Load prompts from a JSON file + with open(prompts_json, "r") as file: + prompts = json.load(file) + assert isinstance(prompts, list), "Prompts must be a list of strings" + if num_repeat_prompts is not None: + prompts = prompts * num_repeat_prompts + print("Number of prompts:", len(prompts)) + else: + assert ( + perf_prompt_len is not None + ), "perf_prompt_len is required to generate dummy prompts" + print("Measuring performance with dummy prompts of length", perf_prompt_len) + prompt_token_ids = [[0] * perf_prompt_len] * max_seqs_in_batch # dummy prompts + sampling_params = ( + sampling_params[:max_seqs_in_batch] + if isinstance(sampling_params, list) + else sampling_params + ) + + # Set an arbitrary max_tokens to simulate generating multiple tokens consecutively + print("Generating prompts with output length", max_tokens) + sampling_params.max_tokens = max_tokens + + max_model_len = engine_kw_args["max_model_len"] + assert_str = f"prompt length ({perf_prompt_len}) + num generated tokens ({sampling_params.max_tokens}) will exceed max_model_len ({max_model_len})" + assert perf_prompt_len + sampling_params.max_tokens <= max_model_len, assert_str + + # Create and run LLM + if not async_engine: + llm = LLM(**engine_kw_args) + if not measure_perf: + generate_tokens(llm, prompts, sampling_params, print_output=True) + else: + run_inference_perf(llm, prompt_token_ids, sampling_params) + else: + print("Using async engine") + engine_args = AsyncEngineArgs(**engine_kw_args) + + async def _run_inference_async(): + async with build_async_engine_client_from_engine_args(engine_args) as llm: + if not measure_perf: + await generate_tokens_async( + llm, prompts, sampling_params, print_output=True + ) + else: + await run_inference_perf_async( + llm, prompt_token_ids, sampling_params + ) + + uvloop.run(_run_inference_async()) + + +def run_inference_perf( + llm: LLM, + prompt_token_ids, + sampling_params, + N_warmup=1, + N_inference=4, +): + for i in tqdm(range(N_inference), desc="Inference runs"): + if i == N_warmup: + start_time = time.perf_counter() + generate_tokens( + llm, None, sampling_params, prompt_token_ids, print_output=False + ) + avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup) + print(f"Average time taken per inference run: {avg_time:.2f} s") + + +async def run_inference_perf_async( + llm: LLM, + prompt_token_ids, + sampling_params, + N_warmup=1, + N_inference=4, +): + for i in tqdm(range(N_inference), desc="Inference runs"): + if i == N_warmup: + start_time = time.perf_counter() + await generate_tokens_async( + llm, None, sampling_params, prompt_token_ids, print_output=False + ) + avg_time = (time.perf_counter() - start_time) / (N_inference - N_warmup) + print(f"Average time taken per inference run: {avg_time:.2f} s") + + +def generate_tokens( + llm: LLM, prompts, sampling_params, prompt_token_ids=None, print_output=True +): + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params, prompt_token_ids) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + if print_output: + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +async def generate_tokens_async( + llm: MQLLMEngineClient, + prompts, + sampling_params, + prompt_token_ids=None, + print_output=True, +): + # async def _generate_tokens_async(llm, prompts, sampling_params, prompt_token_ids, print_output): + # Use tokenized prompts if provided + if prompt_token_ids is not None: + prompts = [] + for single_prompt_token_ids in prompt_token_ids: + prompts.append(TokensPrompt(prompt_token_ids=single_prompt_token_ids)) + + if not isinstance(sampling_params, list): + sampling_params = [sampling_params] * len(prompts) + + generators = [] + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + prompt = res.prompt + generated_text = res.outputs[0].text + if print_output and res.finished: + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompts_json", + type=str, + default="/home/user/vllm/tt_metal/prompts.json", + help="Path to JSON file containing prompts", + ) + parser.add_argument( + "--measure_perf", action="store_true", help="Measure performance" + ) + parser.add_argument( + "--perf_prompt_len", + type=int, + default=128, + help="Length of dummy prompts for performance measurement", + ) + parser.add_argument("--max_tokens", type=int, default=128, help="Length of outputs") + parser.add_argument( + "--greedy_sampling", + action="store_true", + help="Use greedy decoding instead of top-k/p", + ) + parser.add_argument( + "--max_seqs_in_batch", + type=int, + default=32, + help="Maximum batch size for inference", + ) + parser.add_argument("--async_engine", action="store_true", help="Use async engine") + args = parser.parse_args() + + run_inference( + args.prompts_json, + measure_perf=args.measure_perf, + perf_prompt_len=args.perf_prompt_len, + max_tokens=args.max_tokens, + greedy_sampling=args.greedy_sampling, + max_seqs_in_batch=args.max_seqs_in_batch, + async_engine=args.async_engine, + )