diff --git a/vllm-tt-metal-llama3-70b/README.md b/vllm-tt-metal-llama3-70b/README.md index 1e6b77a..a00189d 100644 --- a/vllm-tt-metal-llama3-70b/README.md +++ b/vllm-tt-metal-llama3-70b/README.md @@ -19,7 +19,6 @@ This implementation supports Llama 3.1 70B with vLLM at https://github.com/tenst If first run setup has already been completed, start here. If first run setup has not been run please see the instructions below for [First run setup](#first-run-setup). - ### Docker Run - vLLM llama3 inference server Run the container from the project root at `tt-inference-server`: @@ -40,12 +39,33 @@ docker run \ ghcr.io/tenstorrent/tt-inference-server/tt-metal-llama3-70b-src-base-vllm:v0.0.1-tt-metal-685ef1303b5a-54b9157d852b ``` +By default the Docker container will start running the entrypoint command wrapped in `src/run_vllm_api_server.py`. +This can be run manually if you override the the container default command with an interactive shell via `bash`. +In an interactive shell you can start the vLLM API server via: ```bash # run server manually -python examples/offline_inference_tt.py +python src/run_vllm_api_server.py +``` + +The vLLM inference API server takes 3-5 minutes to start up (~40-60 minutes on first run when generating caches) then will start serving requests. To send HTTP requests to the inference server run the example scripts in a separate bash shell. + +### Example clients + +You can use `docker exec -it bash` to create a shell in the docker container or run the client scripts on the host (ensuring the correct port mappings and python dependencies): + +#### Run example clients from within Docker container: +```bash +# oneliner to enter interactive shell on most recently ran container +docker exec -it $(docker ps -q | head -n1) bash + +# inside interactive shell, run example clients script making requests to vLLM server: +cd ~/src +# this example runs a single request from alpaca eval, expecting and parsing the streaming response +python example_requests_client_alpaca_eval.py --stream True --n_samples 1 --num_full_iterations 1 --batch_size 1 +# this example runs a full-dataset stress test with 32 simultaneous users making requests +python example_requests_client_alpaca_eval.py --stream True --n_samples 805 --num_full_iterations 1 --batch_size 32 ``` -The vLLM inference API server takes 3-5 minutes to start up (~60 minutes on first run when generating caches) then will start serving requests. To send HTTP requests to the inference server run the example scripts in a separate bash shell. You can use `docker exec -it bash` to create a shell in the docker container or run the client scripts on the host ensuring the correct port mappings and python dependencies are available: ## First run setup @@ -80,11 +100,32 @@ sudo cpupower frequency-set -g performance ### 4. Docker image +Either download or build the Docker image using the docker file. + +#### Option A: GitHub Container Registry + ```bash # pull image from GHCR docker pull ghcr.io/tenstorrent/tt-inference-server/tt-metal-llama3-70b-src-base-vllm:v0.0.1-tt-metal-685ef1303b5a-54b9157d852b ``` +#### Option B: Build Docker Image + +```bash +# build image +export TT_METAL_DOCKERFILE_VERSION=v0.53.0-rc27 +export TT_METAL_COMMIT_SHA_OR_TAG=685ef1303b5abdfda63183fdd4fd6ed51b496833 +export TT_METAL_COMMIT_DOCKER_TAG=${TT_METAL_COMMIT_SHA_OR_TAG:0:12} +export TT_VLLM_COMMIT_SHA_OR_TAG=54b9157d852b0fa219613c00abbaa5a35f221049 +export TT_VLLM_COMMIT_DOCKER_TAG=${TT_VLLM_COMMIT_SHA_OR_TAG:0:12} +docker build \ + -t ghcr.io/tenstorrent/tt-inference-server/tt-metal-llama3-70b-src-base-vllm:v0.0.1-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG}-${TT_VLLM_COMMIT_DOCKER_TAG} \ + --build-arg TT_METAL_DOCKERFILE_VERSION=${TT_METAL_DOCKERFILE_VERSION} \ + --build-arg TT_METAL_COMMIT_SHA_OR_TAG=${TT_METAL_COMMIT_SHA_OR_TAG} \ + --build-arg TT_VLLM_COMMIT_SHA_OR_TAG=${TT_VLLM_COMMIT_SHA_OR_TAG} \ + . -f vllm.llama3.src.base.inference.v0.52.0.Dockerfile +``` + ### 5. Automated Setup: environment variables and weights files The script `vllm-tt-metal-llama3-70b/setup.sh` automates: diff --git a/vllm-tt-metal-llama3-70b/requirements.txt b/vllm-tt-metal-llama3-70b/requirements.txt new file mode 100644 index 0000000..730bd03 --- /dev/null +++ b/vllm-tt-metal-llama3-70b/requirements.txt @@ -0,0 +1,5 @@ +# inference server requirements +pyjwt==2.7.0 +requests==2.32.3 +datasets==3.1.0 +openai==1.53.1 diff --git a/vllm-tt-metal-llama3-70b/src/__init__.py b/vllm-tt-metal-llama3-70b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm-tt-metal-llama3-70b/src/example_openai_client_alpaca_eval.py b/vllm-tt-metal-llama3-70b/src/example_openai_client_alpaca_eval.py new file mode 100644 index 0000000..1affddc --- /dev/null +++ b/vllm-tt-metal-llama3-70b/src/example_openai_client_alpaca_eval.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +import threading +import logging +import time + +from openai import OpenAI + +from example_requests_client_alpaca_eval import ( + parse_args, + get_api_base_url, + load_dataset_samples, + get_authorization, + test_api_call_threaded_full_queue, +) + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Thread-safe data collection +responses_lock = threading.Lock() +responses = [] + + +def call_inference_api(prompt, response_idx, stream=True, headers=None, client=None): + # set API prompt and optional parameters + req_time = time.time() + full_text = "" + num_tokens = 0 + try: + # Use OpenAI client to call API + completion = client.completions.create( + model="meta-llama/Meta-Llama-3.1-70B", + prompt=prompt, + temperature=1, + max_tokens=2048, + top_p=0.9, + stop=["<|eot_id|>"], + stream=stream, + ) + if stream: + for event in completion: + if event.choices[0].finish_reason is not None: + break + if num_tokens == 0: + first_token_time = time.time() + ttft = first_token_time - req_time + num_tokens += 1 + content = event.choices[0].text + full_text += content + else: + full_text = completion.choices[0].text + # Assuming tokens were returned with response (using len to mock token length) + num_tokens = len(full_text.split()) + first_token_time = req_time # Simplify for non-stream + ttft = time.time() - req_time + except Exception as e: + logger.error(f"Error calling API: {e}") + elapsed_time = time.time() - req_time + logger.error( + f"Before error: elapsed_time={elapsed_time}, num_tokens: {num_tokens}, full_text: {full_text}" + ) + full_text = "ERROR" + num_tokens = 0 + first_token_time = time.time() + ttft = 0.001 + + num_tokens = max(num_tokens, 2) + throughput_time = max(time.time() - first_token_time, 0.0001) + response_data = { + "response_idx": response_idx, + "prompt": prompt, + "response": full_text, + "num_tokens": num_tokens, + "tps": (num_tokens - 1) / throughput_time, + "ttft": ttft, + } + + with responses_lock: + responses.append(response_data) + return response_data + + +if __name__ == "__main__": + logger.info( + "Note: OpenAI API client adds additional latency of ~10 ms to the API call." + ) + args = parse_args() + prompts = load_dataset_samples(args.n_samples) + headers = {"Authorization": f"Bearer {get_authorization()}"} + base_url = get_api_base_url() + logging.info(f"BASE_API_URL: {base_url}") + client = OpenAI( + base_url=base_url, + api_key=get_authorization(), + ) + test_api_call_threaded_full_queue( + prompts=prompts, + batch_size=args.batch_size, + num_full_iterations=args.num_full_iterations, + call_func=call_inference_api, + call_func_kwargs={ + "stream": args.stream, + "headers": headers, + "client": client, + }, + ) diff --git a/vllm-tt-metal-llama3-70b/src/example_requests_client_alpaca_eval.py b/vllm-tt-metal-llama3-70b/src/example_requests_client_alpaca_eval.py new file mode 100644 index 0000000..7d90d22 --- /dev/null +++ b/vllm-tt-metal-llama3-70b/src/example_requests_client_alpaca_eval.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +import os +import getpass +import threading +import logging +import json +import argparse +import time +from datetime import datetime +import requests +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + +from datasets import load_dataset +import jwt + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run Alpaca Evaluation Inference.") + parser.add_argument( + "--stream", type=bool, default=False, help="Set stream to True or False." + ) + parser.add_argument( + "--n_samples", + type=int, + default=805, + help="Number of samples to use from the dataset.", + ) + parser.add_argument( + "--num_full_iterations", + type=int, + default=100, + help="Number of full iterations to run over the dataset.", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size for concurrent requests." + ) + return parser.parse_args() + + +def load_dataset_samples(n_samples): + # Load alpaca_eval dataset with specified number of samples + alpaca_ds = load_dataset( + "tatsu-lab/alpaca_eval", + "alpaca_eval", + split=f"eval[:{n_samples}]", + ) + return alpaca_ds["instruction"] + + +def get_authorization(): + authorization = os.getenv("AUTHORIZATION", None) + if authorization is None: + jwt_secret = os.getenv("JWT_SECRET", None) + if jwt_secret is None: + raise ValueError( + "Neither AUTHORIZATION or JWT_SECRET environment variables are set." + ) + json_payload = json.loads('{"team_id": "tenstorrent", "token_id":"debug-test"}') + encoded_jwt = jwt.encode(json_payload, jwt_secret, algorithm="HS256") + authorization = f"{encoded_jwt}" + return authorization + + +def get_api_base_url(): + DEPLOY_URL = os.getenv("DEPLOY_URL", "http://127.0.0.1") + base_url = f"{DEPLOY_URL}:{os.getenv('SERVICE_PORT', '8000')}/v1" + return base_url + + +def get_api_url(): + base_url = get_api_base_url() + api_url = f"{base_url}/completions" + return api_url + + +# Thread-safe data collection +responses_lock = threading.Lock() +responses = [] + + +def call_inference_api(prompt, response_idx, stream=True, headers=None, api_url=None): + # set API prompt and optional parameters + json_data = { + "model": "meta-llama/Meta-Llama-3.1-70B", + "prompt": prompt, + "temperature": 1, + "top_k": 20, + "top_p": 0.9, + "max_tokens": 2048, + "stream": stream, + "stop": ["<|eot_id|>"], + } + req_time = time.time() + # using requests stream=True, make sure to set a timeout + response = requests.post( + api_url, json=json_data, headers=headers, stream=stream, timeout=600 + ) + # Handle chunked response + full_text = "" + num_tokens = 0 + if stream: + if response.headers.get("transfer-encoding") == "chunked": + for line in response.iter_lines(decode_unicode=True): + # Process each line of data as it's received + if line: + # Remove the 'data: ' prefix + if line.startswith("data: "): + if num_tokens == 0: + first_token_time = time.time() + ttft = first_token_time - req_time + num_tokens += 1 + data_str = line[len("data: ") :].strip() + if data_str == "[DONE]": + num_tokens -= 1 + break + try: + # Parse the JSON data + data = json.loads(data_str) + # Extract text from the 'choices' field + content = data["choices"][0].get("text", "") + full_text += content + except json.JSONDecodeError as e: + print(f"Failed to decode JSON: {e}") + continue + else: + # If not chunked, you can access the entire response body at once + logger.info(response.text) + raise ValueError("Response is not chunked") + + else: + full_text = response.text + # TODO: get tokens from tokenizer + num_tokens = 2 + + num_tokens = max(num_tokens, 2) + throughput_time = max(time.time() - first_token_time, 0.0001) + response_data = { + "response_idx": response_idx, + "prompt": prompt, + "response": full_text, + "num_tokens": num_tokens, + "tps": (num_tokens - 1) / throughput_time, + "ttft": ttft, + } + + with responses_lock: + responses.append(response_data) + return response_data + + +def check_json_fpath(json_fpath): + directory = os.path.dirname(json_fpath) + user = getpass.getuser() + if os.access(directory, os.W_OK): + try: + with open(json_fpath, "w") as f: + f.write("") # Attempt to write an empty string to the file + logger.info(f"The file '{json_fpath}' can be created and is writable.") + return True, "" + except IOError as err: + err_msg = f"Cannot write to the file '{json_fpath}'. Reason: {err}" + else: + err_msg = ( + f"User:={user} cannot write to file:={json_fpath} in directory:={directory}" + ) + logger.error(err_msg) + return False, err_msg + + +def test_api_call_threaded_full_queue( + prompts, + batch_size, + num_full_iterations, + call_func, + call_func_kwargs, +): + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + cache_root = Path(os.getenv("CACHE_ROOT", ".")) + json_fpath = cache_root / f"alpaca_eval_responses_{timestamp}.json" + logger.info(f"Will write output to: {json_fpath}") + can_write, err_msg = check_json_fpath(json_fpath) + if not can_write: + err_msg += ( + f"\nNote: CACHE_ROOT:={cache_root}, consider setting in this shell to $PWD" + ) + assert can_write, err_msg + with open(json_fpath, "a") as f: + f.write("[\n") + + total_instructions = len(prompts) * num_full_iterations + response_counter = 0 + logger.info( + f"Running {total_instructions} prompts in full queue with batch size {batch_size}." + ) + with ThreadPoolExecutor(max_workers=batch_size) as executor: + futures = [] + for _ in range(num_full_iterations): + for response_idx, instruction in enumerate(prompts): + future = executor.submit( + call_func, instruction, response_idx, **call_func_kwargs + ) + futures.append(future) + + for future in as_completed(futures): + try: + response_data = future.result() + # Write the response data to the JSONL file + with responses_lock: + with open(json_fpath, "a") as f: + if response_counter > 0: + f.write(",") + json.dump(response_data, f, indent=4) + response_counter += 1 + logger.info( + f"Processed {response_counter}/{total_instructions} responses. Avg. TPS: {response_data['tps']:.2f}, TTFT: {response_data['ttft']:.2f}, Num Tokens: {response_data['num_tokens']}" + ) + except Exception as e: + logger.error(f"Error processing a response: {e}") + + logger.info(f"Finished all requests, total responses: {response_counter}") + with open(json_fpath, "a") as f: + f.write("\n]") + + +if __name__ == "__main__": + args = parse_args() + prompts = load_dataset_samples(args.n_samples) + headers = {"Authorization": f"Bearer {get_authorization()}"} + api_url = get_api_url() + logging.info(f"API_URL: {api_url}") + test_api_call_threaded_full_queue( + prompts=prompts, + batch_size=args.batch_size, + num_full_iterations=args.num_full_iterations, + call_func=call_inference_api, + call_func_kwargs={ + "stream": args.stream, + "headers": headers, + "api_url": api_url, + }, + ) diff --git a/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py new file mode 100644 index 0000000..3c6968d --- /dev/null +++ b/vllm-tt-metal-llama3-70b/src/run_vllm_api_server.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +import os +import sys +import runpy +import json + +import jwt +from vllm import ModelRegistry + +# importing from tt-metal install path +from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration + +# register the model +ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration) + + +def get_encoded_api_key(jwt_secret): + if jwt_secret is None: + return None + json_payload = json.loads('{"team_id": "tenstorrent", "token_id":"debug-test"}') + encoded_jwt = jwt.encode(json_payload, jwt_secret, algorithm="HS256") + return encoded_jwt + + +def main(): + # vLLM CLI arguments + args = { + "model": "meta-llama/Meta-Llama-3.1-70B", + "block_size": "64", + "max_num_seqs": "32", + "max_model_len": "131072", + "max_num_batched_tokens": "131072", + "num_scheduler_steps": "10", + "port": os.getenv("SERVICE_PORT", "8000"), + "download-dir": os.getenv("CACHE_DIR", None), + "api-key": get_encoded_api_key(os.getenv("JWT_SECRET", None)), + } + for key, value in args.items(): + if value is not None: + sys.argv.extend(["--" + key, value]) + + # runpy uses the same process and environment so the registered models are available + runpy.run_module("vllm.entrypoints.openai.api_server", run_name="__main__") + + +if __name__ == "__main__": + main() diff --git a/vllm-tt-metal-llama3-70b/vllm.llama3.src.base.inference.v0.52.0.Dockerfile b/vllm-tt-metal-llama3-70b/vllm.llama3.src.base.inference.v0.52.0.Dockerfile index 1e92105..99863b9 100644 --- a/vllm-tt-metal-llama3-70b/vllm.llama3.src.base.inference.v0.52.0.Dockerfile +++ b/vllm-tt-metal-llama3-70b/vllm.llama3.src.base.inference.v0.52.0.Dockerfile @@ -96,4 +96,13 @@ RUN git clone https://github.com/tenstorrent/vllm.git ${vllm_dir}\ # extra vllm dependencies RUN /bin/bash -c "source ${PYTHON_ENV_DIR}/bin/activate && pip install compressed-tensors" -WORKDIR ${vllm_dir} +ENV PYTHONPATH=$PYTHONPATH:$vllm_dir +ARG APP_DIR="${HOME_DIR}/app" +WORKDIR ${APP_DIR} +COPY --chown=user:user "src" "${APP_DIR}/src" +COPY --chown=user:user "requirements.txt" "${APP_DIR}/requirements.txt" +RUN /bin/bash -c "source ${PYTHON_ENV_DIR}/bin/activate \ +&& pip install --default-timeout=240 --no-cache-dir -r requirements.txt" + +WORKDIR "${APP_DIR}/src" +CMD ["/bin/bash", "-c", "source ${PYTHON_ENV_DIR}/bin/activate && python run_vllm_api_server.py"]