diff --git a/examples/apps/fastapi_server.py b/examples/apps/fastapi_server.py index 972cb90d6..59d0226bf 100755 --- a/examples/apps/fastapi_server.py +++ b/examples/apps/fastapi_server.py @@ -79,22 +79,27 @@ async def __call__(self, host, port): @click.command() @click.argument("model_dir") +@click.argument("tokenizer_path") @click.option("--host", type=str, default=None) @click.option("--port", type=int, default=8000) @click.option("--max_beam_width", type=int, default=1) @click.option("--tp_size", type=int, default=1) +@click.option("--max_batch_size", type=int, default=10) def entrypoint(model_dir: str, + tokenizer_path: str, host: Optional[str] = None, port: int = 8000, max_beam_width: int = 1, - tp_size: int = 1): + tp_size: int = 1, + max_batch_size: int = 10): host = host or "0.0.0.0" port = port or 8000 logging.info(f"Starting server at {host}:{port}") - build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width) + build_config = BuildConfig(max_batch_size=max_batch_size, max_beam_width=max_beam_width) llm = LLM(model_dir, + tokenizer_path, tensor_parallel_size=tp_size, build_config=build_config)