From 2a33f86209544f59e202deff75e74eb3a97b8d9d Mon Sep 17 00:00:00 2001 From: Xida Date: Thu, 26 Dec 2024 17:38:02 +0000 Subject: [PATCH] clean up shortfin llm integration tests --- .../integration_tests/llm/device_settings.py | 17 ++ .../integration_tests/llm/model_management.py | 227 ++++++++++++++++++ .../llm/server_management.py | 97 ++++++++ .../llm/shortfin/conftest.py | 227 ++++++------------ .../llm/shortfin/cpu_llm_server_test.py | 198 --------------- .../llm/shortfin/test_llm_server.py | 104 ++++++++ 6 files changed, 522 insertions(+), 348 deletions(-) create mode 100644 app_tests/integration_tests/llm/device_settings.py create mode 100644 app_tests/integration_tests/llm/model_management.py create mode 100644 app_tests/integration_tests/llm/server_management.py delete mode 100644 app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py create mode 100755 app_tests/integration_tests/llm/shortfin/test_llm_server.py diff --git a/app_tests/integration_tests/llm/device_settings.py b/app_tests/integration_tests/llm/device_settings.py new file mode 100644 index 000000000..a7983665f --- /dev/null +++ b/app_tests/integration_tests/llm/device_settings.py @@ -0,0 +1,17 @@ +from typing import Tuple +from dataclasses import dataclass + + +@dataclass +class DeviceSettings: + compile_flags: Tuple[str] + server_flags: Tuple[str] + + +CPU = DeviceSettings( + compile_flags=( + "-iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + ), + server_flags=("--device=local-task",), +) diff --git a/app_tests/integration_tests/llm/model_management.py b/app_tests/integration_tests/llm/model_management.py new file mode 100644 index 000000000..08088b1af --- /dev/null +++ b/app_tests/integration_tests/llm/model_management.py @@ -0,0 +1,227 @@ +"""Module for managing model artifacts through various processing stages.""" +import logging +from pathlib import Path +import subprocess +from dataclasses import dataclass +from typing import Optional, Tuple +from enum import Enum, auto + +logger = logging.getLogger(__name__) + + +class ModelSource(Enum): + HUGGINGFACE = auto() + LOCAL = auto() + AZURE = auto() + + +@dataclass +class AzureConfig: + """Configuration for Azure blob storage downloads.""" + + account_name: str + container_name: str + blob_path: str + auth_mode: str = "key" + + +@dataclass +class ModelConfig: + """Configuration for model source and settings.""" + + model_file: str + tokenizer_id: str + batch_sizes: Tuple[int, ...] + device_settings: "DeviceSettings" + source: ModelSource + repo_id: Optional[str] = None + local_path: Optional[Path] = None + azure_config: Optional[AzureConfig] = None + + def __post_init__(self): + if self.source == ModelSource.HUGGINGFACE and not self.repo_id: + raise ValueError("repo_id required for HuggingFace models") + elif self.source == ModelSource.LOCAL and not self.local_path: + raise ValueError("local_path required for local models") + elif self.source == ModelSource.AZURE and not self.azure_config: + raise ValueError("azure_config required for Azure models") + + +@dataclass +class ModelArtifacts: + """Container for all paths related to model artifacts.""" + + weights_path: Path + tokenizer_path: Path + mlir_path: Path + vmfb_path: Path + config_path: Path + + +class ModelStageManager: + """Manages different stages of model processing with caching behavior.""" + + def __init__(self, base_dir: Path, config: ModelConfig): + self.base_dir = base_dir + self.config = config + self.model_dir = self._get_model_dir() + self.model_dir.mkdir(parents=True, exist_ok=True) + + def _get_model_dir(self) -> Path: + """Creates and returns appropriate model directory based on source.""" + if self.config.source == ModelSource.HUGGINGFACE: + return self.base_dir / self.config.repo_id.replace("/", "_") + elif self.config.source == ModelSource.LOCAL: + return self.base_dir / "local" / self.config.local_path.stem + elif self.config.source == ModelSource.AZURE: + return ( + self.base_dir + / "azure" + / self.config.azure_config.blob_path.replace("/", "_") + ) + raise ValueError(f"Unsupported model source: {self.config.source}") + + def _download_from_huggingface(self) -> Path: + """Downloads model from HuggingFace.""" + model_path = self.model_dir / self.config.model_file + if not model_path.exists(): + logger.info(f"Downloading model {self.config.repo_id} from HuggingFace") + subprocess.run( + f"huggingface-cli download --local-dir {self.model_dir} {self.config.repo_id} {self.config.model_file}", + shell=True, + check=True, + ) + return model_path + + def _copy_from_local(self) -> Path: + """Copies model from local filesystem.""" + import shutil + + model_path = self.model_dir / self.config.model_file + if not model_path.exists(): + logger.info(f"Copying local model from {self.config.local_path}") + shutil.copy2(self.config.local_path, model_path) + return model_path + + def _download_from_azure(self) -> Path: + """Downloads model from Azure blob storage.""" + model_path = self.model_dir / self.config.model_file + if not model_path.exists(): + logger.info( + f"Downloading model from Azure blob storage: {self.config.azure_config.blob_path}" + ) + subprocess.run( + [ + "az", + "storage", + "blob", + "download", + "--account-name", + self.config.azure_config.account_name, + "--container-name", + self.config.azure_config.container_name, + "--name", + self.config.azure_config.blob_path, + "--file", + str(model_path), + "--auth-mode", + self.config.azure_config.auth_mode, + ], + check=True, + ) + return model_path + + def prepare_tokenizer(self) -> Path: + """Downloads and prepares tokenizer.""" + tokenizer_path = self.model_dir / "tokenizer.json" + if not tokenizer_path.exists(): + logger.info(f"Downloading tokenizer {self.config.tokenizer_id}") + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id) + tokenizer.save_pretrained(self.model_dir) + return tokenizer_path + + def export_model(self, weights_path: Path) -> Tuple[Path, Path]: + """Exports model to MLIR format.""" + bs_string = ",".join(map(str, self.config.batch_sizes)) + mlir_path = self.model_dir / "model.mlir" + config_path = self.model_dir / "config.json" + + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) + + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + "--block-seq-stride=16", + f"--{weights_path.suffix.strip('.')}-file={weights_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + f"--bs={bs_string}", + ], + check=True, + ) + + logger.info(f"Model successfully exported to {mlir_path}") + return mlir_path, config_path + + def compile_model(self, mlir_path: Path) -> Path: + """Compiles model to VMFB format.""" + vmfb_path = self.model_dir / "model.vmfb" + logger.info(f"Compiling model to {vmfb_path}") + + compile_command = [ + "iree-compile", + str(mlir_path), + "-o", + str(vmfb_path), + ] + compile_command.extend(self.config.device_settings.compile_flags) + + subprocess.run(compile_command, check=True) + logger.info(f"Model successfully compiled to {vmfb_path}") + return vmfb_path + + +class ModelProcessor: + """Main interface for processing models through all stages.""" + + def __init__(self, base_dir: Path): + self.base_dir = Path(base_dir) + + def process_model(self, config: ModelConfig) -> ModelArtifacts: + """Process model through all stages and return paths to all artifacts.""" + manager = ModelStageManager(self.base_dir, config) + + # Stage 1: Download weights and tokenizer (cached) + if config.source == ModelSource.HUGGINGFACE: + weights_path = manager._download_from_huggingface() + elif config.source == ModelSource.LOCAL: + weights_path = manager._copy_from_local() + elif config.source == ModelSource.AZURE: + weights_path = manager._download_from_azure() + else: + raise ValueError(f"Unsupported model source: {config.source}") + + tokenizer_path = manager.prepare_tokenizer() + + # Stage 2: Export model (fresh every time) + mlir_path, config_path = manager.export_model(weights_path) + + # Stage 3: Compile model (fresh every time) + vmfb_path = manager.compile_model(mlir_path) + + return ModelArtifacts( + weights_path=weights_path, + tokenizer_path=tokenizer_path, + mlir_path=mlir_path, + vmfb_path=vmfb_path, + config_path=config_path, + ) diff --git a/app_tests/integration_tests/llm/server_management.py b/app_tests/integration_tests/llm/server_management.py new file mode 100644 index 000000000..7cbfc390b --- /dev/null +++ b/app_tests/integration_tests/llm/server_management.py @@ -0,0 +1,97 @@ +"""Handles server lifecycle and configuration.""" +import json +import socket +from contextlib import closing +from dataclasses import dataclass, field +import subprocess +import time +import requests +from pathlib import Path +import sys +from typing import Optional + +from .device_settings import DeviceSettings +from .model_management import ModelArtifacts + + +@dataclass +class ServerConfig: + """Configuration for server instance.""" + + port: int + artifacts: ModelArtifacts + device_settings: DeviceSettings + + # things we need to write to config + prefix_sharing_algorithm: str = "none" + + +class ServerManager: + """Manages server lifecycle and configuration.""" + + @staticmethod + def find_available_port() -> int: + """Finds an available port for the server.""" + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + def __init__(self, config: ServerConfig): + self.config = config + + def write_config(self) -> Path: + """Creates server config by extending the exported model config.""" + source_config_path = self.config.artifacts.config_path + server_config_path = ( + source_config_path.parent + / f"server_config_{self.config.prefix_sharing_algorithm}.json" + ) + + # Read the exported config as base + with open(source_config_path) as f: + config = json.load(f) + + # Update with server-specific settings + config.update( + { + "paged_kv_cache": { + "prefix_sharing_algorithm": self.config.prefix_sharing_algorithm + } + } + ) + + # Write the extended config + with open(server_config_path, "w") as f: + json.dump(config, f) + 6 + return server_config_path + + def start(self) -> subprocess.Popen: + """Starts the server process.""" + config_path = self.write_config() + cmd = [ + sys.executable, + "-m", + "shortfin_apps.llm.server", + f"--tokenizer_json={self.config.artifacts.tokenizer_path}", + f"--model_config={config_path}", + f"--vmfb={self.config.artifacts.vmfb_path}", + f"--parameters={self.config.artifacts.weights_path}", + f"--port={self.config.port}", + ] + cmd.extend(self.config.device_settings.server_flags) + process = subprocess.Popen(cmd) + self._wait_for_server(timeout=10) + return process + + def _wait_for_server(self, timeout: int = 10): + """Waits for server to be ready.""" + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"http://localhost:{self.config.port}/health") + return + except requests.exceptions.ConnectionError: + time.sleep(1) + raise TimeoutError(f"Server failed to start within {timeout} seconds") diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 55c9e8bdc..d58b604bd 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -1,165 +1,92 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import hashlib -import json -import logging +"""Test fixtures and configurations.""" import pytest +from pathlib import Path +import hashlib -pytest.importorskip("transformers") -from ..utils import ( - download_huggingface_model, - download_tokenizer, - export_paged_llm_v1, - compile_model, - find_available_port, - start_llm_server, - start_log_group, - end_log_group, +from ..model_management import ( + ModelProcessor, + ModelConfig, + ModelSource, + AzureConfig, + ModelArtifacts, ) - -logger = logging.getLogger(__name__) - -MODEL_DIR_CACHE = {} +from ..server_management import ServerManager, ServerConfig +from .. import device_settings + +# Example model configurations +TEST_MODELS = { + "open_llama_3b": ModelConfig( + source=ModelSource.HUGGINGFACE, + repo_id="SlyEcho/open_llama_3b_v2_gguf", + model_file="open-llama-3b-v2-f16.gguf", + tokenizer_id="openlm-research/open_llama_3b_v2", + batch_sizes=(1, 4), + device_settings=device_settings.CPU, + ), + "llama3.1_8b": ModelConfig( + source=ModelSource.LOCAL, + local_path=Path("/data/llama3.1/8b/llama8b_f16.irpa"), + model_file="llama8b_f16.irpa", + tokenizer_id="NousResearch/Meta-Llama-3.1-8B", + batch_sizes=(1, 4), + device_settings=device_settings.CPU, + ), + "azure_llama": ModelConfig( + source=ModelSource.AZURE, + azure_config=AzureConfig( + account_name="sharkblobs", + container_name="halo-models", + blob_path="llm-dev/llama3_8b/8b_f16.irpa", + ), + model_file="azure-llama.irpa", + tokenizer_id="openlm-research/open_llama_3b_v2", + batch_sizes=(1, 4), + device_settings=device_settings.CPU, + ), +} @pytest.fixture(scope="module") -def model_test_dir(request, tmp_path_factory): - """Prepare model artifacts for starting the LLM server. - - Args: - request (FixtureRequest): The following params are accepted: - - repo_id (str): The Hugging Face repo ID. - - model_file (str): The model file to download. - - tokenizer_id (str): The tokenizer ID to download. - - settings (dict): The settings for sharktank export. - - batch_sizes (list): The batch sizes to use for the model. - tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. - - Yields: - Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. - """ - logger.info( - "Preparing model artifacts..." + start_log_group("Preparing model artifacts") - ) - - param_key = hashlib.md5(str(request.param).encode()).hexdigest() - if (directory := MODEL_DIR_CACHE.get(param_key)) is not None: - logger.info( - f"Reusing existing model artifacts directory: {directory}" + end_log_group() +def model_artifacts(tmp_path_factory, request): + """Prepares model artifacts in a cached directory.""" + model_config = TEST_MODELS[request.param] + cache_key = hashlib.md5(str(model_config).encode()).hexdigest() + + cache_dir = tmp_path_factory.mktemp("model_cache") + model_dir = cache_dir / cache_key + + # Return cached artifacts if available + if model_dir.exists(): + return ModelArtifacts( + weights_path=model_dir / model_config.model_file, + tokenizer_path=model_dir / "tokenizer.json", + mlir_path=model_dir / "model.mlir", + vmfb_path=model_dir / "model.vmfb", + config_path=model_dir / "config.json", ) - yield MODEL_DIR_CACHE[param_key] - return - - repo_id = request.param["repo_id"] - model_file = request.param["model_file"] - tokenizer_id = request.param["tokenizer_id"] - settings = request.param["settings"] - batch_sizes = request.param["batch_sizes"] - tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") - - # Download model if it doesn't exist - model_path = tmp_dir / model_file - download_huggingface_model(tmp_dir, repo_id, model_file) - - # Set up tokenizer if it doesn't exist - download_tokenizer(tmp_dir, tokenizer_id) - - # Export model - mlir_path = tmp_dir / "model.mlir" - config_path = tmp_dir / "config.json" - export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) - - # Compile model - vmfb_path = tmp_dir / "model.vmfb" - compile_model(mlir_path, vmfb_path, settings) - logger.info("Model artifacts setup successfully" + end_log_group()) - MODEL_DIR_CACHE[param_key] = tmp_dir - yield tmp_dir + # Process model and create artifacts + processor = ModelProcessor(cache_dir) + return processor.process_model(model_config) @pytest.fixture(scope="module") -def write_config(request, model_test_dir): - batch_sizes = request.param["batch_sizes"] - prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"] - - # Construct the new config filename - config_path = ( - model_test_dir - / f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json" - ) - - # Read the base config file - base_config_path = model_test_dir / "config.json" - with open(base_config_path, "r") as f: - config = json.load(f) - - # Override specific fields - config.update( - { - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "paged_kv_cache": { - **config.get( - "paged_kv_cache", {} - ), # Preserve other paged_kv_cache settings - "prefix_sharing_algorithm": prefix_sharing_algorithm, - }, - } +def server(model_artifacts, request): + """Starts and manages the test server.""" + model_id = request.param["model"] + model_config = TEST_MODELS[model_id] + + server_config = ServerConfig( + port=ServerManager.find_available_port(), + artifacts=model_artifacts, + device_settings=model_config.device_settings, + prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"), ) - logger.info(f"Saving edited config to: {config_path}\n") - logger.info(f"Config: {json.dumps(config, indent=2)}") - with open(config_path, "w") as f: - json.dump(config, f) - - yield config_path - - -@pytest.fixture(scope="module") -def available_port(): - return find_available_port() - - -@pytest.fixture(scope="module") -def llm_server(request, model_test_dir, write_config, available_port): - """Start the LLM server. - Args: - request (FixtureRequest): The following params are accepted: - - model_file (str): The model file to download. - - settings (dict): The settings for starting the server. - model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. - available_port (int): The available port to start the server on. + server_manager = ServerManager(server_config) + process = server_manager.start() - Yields: - subprocess.Popen: The server process that was started. - """ - logger.info("Starting LLM server..." + start_log_group("Starting LLM server")) - tmp_dir = model_test_dir - config_path = write_config + yield process, server_config.port - model_file = request.param["model_file"] - settings = request.param["settings"] - - tokenizer_path = tmp_dir / "tokenizer.json" - vmfb_path = tmp_dir / "model.vmfb" - parameters_path = tmp_dir / model_file - - # Start llm server - server_process = start_llm_server( - available_port, - tokenizer_path, - config_path, - vmfb_path, - parameters_path, - settings, - ) - logger.info("LLM server started!" + end_log_group()) - yield server_process - # Teardown: kill the server - server_process.terminate() - server_process.wait() + process.terminate() + process.wait() diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py deleted file mode 100644 index 021bbf4f0..000000000 --- a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import concurrent.futures -import logging -import os -import pytest -import requests -import uuid - -from ..utils import AccuracyValidationException, start_log_group, end_log_group - -logger = logging.getLogger(__name__) - -CPU_SETTINGS = { - "device_flags": [ - "-iree-hal-target-backends=llvm-cpu", - "--iree-llvmcpu-target-cpu=host", - ], - "device": "local-task", -} -IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100") -gpu_settings = { - "device_flags": [ - "-iree-hal-target-backends=rocm", - f"--iree-hip-target={IREE_HIP_TARGET}", - ], - "device": "hip", -} - - -def do_generate(prompt, port, concurrent_requests=1): - logger.info("Generating request...") - headers = {"Content-Type": "application/json"} - # Create a GenerateReqInput-like structure - data = { - "text": prompt, - "sampling_params": {"max_completion_tokens": 15, "temperature": 0.7}, - "rid": uuid.uuid4().hex, - "return_logprob": False, - "logprob_start_len": -1, - "top_logprobs_num": 0, - "return_text_in_logprobs": False, - "stream": False, - } - logger.info("Prompt text:") - logger.info(data["text"]) - BASE_URL = f"http://localhost:{port}" - - response_data = [] - with concurrent.futures.ThreadPoolExecutor( - max_workers=concurrent_requests - ) as executor: - futures = [ - executor.submit( - lambda: requests.post( - f"{BASE_URL}/generate", headers=headers, json=data - ) - ) - for _ in range(concurrent_requests) - ] - for future in concurrent.futures.as_completed(futures): - response = future.result() - - logger.info(f"Generate endpoint status code: {response.status_code}") - if response.status_code == 200: - logger.info("Generated text:") - data = response.text - assert data.startswith("data: ") - data = data[6:] - assert data.endswith("\n\n") - data = data[:-2] - logger.info(data) - response_data.append(data) - else: - response.raise_for_status() - - return response_data - - -@pytest.mark.parametrize( - "model_test_dir,write_config,llm_server", - [ - pytest.param( - { - "repo_id": "SlyEcho/open_llama_3b_v2_gguf", - "model_file": "open-llama-3b-v2-f16.gguf", - "tokenizer_id": "openlm-research/open_llama_3b_v2", - "settings": CPU_SETTINGS, - "batch_sizes": [1, 4], - }, - {"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"}, - {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ), - pytest.param( - { - "repo_id": "SlyEcho/open_llama_3b_v2_gguf", - "model_file": "open-llama-3b-v2-f16.gguf", - "tokenizer_id": "openlm-research/open_llama_3b_v2", - "settings": CPU_SETTINGS, - "batch_sizes": [1, 4], - }, - {"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"}, - {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ), - ], - indirect=True, -) -def test_llm_server(llm_server, available_port): - # Here you would typically make requests to your server - # and assert on the responses - assert llm_server.poll() is None - PROMPT = "1 2 3 4 5 " - expected_output_prefix = "6 7 8" - logger.info( - "Sending HTTP Generation Request" - + start_log_group("Sending HTTP Generation Request") - ) - output = do_generate(PROMPT, available_port)[0] - # log to GITHUB_STEP_SUMMARY if we are in a GitHub Action - if "GITHUB_ACTION" in os.environ: - with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: - # log prompt - f.write("LLM results:\n") - f.write(f"- llm_prompt:`{PROMPT}`\n") - f.write(f"- llm_output:`{output}`\n") - if not output.startswith(expected_output_prefix): - raise AccuracyValidationException( - f"Expected '{output}' to start with '{expected_output_prefix}'" - ) - logger.info("HTTP Generation Request Successful" + end_log_group()) - - -@pytest.mark.parametrize( - "model_test_dir,write_config,llm_server", - [ - pytest.param( - { - "repo_id": "SlyEcho/open_llama_3b_v2_gguf", - "model_file": "open-llama-3b-v2-f16.gguf", - "tokenizer_id": "openlm-research/open_llama_3b_v2", - "settings": CPU_SETTINGS, - "batch_sizes": [1, 4], - }, - {"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"}, - {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ), - pytest.param( - { - "repo_id": "SlyEcho/open_llama_3b_v2_gguf", - "model_file": "open-llama-3b-v2-f16.gguf", - "tokenizer_id": "openlm-research/open_llama_3b_v2", - "settings": CPU_SETTINGS, - "batch_sizes": [1, 4], - }, - {"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"}, - {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, - ), - ], - indirect=True, -) -@pytest.mark.parametrize( - "concurrent_requests", - [2, 4, 8], -) -@pytest.mark.xfail( - raises=AccuracyValidationException, - reason="Concurreny issues in Shortfin batch processing", -) -def test_llm_server_concurrent(llm_server, available_port, concurrent_requests): - logger.info("Testing concurrent invocations") - - assert llm_server.poll() is None - PROMPT = "1 2 3 4 5 " - expected_output_prefix = "6 7 8" - logger.info( - "Sending HTTP Generation Request" - + start_log_group("Sending HTTP Generation Request") - ) - outputs = do_generate(PROMPT, available_port, concurrent_requests) - - for output in outputs: - # log to GITHUB_STEP_SUMMARY if we are in a GitHub Action - if "GITHUB_ACTION" in os.environ: - with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: - # log prompt - f.write("LLM results:\n") - f.write(f"- llm_prompt:`{PROMPT}`\n") - f.write(f"- llm_output:`{output}`\n") - - if not output.startswith(expected_output_prefix): - raise AccuracyValidationException( - f"Expected '{output}' to start with '{expected_output_prefix}'" - ) - logger.info("HTTP Generation Request Successful" + end_log_group()) diff --git a/app_tests/integration_tests/llm/shortfin/test_llm_server.py b/app_tests/integration_tests/llm/shortfin/test_llm_server.py new file mode 100755 index 000000000..ebec0d8be --- /dev/null +++ b/app_tests/integration_tests/llm/shortfin/test_llm_server.py @@ -0,0 +1,104 @@ +"""Main test module for LLM server functionality.""" +import pytest +import requests +import uuid +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class TestLLMServer: + """Test suite for LLM server functionality.""" + + @pytest.mark.parametrize( + "model_artifacts,server", + [ + ("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "none"}), + ("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "trie"}), + ("llama3.1_8b", {"model": "llama3.1_8b", "prefix_sharing": "none"}), + ("llama3.1_8b", {"model": "llama3.1_8b", "prefix_sharing": "trie"}), + ], + ids=[ + "open_llama_3b_none", + "open_llama_3b_trie", + "llama31_8b_none", + "llama31_8b_trie", + ], + indirect=True, + ) + def test_basic_generation(self, server: tuple[Any, int]) -> None: + """Tests basic text generation capabilities. + + Args: + server: Tuple of (process, port) from server fixture + """ + process, port = server + assert process.poll() is None, "Server process terminated unexpectedly" + + response = self._generate("1 2 3 4 5 ", port) + assert response.startswith("6 7 8"), f"Unexpected response: {response}" + + @pytest.mark.parametrize( + "model_artifacts,server", + [ + ("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "none"}), + ("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "trie"}), + ], + indirect=True, + ) + @pytest.mark.parametrize("concurrent_requests", [2, 4, 8]) + def test_concurrent_generation( + self, server: tuple[Any, int], concurrent_requests: int + ) -> None: + """Tests concurrent text generation requests. + + Args: + server: Tuple of (process, port) from server fixture + concurrent_requests: Number of concurrent requests to test + """ + process, port = server + assert process.poll() is None, "Server process terminated unexpectedly" + + prompt = "1 2 3 4 5 " + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + futures = [ + executor.submit(self._generate, prompt, port) + for _ in range(concurrent_requests) + ] + + for future in as_completed(futures): + response = future.result() + assert response.startswith("6 7 8"), f"Unexpected response: {response}" + + def _generate(self, prompt: str, port: int) -> str: + """Helper method to make generation request to server. + + Args: + prompt: Input text prompt + port: Server port number + + Returns: + Generated text response + + Raises: + requests.exceptions.RequestException: If request fails + """ + response = requests.post( + f"http://localhost:{port}/generate", + headers={"Content-Type": "application/json"}, + json={ + "text": prompt, + "sampling_params": {"max_completion_tokens": 15, "temperature": 0.7}, + "rid": uuid.uuid4().hex, + "stream": False, + }, + timeout=30, # Add reasonable timeout + ) + response.raise_for_status() + + # Parse streaming response + data = response.text + assert data.startswith("data: "), f"Invalid response format: {data}" + return data[6:].rstrip("\n")