Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Mock Model + Test Directory #24

Merged
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ee9be29
adding vLLM dockerfile
tstescoTT Oct 15, 2024
efe648b
adding label GHCR and sed import edit for vllm server_example_tt.py
tstescoTT Oct 22, 2024
0e5f431
adding evals instructions and run script
tstescoTT Oct 22, 2024
ffdd77a
move vllm llama 3.1 70b implementation to top level model impl dir
tstescoTT Oct 23, 2024
5ef0d64
update eval instructions
tstescoTT Oct 23, 2024
33340dc
adding llama 3.1 70b benchmarking instructions
tstescoTT Oct 23, 2024
50626e4
adding dir for locust
tstescoTT Oct 23, 2024
11c604a
add doc link for vllm setup
tstescoTT Oct 23, 2024
45f544d
adding pre-commit with ruff linting and formatting
tstescoTT Oct 23, 2024
64281b0
add pre-commit instructions
tstescoTT Oct 23, 2024
f9b4fb3
add GHCR repo connection label to Dockerfile
tstescoTT Oct 23, 2024
663e090
added mock model
mvanniasingheTT Oct 22, 2024
ca65d95
created + add mock model and mock offline inference w/ patches
mvanniasingheTT Oct 24, 2024
b71e261
removed unneeded imports from mock model
mvanniasingheTT Oct 24, 2024
e8ab7a6
update python path to include vllm
mvanniasingheTT Oct 24, 2024
1a72731
uncommetned sys path line
mvanniasingheTT Oct 25, 2024
bb284db
update dockerfile to pip install for specifc commit of vllm
mvanniasingheTT Oct 25, 2024
9986877
update readme mock instructions
mvanniasingheTT Oct 25, 2024
80138ce
update readme mock instructions
mvanniasingheTT Oct 25, 2024
a03566e
comments for mock patches
mvanniasingheTT Oct 25, 2024
fb50708
update mock to work with latest vllm commit 82dbca6
mvanniasingheTT Oct 28, 2024
ced6860
remove tracing calls in mock model
mvanniasingheTT Oct 28, 2024
0c7d005
remove uneeded imports
mvanniasingheTT Oct 28, 2024
4f8e512
remove checkout specfic commit
mvanniasingheTT Oct 28, 2024
77e6699
remove imports and comments
mvanniasingheTT Oct 28, 2024
bf2e08b
cleanup
mvanniasingheTT Oct 28, 2024
fda72f7
resolved conflict
mvanniasingheTT Nov 1, 2024
e9a9f54
resolved conflict
mvanniasingheTT Nov 1, 2024
66352f8
update read me with clearer commit hash instructions
mvanniasingheTT Nov 1, 2024
9b14371
update read me with clearer commit hash instructions
mvanniasingheTT Nov 1, 2024
1ee28bc
Update tests/README.md
mvanniasingheTT Nov 1, 2024
61d872a
Update tests/README.md
mvanniasingheTT Nov 4, 2024
f607fcb
Update tests/README.md
mvanniasingheTT Nov 4, 2024
ab44d3e
remove unused TtLlamaModel import
mvanniasingheTT Nov 4, 2024
dd11f47
Merge branch 'mvanniasinghe/mock_model' of https://github.com/tenstor…
mvanniasingheTT Nov 4, 2024
06c2188
delete_trace should do nothing when called in mock, in worker it is a…
mvanniasingheTT Nov 4, 2024
9ed760f
update prefill forward call
mvanniasingheTT Nov 4, 2024
23d37ad
added some comments about decode
mvanniasingheTT Nov 4, 2024
ce1893f
renamed file by accident
mvanniasingheTT Nov 4, 2024
8b4afcc
ran ruff
mvanniasingheTT Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Test VLLM via Mock Model
mvanniasingheTT marked this conversation as resolved.
Show resolved Hide resolved

This directory contains scripts to allow for rapid testing and development of benchmarking, evaluation,and stress testing procedures for available models through [vllm](https://github.com/tenstorrent/vllm/tree/dev)
mvanniasingheTT marked this conversation as resolved.
Show resolved Hide resolved

To run the mock offline inference script `mock_vllm_offline_inference_tt.py` follow the steps below:

## 1. Build Docker Container

Follow instructions under `evals/README.md` to build the docker container. To set the enviroment variables appropriatley for the latest supported versions of tt-metal and vllm, refer to [vllm/tt_metal](https://github.com/tenstorrent/vllm/blob/dev/tt_metal/README.md) when setting:
mvanniasingheTT marked this conversation as resolved.
Show resolved Hide resolved

```bash
export TT_METAL_COMMIT_SHA_OR_TAG=<tt-mettal-commit>
export VLLM_COMMIT_SHA=<vllm-commit>
```

## 2. Run The Docker Container

Add a volume mounting the `test` directory in the container before running with the following in the docker run command:

```bash
--volume $PWD/tests:/home/user/tests
```

## 3. Run The Mock Model

Once in the docker container, run the mock script with:

```bash
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python /home/user/tests/mock_vllm_offline_inference_tt.py
```
225 changes: 225 additions & 0 deletions tests/mock_vllm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import torch
import time
import copy
import os
import sys
from dataclasses import dataclass
from typing import List

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tt_metal.models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
from tt_metal.models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
)
from tt_metal.models.demos.t3000.llama2_70b.tt.llama_model_optimized import TtLlamaModel_optimized as TtLlamaModel
milank94 marked this conversation as resolved.
Show resolved Hide resolved
from tt_metal.models.demos.t3000.llama2_70b.tt.model_config import (
get_model_config,
)
def new_init_cache_enginer(self):
assert self.cache_config.num_gpu_blocks is not None

# Get cache path from TT model for caching kv blocks
self.cache_config.tt_cache_path = None

from vllm.worker.tt_worker import TTCacheEngine

self.cache_engine = TTCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config)
self.tt_cache = self.cache_engine.tt_cache

def new_allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device.
The assumption is that KV cache for a layer is packed into one tensor.
We will have a separate tensor for K and V.

In the mock implementation, device is always cpu
"""
# K and V each have the following shape: (num_blocks, num_kv_heads, block_size, head_size)
kv_cache_shape = (num_blocks, self.num_kv_heads, self.block_size, self.head_size)
kv_cache: List[torch.Tensor] = []
num_layers = self.num_attention_layers
if device == "cpu":
for _ in range(num_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# Zero-initialize CPU cache
cache_k = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
cache_v = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
kv_cache.append([cache_k, cache_v])
self.tt_cache = kv_cache # set tt_cache to just be cpu
return kv_cache


class MockModel(TtLlamaModelForGeneration):
# mock implementation in TtLlamaModelForGeneration
# see: tt-metal/models/demos/t3000/llama2_70b/tt/llama_generation.py
# inherits from llama at the moment since only this model is currently used with vllm
def __init__(self, configuration, state_dict, model_args, tt_args, paged_attention_config=None, vllm=False):

# Cache Weights setup
n_layers = model_args.num_layers or 80

self.params = copy.deepcopy(configuration)

# required to setup model config
self.llama_version = model_args.llama_version
self.max_batch_size = model_args.max_batch_size
self.max_kv_context_len = model_args.max_kv_context_len

self.mesh_device = tt_args.mesh_device

# Initial model_config is set in decode mode
# model conifg is required for vllm
model_config = get_model_config(
llama_version=self.llama_version,
max_batch_size=self.max_batch_size,
max_context_len=self.max_kv_context_len,
vllm=vllm,
)
self.model_config = model_config
del state_dict

@classmethod
def initialize_vllm_model(cls, hf_config, t3k_mesh_device, max_batch_size):
# TODO: pass in model args and tt args as parameters from vllm
@dataclass
class ModelArgs:
llama_version: str = None
ckpt_dir: str = None
max_batch_size: int = 32 # overwritten by max_num_seqs from vllm
num_layers: int = 80
max_kv_context_len: int = 131072

@dataclass
class TTArgs:
mesh_device: object = None
cache_path: str = None

# setup configs
llama_version = "llama3"
model_config, ckpt_dir, _, cache_path = setup_llama_env(
llama_version=llama_version,
)
# do not look for mesh device

# initialize arg classes
model_args = ModelArgs(llama_version=llama_version, ckpt_dir=ckpt_dir, max_batch_size=max_batch_size)
tt_args = TTArgs(mesh_device=t3k_mesh_device, cache_path=cache_path)

# do not load state dict

# TODO: delete this configuration setup once llama can directly accept hf_config
from models.demos.t3000.llama2_70b.reference.llama.llama.model import ModelArgs as ReferenceModelArgs
from pathlib import Path
import json

with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
configuration = ReferenceModelArgs(
max_seq_len=model_args.max_kv_context_len,
max_batch_size=model_args.max_batch_size,
**params,
)

return cls(
configuration=configuration, state_dict=None, model_args=model_args, tt_args=tt_args, vllm=True
)

def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None):
# mock out computing trace since TT/GPU device is not being used, only return logits from decode pass
tt_logits = self.decode_forward(tokens, start_pos, page_table, kv_cache) # mock out self.tt_model() call
return None, None, None, None, tt_logits, None

def decode_forward_trace(
self,
tokens: torch.Tensor,
start_pos: int,
trace_id,
tt_inp,
rot_mat,
cache_idxs_tt,
tt_logits,
page_table=None,
tt_page_table=None,
):
# mock out excuting the trace and only return logits directly
batch, seqlen = tokens.shape
logits = tt_logits
logits = logits[:batch] # Remove padded users

return logits

def delete_trace(self, trace_id):
# ttnn.release_trace(self.mesh_device, trace_id)
milank94 marked this conversation as resolved.
Show resolved Hide resolved
return

def prefill_forward_single_user(
self,
tokens: torch.Tensor,
start_pos: int,
user_id: int,
last_token_idx=None,
page_table=None,
kv_cache=None,
):
return self.decode_forward(tokens=tokens, start_pos=start_pos)

def decode_forward(
self,
tokens: torch.Tensor,
start_pos: int,
page_table=None,
kv_cache=None,
):
assert len(tokens.shape) == 2
batch, seqlen = tokens.shape
forward_start = time.time()
simulated_tps = 10000.0
simulated_duration = 1.0 / simulated_tps
# update the new tokens generated to the input id
# vocab_size = tokenizer.nwords
# logits: [batch, seqlen, vocab_size]
logits = torch.randn((batch, seqlen, 128256))
# send a token every period loops
EOT_ID = 128009
# EOS_ID = 128001
send_index = 200
send_token = EOT_ID
if start_pos is not None:
if isinstance(start_pos, int):
# if single input per batch
cache_idxs = torch.tensor([start_pos for _ in range(batch)], dtype=torch.int64)
else: # if start_pos is a tensor
# if start position is greater than index to send EOT
cache_idxs = start_pos.to(dtype=torch.int64)
send_token_mask = cache_idxs > send_index
# find positions where start pos passes send_index (ie we are done decording) + make 1D
batch_indices = torch.nonzero(send_token_mask).squeeze()
# assign a high logit at at the send _token index so model will select it and generate the EOT so that generation stops
logits[batch_indices, 0, send_token] = 100.0


actual_duration = time.time() - forward_start
# simulate forward latency
time.sleep(max(simulated_duration - actual_duration, 0))
return logits

def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None, prompt_lens=None):
_, seq_len = tokens.shape
if seq_len == 1:
return self.decode_forward(tokens, start_pos, page_table=page_table, kv_cache=kv_cache)
else:
return self.prefill_forward(
tokens, start_pos, page_table=page_table, kv_cache=kv_cache, prompt_lens=prompt_lens
)
Loading