From 45abe0f74ee281aea6e5283c1e738061256cfcae Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 26 Nov 2024 16:20:18 +0100 Subject: [PATCH] server : replace behave with pytest (#10416) * server : replace behave with pytest * fix test on windows * misc * add more tests * more tests * styling * log less, fix embd test * added all sequential tests * fix coding style * fix save slot test * add parallel completion test * fix parallel test * remove feature files * update test docs * no cache_prompt for some tests * add test_cache_vs_nocache_prompt --- .devops/nix/python-scripts.nix | 2 +- .github/workflows/server.yml | 9 +- examples/server/tests/.gitignore | 1 + examples/server/tests/README.md | 33 +- examples/server/tests/conftest.py | 15 + .../server/tests/features/ctx_shift.feature | 66 - .../server/tests/features/embeddings.feature | 113 -- examples/server/tests/features/environment.py | 71 - examples/server/tests/features/infill.feature | 36 - examples/server/tests/features/issues.feature | 5 - examples/server/tests/features/lora.feature | 36 - .../server/tests/features/parallel.feature | 131 -- .../server/tests/features/passkey.feature | 56 - examples/server/tests/features/rerank.feature | 42 - .../server/tests/features/results.feature | 118 -- .../server/tests/features/security.feature | 68 - examples/server/tests/features/server.feature | 120 -- .../server/tests/features/slotsave.feature | 58 - examples/server/tests/features/steps/steps.py | 1518 ----------------- .../tests/features/wrong_usages.feature | 25 - examples/server/tests/requirements.txt | 2 +- examples/server/tests/tests.sh | 5 +- examples/server/tests/unit/test_basic.py | 34 + .../server/tests/unit/test_chat_completion.py | 129 ++ examples/server/tests/unit/test_completion.py | 223 +++ examples/server/tests/unit/test_ctx_shift.py | 67 + examples/server/tests/unit/test_embedding.py | 99 ++ examples/server/tests/unit/test_infill.py | 35 + examples/server/tests/unit/test_lora.py | 42 + examples/server/tests/unit/test_rerank.py | 38 + examples/server/tests/unit/test_security.py | 83 + examples/server/tests/unit/test_slot_save.py | 98 ++ examples/server/tests/unit/test_tokenize.py | 59 + examples/server/tests/utils.py | 377 ++++ 34 files changed, 1317 insertions(+), 2497 deletions(-) create mode 100644 examples/server/tests/conftest.py delete mode 100644 examples/server/tests/features/ctx_shift.feature delete mode 100644 examples/server/tests/features/embeddings.feature delete mode 100644 examples/server/tests/features/environment.py delete mode 100644 examples/server/tests/features/infill.feature delete mode 100644 examples/server/tests/features/issues.feature delete mode 100644 examples/server/tests/features/lora.feature delete mode 100644 examples/server/tests/features/parallel.feature delete mode 100644 examples/server/tests/features/passkey.feature delete mode 100644 examples/server/tests/features/rerank.feature delete mode 100644 examples/server/tests/features/results.feature delete mode 100644 examples/server/tests/features/security.feature delete mode 100644 examples/server/tests/features/server.feature delete mode 100644 examples/server/tests/features/slotsave.feature delete mode 100644 examples/server/tests/features/steps/steps.py delete mode 100644 examples/server/tests/features/wrong_usages.feature create mode 100644 examples/server/tests/unit/test_basic.py create mode 100644 examples/server/tests/unit/test_chat_completion.py create mode 100644 examples/server/tests/unit/test_completion.py create mode 100644 examples/server/tests/unit/test_ctx_shift.py create mode 100644 examples/server/tests/unit/test_embedding.py create mode 100644 examples/server/tests/unit/test_infill.py create mode 100644 examples/server/tests/unit/test_lora.py create mode 100644 examples/server/tests/unit/test_rerank.py create mode 100644 examples/server/tests/unit/test_security.py create mode 100644 examples/server/tests/unit/test_slot_save.py create mode 100644 examples/server/tests/unit/test_tokenize.py create mode 100644 examples/server/tests/utils.py diff --git a/.devops/nix/python-scripts.nix b/.devops/nix/python-scripts.nix index 392e9ffe41bf5..56ea182788764 100644 --- a/.devops/nix/python-scripts.nix +++ b/.devops/nix/python-scripts.nix @@ -34,7 +34,7 @@ let # server tests openai - behave + pytest prometheus-client ]; in diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 699ac095d6c83..2e8e3348f4292 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -122,14 +122,14 @@ jobs: id: server_integration_tests run: | cd examples/server/tests - PORT=8888 ./tests.sh + ./tests.sh - name: Slow tests id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | cd examples/server/tests - PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow + SLOW_TESTS=1 ./tests.sh server-windows: @@ -180,11 +180,12 @@ jobs: run: | cd examples/server/tests $env:PYTHONIOENCODING = ":replace" - behave.exe --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + pytest -v -x - name: Slow tests id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | cd examples/server/tests - behave.exe --stop --no-skipped --no-capture --tags slow + $env:SLOW_TESTS = "1" + pytest -v -x diff --git a/examples/server/tests/.gitignore b/examples/server/tests/.gitignore index 1d17dae13b53a..90ee7fe6d971a 100644 --- a/examples/server/tests/.gitignore +++ b/examples/server/tests/.gitignore @@ -1 +1,2 @@ .venv +tmp diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 10f22c4471ea7..2930a2e0dea0f 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -1,19 +1,9 @@ # Server tests -Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) -and [behave](https://behave.readthedocs.io/en/latest/): - -* [issues.feature](./features/issues.feature) Pending issues scenario -* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests -* [security.feature](./features/security.feature) Security, CORS and API Key -* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc... +Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/). Tests target GitHub workflows job runners with 4 vCPU. -Requests are -using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) -based http client. - Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`. @@ -39,26 +29,19 @@ It's possible to override some scenario steps values with environment variables: |--------------------------|------------------------------------------------------------------------------------------------| | `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` | | `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` | -| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` | +| `DEBUG` | to enable steps and server verbose mode `--verbose` | | `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` | -### Run @bug, @wip or @wrong_usage annotated scenario - -Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope. - -- `@bug` annotation aims to link a scenario with a GitHub issue. -- `@wrong_usage` are meant to show user issue that are actually an expected behavior -- `@wip` to focus on a scenario working in progress -- `@slow` heavy test, disabled by default - -To run a scenario annotated with `@bug`, start: +To run slow tests: ```shell -DEBUG=ON ./tests.sh --no-skipped --tags bug --stop +SLOW_TESTS=1 ./tests.sh ``` -After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated. +To run with stdout/stderr display in real time (verbose output, but useful for debugging): ```shell -./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile" +DEBUG=1 ./tests.sh -s -v -x ``` + +To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) diff --git a/examples/server/tests/conftest.py b/examples/server/tests/conftest.py new file mode 100644 index 0000000000000..017d1bb841efd --- /dev/null +++ b/examples/server/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest +from utils import * + + +# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test +@pytest.fixture(autouse=True) +def stop_server_after_each_test(): + # do nothing before each test + yield + # stop all servers after each test + instances = set( + server_instances + ) # copy the set to prevent 'Set changed size during iteration' + for server in instances: + server.stop() diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature deleted file mode 100644 index ae6c6b01b0221..0000000000000 --- a/examples/server/tests/features/ctx_shift.feature +++ /dev/null @@ -1,66 +0,0 @@ -@llama.cpp -@ctx_shift -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a model file test-model.gguf - And a model alias tinyllama-2 - And BOS token is 1 - And 42 as server seed - And 256 KV cache size - And 32 as batch size - And 2 slots - - # the prompt is 301 tokens - # the slot context is 256/2 = 128 tokens - # the prompt is truncated to keep the last 109 tokens - # 64 tokens are generated thanks to shifting the context when it gets full - Scenario: Inference with context shift - And 64 server max tokens to predict - Then the server is starting - Then the server is healthy - Given a prompt: - """ - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - """ - And a completion request with no api error - Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl - And the completion is truncated - And 109 prompt tokens are processed - - Scenario Outline: Inference without context shift - And server max tokens to predict - And disable context shifting - Then the server is starting - Then the server is healthy - Given a prompt: - """ - Hi how are you - """ - And a completion request with no api error - Then tokens are predicted matching twind|Anna - And the completion is truncated - And 8 prompt tokens are processed - Examples: - | n_predict | n_token_output | truncated | - | 64 | 64 | not | - | -1 | 120 | | - - Scenario: Inference without context shift (expected error: prompt too long) - And disable context shifting - Then the server is starting - Then the server is healthy - Given a prompt: - """ - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - """ - And a completion request with 400 api error - diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature deleted file mode 100644 index f4fe2ee4335ff..0000000000000 --- a/examples/server/tests/features/embeddings.feature +++ /dev/null @@ -1,113 +0,0 @@ -@llama.cpp -@embeddings -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf - And a model file bert-bge-small.gguf - And a model alias bert-bge-small - And 42 as server seed - And 2 slots - # the bert-bge-small model has context size of 512 - # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512 - # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20 - And 128 as batch size - And 128 as ubatch size - And 512 KV cache size - And enable embeddings endpoint - Then the server is starting - Then the server is healthy - - Scenario: Embedding - When embeddings are computed for: - """ - What is the capital of Bulgaria ? - """ - Then embeddings are generated - - Scenario: Embedding (error: prompt too long) - When embeddings are computed for: - """ - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - """ - And embeddings request with 500 api error - - Scenario: OAI Embeddings compatibility - Given a model bert-bge-small - When an OAI compatible embeddings computation request for: - """ - What is the capital of Spain ? - """ - Then embeddings are generated - - Scenario: OAI Embeddings compatibility with multiple inputs - Given a model bert-bge-small - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - When an OAI compatible embeddings computation request for multiple inputs - Then embeddings are generated - - Scenario: Multi users embeddings - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - Given concurrent embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated - - Scenario: Multi users OAI compatibility embeddings - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - And a prompt: - """ - What is the biggest US city ? - """ - And a prompt: - """ - What is the capital of Bulgaria ? - """ - And a model bert-bge-small - Given concurrent OAI embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated - - Scenario: All embeddings should be the same - Given 10 fixed prompts - And a model bert-bge-small - Given concurrent OAI embedding requests - Then all embeddings are the same diff --git a/examples/server/tests/features/environment.py b/examples/server/tests/features/environment.py deleted file mode 100644 index e7845dc2f51fc..0000000000000 --- a/examples/server/tests/features/environment.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -import signal -import socket -import sys -import time -import traceback -from contextlib import closing -from subprocess import TimeoutExpired - - -def before_scenario(context, scenario): - context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' - if context.debug: - print("DEBUG=ON") - print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m") - port = 8080 - if 'PORT' in os.environ: - port = int(os.environ['PORT']) - if is_server_listening("localhost", port): - assert False, "Server already started" - - -def after_scenario(context, scenario): - try: - if 'server_process' not in context or context.server_process is None: - return - if scenario.status == "failed": - if 'GITHUB_ACTIONS' in os.environ: - print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n") - if os.path.isfile('llama.log'): - with closing(open('llama.log', 'r')) as f: - for line in f: - print(line) - if not is_server_listening(context.server_fqdn, context.server_port): - print("\x1b[33;101mERROR: Server stopped listening\x1b[0m") - - if context.server_process.poll() is not None: - assert False, f"Server not running pid={context.server_process.pid} ..." - - server_graceful_shutdown(context) # SIGINT - - try: - context.server_process.wait(0.5) - except TimeoutExpired: - print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...") - context.server_process.kill() # SIGKILL - context.server_process.wait() - - while is_server_listening(context.server_fqdn, context.server_port): - time.sleep(0.1) - except Exception: - print("ignoring error in after_scenario:") - traceback.print_exc(file=sys.stdout) - - -def server_graceful_shutdown(context): - print(f"shutting down server pid={context.server_process.pid} ...") - if os.name == 'nt': - interrupt = signal.CTRL_C_EVENT - else: - interrupt = signal.SIGINT - context.server_process.send_signal(interrupt) - - -def is_server_listening(server_fqdn, server_port): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - result = sock.connect_ex((server_fqdn, server_port)) - _is_server_listening = result == 0 - if _is_server_listening: - print(f"server is listening on {server_fqdn}:{server_port}...") - return _is_server_listening diff --git a/examples/server/tests/features/infill.feature b/examples/server/tests/features/infill.feature deleted file mode 100644 index a0bbfef77707b..0000000000000 --- a/examples/server/tests/features/infill.feature +++ /dev/null @@ -1,36 +0,0 @@ -@llama.cpp -@infill -Feature: llama.cpp server - - # The current model is made by adding FIM tokens to the existing stories260K - # We may want to use a better model in the future, maybe something like SmolLM 360M - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models - And a model file test-model-infill.gguf - And a model alias tinyllama-infill - And 42 as server seed - And 1024 as batch size - And 1024 as ubatch size - And 2048 KV cache size - And 64 max tokens to predict - And 0.0 temperature - Then the server is starting - Then the server is healthy - - Scenario: Infill without input_extra - Given a prompt "Complete this" - And an infill input extra none none - And an infill input prefix "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_" - And an infill input suffix "}\n" - And an infill request with no api error - Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird - - Scenario: Infill with input_extra - Given a prompt "Complete this" - And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n" - And an infill input prefix "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_" - And an infill input suffix "}\n" - And an infill request with no api error - Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room" diff --git a/examples/server/tests/features/issues.feature b/examples/server/tests/features/issues.feature deleted file mode 100644 index 7b13e44cad395..0000000000000 --- a/examples/server/tests/features/issues.feature +++ /dev/null @@ -1,5 +0,0 @@ -# List of ongoing issues -# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug -@bug -Feature: Issues - # No confirmed issue at the moment diff --git a/examples/server/tests/features/lora.feature b/examples/server/tests/features/lora.feature deleted file mode 100644 index 7b85988ac6e87..0000000000000 --- a/examples/server/tests/features/lora.feature +++ /dev/null @@ -1,36 +0,0 @@ -@llama.cpp -@lora -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf - And a model file stories15M_MOE-F16.gguf - And a model alias stories15M_MOE - And a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf - And 42 as server seed - And 1024 as batch size - And 1024 as ubatch size - And 2048 KV cache size - And 64 max tokens to predict - And 0.0 temperature - Then the server is starting - Then the server is healthy - - Scenario: Completion LoRA disabled - Given switch off lora adapter 0 - Given a prompt: - """ - Look in thy glass - """ - And a completion request with no api error - Then 64 tokens are predicted matching little|girl|three|years|old - - Scenario: Completion LoRA enabled - Given switch on lora adapter 0 - Given a prompt: - """ - Look in thy glass - """ - And a completion request with no api error - Then 64 tokens are predicted matching eye|love|glass|sun diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature deleted file mode 100644 index 423d0f1d42f55..0000000000000 --- a/examples/server/tests/features/parallel.feature +++ /dev/null @@ -1,131 +0,0 @@ -@llama.cpp -@parallel -Feature: Parallel - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models - And a model file test-model-00001-of-00003.gguf - And 42 as server seed - And 128 as batch size - And 256 KV cache size - And 2 slots - And continuous batching - Then the server is starting - Then the server is healthy - - Scenario Outline: Multi users completion - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And max tokens to predict - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all prompts are predicted with tokens - Examples: - | n_predict | - | 128 | - - Scenario Outline: Multi users OAI completions compatibility - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario Outline: Multi users OAI completions compatibility no v1 - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests no v1 - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario Outline: Multi users with number of prompts exceeding number of slots - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And a prompt: - """ - What is LLM? - """ - And a prompt: - """ - The sky is blue and I love it. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969 - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - And 128 max tokens to predict - Given concurrent completion requests - Then the server is busy - Then the server is idle - Then all prompts are predicted diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature deleted file mode 100644 index ff0a82cc46581..0000000000000 --- a/examples/server/tests/features/passkey.feature +++ /dev/null @@ -1,56 +0,0 @@ -# run with: ./tests.sh --no-skipped --tags passkey -@passkey -@slow -Feature: Passkey / Self-extend with context shift - - Background: Server startup - Given a server listening on localhost:8080 - - # Generates a long text of junk and inserts a secret passkey number inside it. - # Then we query the LLM for the secret passkey. - # see #3856 and #4810 - Scenario Outline: Passkey - Given a model file from HF repo - And as batch size - And as number of junk - And server max tokens to predict - And 42 as seed - And 0.0 temperature - And KV cache size - And 1 slots - And group attention factor to extend context size through self-extend - And group attention width to extend context size through self-extend - # Can be override with N_GPU_LAYERS - And GPU offloaded layers - Then the server is starting - # Higher timeout because the model may need to be downloaded from the internet - Then the server is healthy with timeout 120 seconds - Given available models - Then model 0 is trained on tokens context - Given a prefix prompt: - """ - here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. - """ - And a passkey prompt template: - """ - The pass key is Remember it. is the pass key. - """ - And a junk suffix prompt: - """ - The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. - """ - And a suffix prompt: - """ - What is the pass key? The pass key is - """ - Given a "" passkey challenge prompt with the passkey inserted every junk - And a completion request with no api error - Then tokens are predicted matching - - Examples: - | hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content | - | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 | - | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b | - #| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 | - #| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0 - # 987 | diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature deleted file mode 100644 index c36cc8e215fa6..0000000000000 --- a/examples/server/tests/features/rerank.feature +++ /dev/null @@ -1,42 +0,0 @@ -@llama.cpp -@rerank -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf - And a model file jina-reranker-v1-tiny-en.gguf - And a model alias jina-reranker-v1-tiny-en - And 42 as server seed - And 2 slots - And 512 as batch size - And 512 as ubatch size - And 512 KV cache size - And enable reranking endpoint - Then the server is starting - Then the server is healthy - - Scenario: Rerank - Given a rerank query: - """ - Machine learning is - """ - And a rerank document: - """ - A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines. - """ - And a rerank document: - """ - Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants. - """ - And a rerank document: - """ - Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. - """ - And a rerank document: - """ - Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine. - """ - When reranking request - Then reranking results are returned - Then reranking highest score is index 2 and lowest score is index 3 diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature deleted file mode 100644 index e8e1b54147b05..0000000000000 --- a/examples/server/tests/features/results.feature +++ /dev/null @@ -1,118 +0,0 @@ -@llama.cpp -@results -Feature: Results - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models - And a model file test-model-00001-of-00003.gguf - And 128 as batch size - And 1024 KV cache size - And 128 max tokens to predict - And continuous batching - - Scenario Outline: consistent results with same seed - Given slots - And 1.0 temperature - Then the server is starting - Then the server is healthy - - Given 4 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42 - - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all predictions are equal - Examples: - | n_slots | - | 1 | - # FIXME: unified KV cache nondeterminism - # | 2 | - - Scenario Outline: different results with different seed - Given slots - And 1.0 temperature - Then the server is starting - Then the server is healthy - - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 43 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 44 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 45 - - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all predictions are different - Examples: - | n_slots | - | 1 | - | 2 | - - Scenario Outline: consistent results with same seed and varying batch size - Given 4 slots - And temperature - # And 0 as draft - Then the server is starting - Then the server is healthy - - Given 1 prompts "Write a very long story about AI." with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Given prompts "Write a very long story about AI." with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Then all predictions are equal - Examples: - | n_parallel | temp | - | 1 | 0.0 | - | 1 | 1.0 | - # FIXME: unified KV cache nondeterminism - # See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227 - # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 - # and https://github.com/ggerganov/llama.cpp/pull/7347 . - # | 2 | 0.0 | - # | 4 | 0.0 | - # | 2 | 1.0 | - # | 4 | 1.0 | - - Scenario Outline: consistent token probs with same seed and prompt - Given slots - And KV cache size - And 1.0 temperature - And max tokens to predict - Then the server is starting - Then the server is healthy - - Given 1 prompts "The meaning of life is" with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Given prompts "The meaning of life is" with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Then all token probabilities are equal - Examples: - | n_slots | n_kv | n_predict | n_parallel | - | 4 | 1024 | 1 | 1 | - # FIXME: unified KV cache nondeterminism - # See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227 - # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 - # and https://github.com/ggerganov/llama.cpp/pull/7347 . - # | 4 | 1024 | 1 | 4 | - # | 4 | 1024 | 100 | 1 | - # This test still fails even the above patches; the first token probabilities are already different. - # | 4 | 1024 | 100 | 4 | diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature deleted file mode 100644 index ef30007c3eddb..0000000000000 --- a/examples/server/tests/features/security.feature +++ /dev/null @@ -1,68 +0,0 @@ -@llama.cpp -@security -Feature: Security - - Background: Server startup with an api key defined - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a server api key THIS_IS_THE_KEY - Then the server is starting - Then the server is healthy - - Scenario Outline: Completion with some user api key - Given a prompt test - And a user api key - And 4 max tokens to predict - And a completion request with api error - - Examples: Prompts - | api_key | api_error | - | THIS_IS_THE_KEY | no | - | THIS_IS_THE_KEY | no | - | hackeme | raised | - | | raised | - - Scenario Outline: OAI Compatibility - Given a system prompt test - And a user prompt test - And a model test - And 2 max tokens to predict - And streaming is disabled - And a user api key - Given an OAI compatible chat completions request with api error - - Examples: Prompts - | api_key | api_error | - | THIS_IS_THE_KEY | no | - | THIS_IS_THE_KEY | no | - | hackme | raised | - - Scenario Outline: OAI Compatibility (invalid response formats) - Given a system prompt test - And a user prompt test - And a response format - And a model test - And 2 max tokens to predict - And streaming is disabled - Given an OAI compatible chat completions request with raised api error - - Examples: Prompts - | response_format | - | {"type": "sound"} | - | {"type": "json_object", "schema": 123} | - | {"type": "json_object", "schema": {"type": 123}} | - | {"type": "json_object", "schema": {"type": "hiccup"}} | - - - Scenario Outline: CORS Options - Given a user api key THIS_IS_THE_KEY - When an OPTIONS request is sent from - Then CORS header is set to - - Examples: Headers - | origin | cors_header | cors_header_value | - | localhost | Access-Control-Allow-Origin | localhost | - | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | - | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | GET, POST | - | web.mydomain.fr | Access-Control-Allow-Headers | * | diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature deleted file mode 100644 index 15e24c624af37..0000000000000 --- a/examples/server/tests/features/server.feature +++ /dev/null @@ -1,120 +0,0 @@ -@llama.cpp -@server -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a model file test-model.gguf - And a model alias tinyllama-2 - And BOS token is 1 - And 42 as server seed - # KV Cache corresponds to the total amount of tokens - # that can be stored across all independent sequences: #4130 - # see --ctx-size and #5568 - And 256 KV cache size - And 32 as batch size - And 2 slots - And 64 server max tokens to predict - And prometheus compatible metrics exposed - Then the server is starting - Then the server is healthy - - Scenario: Health - Then the server is ready - And all slots are idle - - - Scenario Outline: Completion - Given a prompt - And max tokens to predict - And a completion request with no api error - Then tokens are predicted matching - And the completion is truncated - And prompt tokens are processed - And prometheus metrics are exposed - And metric llamacpp:tokens_predicted is - - Examples: Prompts - | prompt | n_predict | re_content | n_prompt | n_predicted | truncated | - | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not | - | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not | - - Scenario: Completion prompt truncated - Given a prompt: - """ - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - """ - And a completion request with no api error - Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl - And the completion is truncated - And 109 prompt tokens are processed - - - Scenario Outline: OAI Compatibility - Given a model - And a system prompt - And a user prompt - And max tokens to predict - And streaming is - Given an OAI compatible chat completions request with no api error - Then tokens are predicted matching - And prompt tokens are processed - And the completion is truncated - - Examples: Prompts - | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated | - | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not | - | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | | - - - Scenario Outline: OAI Compatibility w/ response format - Given a model test - And a system prompt test - And a user prompt test - And a response format - And 10 max tokens to predict - Given an OAI compatible chat completions request with no api error - Then tokens are predicted matching - - Examples: Prompts - | response_format | n_predicted | re_content | - | {"type": "json_object", "schema": {"const": "42"}} | 6 | "42" | - | {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] | - | {"type": "json_object"} | 10 | \{ " Jacky. | - - - Scenario: Tokenize / Detokenize - When tokenizing: - """ - What is the capital of France ? - """ - Then tokens can be detokenized - And tokens do not begin with BOS - - Scenario: Tokenize w/ BOS - Given adding special tokens - When tokenizing: - """ - What is the capital of Germany? - """ - Then tokens begin with BOS - Given first token is removed - Then tokens can be detokenized - - Scenario: Tokenize with pieces - When tokenizing with pieces: - """ - What is the capital of Germany? - 媽 - """ - Then tokens are given with pieces - - Scenario: Models available - Given available models - Then 1 models are supported - Then model 0 is identified by tinyllama-2 - Then model 0 is trained on 128 tokens context diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature deleted file mode 100644 index 1c281c0741afe..0000000000000 --- a/examples/server/tests/features/slotsave.feature +++ /dev/null @@ -1,58 +0,0 @@ -@llama.cpp -@slotsave -Feature: llama.cpp server slot management - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And prompt caching is enabled - And 2 slots - And . as slot save path - And 2048 KV cache size - And 42 as server seed - And 24 max tokens to predict - Then the server is starting - Then the server is healthy - - Scenario: Save and Restore Slot - # First prompt in slot 1 should be fully processed - Given a user prompt "What is the capital of France?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed - When the slot 1 is saved with filename "slot1.bin" - Then the server responds with status code 200 - # Since we have cache, this should only process the last tokens - Given a user prompt "What is the capital of Germany?" - And a completion request with no api error - Then 24 tokens are predicted matching (Thank|special) - And 7 prompt tokens are processed - # Loading the original cache into slot 0, - # we should only be processing 1 prompt token and get the same output - When the slot 0 is restored with filename "slot1.bin" - Then the server responds with status code 200 - Given a user prompt "What is the capital of France?" - And using slot id 0 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 1 prompt tokens are processed - # For verification that slot 1 was not corrupted during slot 0 load, same thing - Given a user prompt "What is the capital of Germany?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Thank|special) - And 1 prompt tokens are processed - - Scenario: Erase Slot - Given a user prompt "What is the capital of France?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed - When the slot 1 is erased - Then the server responds with status code 200 - Given a user prompt "What is the capital of France?" - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py deleted file mode 100644 index 687b163f487b6..0000000000000 --- a/examples/server/tests/features/steps/steps.py +++ /dev/null @@ -1,1518 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import asyncio -import json -import os -import re -import socket -import subprocess -import sys -import threading -import time -import requests -from collections.abc import Sequence -from contextlib import closing -from re import RegexFlag -from typing import Any, Literal, cast - -import aiohttp -import numpy as np -import openai -from openai.types.chat import ChatCompletionChunk -from behave import step # pyright: ignore[reportAttributeAccessIssue] -from behave.api.async_step import async_run_until_complete -from prometheus_client import parser - -# pyright: reportRedeclaration=false - -DEFAULT_TIMEOUT_SECONDS = aiohttp.ClientTimeout(total=600) - -@step("a server listening on {server_fqdn}:{server_port}") -def step_server_config(context, server_fqdn: str, server_port: str): - context.server_fqdn = server_fqdn - context.server_port = int(server_port) - context.n_threads = None - context.n_gpu_layer = None - if 'PORT' in os.environ: - context.server_port = int(os.environ['PORT']) - print(f"$PORT set, overriding server port with to {context.server_port}") - if 'FQDN' in os.environ: - context.server_fqdn = os.environ['FQDN'] - print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}") - if 'N_GPU_LAYERS' in os.environ: - context.n_gpu_layer = int(os.environ['N_GPU_LAYERS']) - print(f"$N_GPU_LAYERS set, overriding n_gpu_layer with to {context.n_gpu_layer}") - - context.base_url = f'http://{context.server_fqdn}:{context.server_port}' - - context.model_alias = None - context.model_file = None - context.model_hf_repo = None - context.model_hf_file = None - context.model_url = None - context.n_batch = None - context.n_ubatch = None - context.n_ctx = None - context.n_ga = None - context.n_ga_w = None - context.n_predict = None - context.n_prompts = 0 - context.n_server_predict = None - context.slot_save_path = None - context.id_slot = None - context.cache_prompt = None - context.n_slots = None - context.prompt_prefix = None - context.prompt_suffix = None - context.server_api_key = None - context.server_continuous_batching = False - context.server_embeddings = False - context.server_reranking = False - context.server_metrics = False - context.server_process = None - context.seed = None - context.draft = None - context.server_seed = None - context.user_api_key = None - context.response_format = None - context.temperature = None - context.lora_file = None - context.disable_ctx_shift = False - - # infill - context.infill_input_extra = None - context.infill_input_suffix = '' - context.infill_input_prefix = '' - - context.tasks_result = [] - context.concurrent_tasks = [] - context.prompts = [] - - context.reranking_query = None - context.reranking_documents = [] - context.reranking_results = None - - -@step('a model file {hf_file} from HF repo {hf_repo}') -def step_download_hf_model(context, hf_file: str, hf_repo: str): - context.model_hf_repo = hf_repo - context.model_hf_file = hf_file - context.model_file = os.path.basename(hf_file) - -@step('a lora adapter file from {lora_file_url}') -def step_download_lora_file(context, lora_file_url: str): - file_name = lora_file_url.split('/').pop() - context.lora_file = f'../../../{file_name}' - with open(context.lora_file, 'wb') as f: - f.write(requests.get(lora_file_url).content) - -@step('a model file {model_file}') -def step_model_file(context, model_file: str): - context.model_file = model_file - - -@step('a model url {model_url}') -def step_model_url(context, model_url: str): - context.model_url = model_url - - -@step('a model alias {model_alias}') -def step_model_alias(context, model_alias: str): - context.model_alias = model_alias - - -@step('{seed:d} as server seed') -def step_seed(context, seed: int): - context.server_seed = seed - - -@step('{ngl:d} GPU offloaded layers') -def step_n_gpu_layer(context, ngl: int): - if 'N_GPU_LAYERS' in os.environ: - new_ngl = int(os.environ['N_GPU_LAYERS']) - if context.debug: - print(f"-ngl upgraded from {ngl} to {new_ngl}") - ngl = new_ngl - context.n_gpu_layer = ngl - - -@step('{n_threads:d} threads') -def step_n_threads(context, n_threads: int): - context.n_thread = n_threads - - -@step('{draft:d} as draft') -def step_draft(context, draft: int): - context.draft = draft - - -@step('{n_ctx:d} KV cache size') -def step_n_ctx(context, n_ctx: int): - context.n_ctx = n_ctx - - -@step('{n_slots:d} slots') -def step_n_slots(context, n_slots: int): - context.n_slots = n_slots - - -@step('{n_predict:d} server max tokens to predict') -def step_server_n_predict(context, n_predict: int): - context.n_server_predict = n_predict if n_predict > 0 else None - - -@step('{slot_save_path} as slot save path') -def step_slot_save_path(context, slot_save_path: str): - context.slot_save_path = slot_save_path - - -@step('using slot id {id_slot:d}') -def step_id_slot(context, id_slot: int): - context.id_slot = id_slot - - -@step('prompt caching is enabled') -def step_enable_prompt_cache(context): - context.cache_prompt = True - - -@step('continuous batching') -def step_server_continuous_batching(context): - context.server_continuous_batching = True - - -@step('enable embeddings endpoint') -def step_server_embeddings(context): - context.server_embeddings = True - -@step('enable reranking endpoint') -def step_server_reranking(context): - context.server_reranking = True - -@step('prometheus compatible metrics exposed') -def step_server_metrics(context): - context.server_metrics = True - -@step('disable context shifting') -def step_server_disable_ctx_shift(context): - context.disable_ctx_shift = True - -@step("the server is starting") -def step_start_server(context): - start_server_background(context) - attempts = 0 - max_attempts = 20 - if 'GITHUB_ACTIONS' in os.environ: - max_attempts *= 2 - - addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM) - family, typ, proto, _, sockaddr = addrs[0] - - while True: - with closing(socket.socket(family, typ, proto)) as sock: - result = sock.connect_ex(sockaddr) - if result == 0: - print("\x1b[33;46mserver started!\x1b[0m") - return - attempts += 1 - if attempts > max_attempts: - assert False, "server not started" - print(f"waiting for server to start, connect error code = {result}...") - time.sleep(0.1) - - -async def wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int): - match expecting_status: - case 'healthy': - await wait_for_slots_status(context, context.base_url, 200, - timeout=timeout) - - case 'ready' | 'idle': - await wait_for_slots_status(context, context.base_url, 200, - timeout=timeout, - params={'fail_on_no_slot': 1}, - slots_idle=context.n_slots, - slots_processing=0) - case 'busy': - await wait_for_slots_status(context, context.base_url, 503, - params={'fail_on_no_slot': 1}, - slots_idle=0, - slots_processing=context.n_slots) - case _: - assert False, "unknown status" - - -@step("the server is {expecting_status} with timeout {timeout:d} seconds") -@async_run_until_complete -async def step_wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int): - await wait_for_server_status_with_timeout(context, expecting_status, timeout) - - -@step("the server is {expecting_status}") -@async_run_until_complete -async def step_wait_for_server_status(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): - await wait_for_server_status_with_timeout(context, expecting_status, 30) - - -@step('all slots are {expected_slot_status_string}') -@async_run_until_complete -async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str): - match expected_slot_status_string: - case 'idle': - expected_slot_status = False - case 'busy': - expected_slot_status = True - case _: - assert False, "unknown status" - - expected_slots = [{'id': slot_id, 'is_processing': expected_slot_status} - for slot_id in range(context.n_slots)] - await request_slots_status(context, expected_slots) - - -@step('a completion request with {api_error} api error') -@async_run_until_complete -async def step_request_completion(context, api_error: Literal['raised'] | str): - expect_api_error = api_error == 'raised' or api_error != 'no' - seeds = await completions_seed(context, num_seeds=1) - completion = await request_completion(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.base_url, - debug=context.debug, - n_predict=context.n_predict, - cache_prompt=context.cache_prompt, - id_slot=context.id_slot, - expect_api_error=expect_api_error, - user_api_key=context.user_api_key, - temperature=context.temperature) - context.tasks_result.append(completion) - if context.debug: - print(f"Completion response: {completion}") - if api_error == 'raised': - assert completion == 401, f"completion must be an 401 status code: {completion}" - elif api_error.isdigit(): - api_error_code = int(api_error) - assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" - - -@step('an infill request with {api_error} api error') -@async_run_until_complete -async def step_request_completion(context, api_error: Literal['raised'] | str): - if api_error != 'no': - raise ValueError(f'api_error={api_error} is not yet implemented') - payload = { - "prompt": context.prompts[0], - "input_suffix": context.infill_input_suffix, - "input_prefix": context.infill_input_prefix, - "n_predict": context.n_predict, - "seed": context.seed, - "temperature": context.temperature, - } - if context.infill_input_extra is not None: - payload['input_extra'] = context.infill_input_extra - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/infill', - json=payload) as response: - assert response.status == 200 - context.tasks_result = [await response.json()] - - -@step('{predicted_n:d} tokens are predicted matching {re_content}') -def step_n_tokens_predicted_with_content(context, predicted_n, re_content): - context.completion = context.tasks_result.pop() - assert_n_tokens_predicted(context.completion, predicted_n, re_content) - - -@step('{predicted_n:d} tokens are predicted') -def step_n_tokens_predicted(context, predicted_n): - context.completion = context.tasks_result.pop() - assert_n_tokens_predicted(context.completion, predicted_n) - - -@step('all predictions are equal') -@async_run_until_complete -async def step_predictions_equal(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_predictions_equal(context.tasks_result) - context.tasks_result = [] - - -@step('all predictions are different') -@async_run_until_complete -async def step_predictions_different(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_predictions_different(context.tasks_result) - context.tasks_result = [] - - -@step('all token probabilities are equal') -@async_run_until_complete -async def step_token_probabilities_equal(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_token_probabilities_equal(context.tasks_result) - context.tasks_result = [] - - -@step('the completion is truncated') -def step_assert_completion_truncated(context): - step_assert_completion_truncated(context, '') - - -@step('the completion is {truncated} truncated') -def step_assert_completion_truncated(context, truncated): - truncated = truncated != "not" - assert context.completion['truncated'] == truncated, f'{context.completion}' - - -@step('{n_prompt:d} prompt tokens are processed') -def step_impl(context, n_prompt): - assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}" - - -@step('a user prompt {user_prompt}') -def step_user_prompt(context, user_prompt): - context.prompts.append(user_prompt) - context.n_prompts = len(context.prompts) - - -@step('a system prompt {system_prompt}') -def step_system_prompt(context, system_prompt): - context.system_prompt = system_prompt - - -@step('a model {model}') -def step_model(context, model): - context.model = model - - -@step('{max_tokens:d} max tokens to predict') -def step_max_tokens(context, max_tokens): - context.n_predict = max_tokens - - -@step('a response format {response_format}') -def step_response_format(context, response_format): - context.response_format = json.loads(response_format) - - -@step('{temperature:f} temperature') -def step_temperature(context, temperature): - context.temperature = temperature - - -@step('streaming is {enable_streaming}') -def step_streaming(context, enable_streaming): - context.enable_streaming = enable_streaming == 'enabled' - - -@step('a user api key {user_api_key}') -def step_user_api_key(context, user_api_key): - context.user_api_key = user_api_key - - -@step('no user api key') -def step_no_user_api_key(context): - context.user_api_key = None - - -@step('a user api key ') -def step_no_user_api_key_space(context): - context.user_api_key = None - - -@step('a server api key {server_api_key}') -def step_server_api_key(context, server_api_key): - context.server_api_key = server_api_key - - -@step('{n_junk:d} as number of junk') -def step_n_junk(context, n_junk): - context.n_junk = n_junk - - -@step('{n_batch:d} as batch size') -def step_n_batch(context, n_batch): - context.n_batch = n_batch - - -@step('{n_ubatch:d} as ubatch size') -def step_n_ubatch(context, n_ubatch): - context.n_ubatch = n_ubatch - - -@step('{seed:d} as seed') -def step_seed(context, seed): - if context.seed is None: - context.seed = [seed] - else: - context.seed.append(seed) - - -@step('BOS token is {bos:d}') -def step_bos_token(context, bos): - context.bos = bos - - -@step('a prefix prompt') -def step_prompt_prefix(context): - context.prompt_prefix = context_text(context) - - -@step('a junk suffix prompt') -def step_prompt_junk_suffix(context): - context.prompt_junk_suffix = context_text(context) - - -@step('a suffix prompt') -def step_prompt_suffix(context): - context.prompt_suffix = context_text(context) - - -@step('{n_ga:d} group attention factor' - ' to extend context size through self-extend') -def step_impl(context, n_ga): - context.n_ga = n_ga - - -@step('{n_ga_w:d} group attention width to extend context size through self-extend') -def step_impl(context, n_ga_w): - context.n_ga_w = n_ga_w - - -@step('a passkey prompt template') -def step_prompt_passkey(context): - context.prompt_passkey = context_text(context) - -@step('a rerank query') -def step_set_rerank_query(context): - context.reranking_query = context_text(context) - context.reranking_documents = [] - -@step('a rerank document') -def step_set_rerank_document(context): - context.reranking_documents.append(context_text(context)) - -@step('{n_prompts:d} fixed prompts') -def step_fixed_prompts(context, n_prompts): - context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)]) - context.n_prompts = n_prompts - - -@step('a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') -def step_prompt_passkey(context, passkey, i_pos): - prompt = "" - for i in range(context.n_junk): - if i % context.n_junk == i_pos: - prompt += context.prompt_passkey # the passkey is already substituted - prompt += context.prompt_junk_suffix - if context.debug: - passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" - print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```") - context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) - context.n_prompts = len(context.prompts) - - -@step('an OAI compatible chat completions request with {api_error} api error') -@async_run_until_complete -async def step_oai_chat_completions(context, api_error): - if context.debug: - print(f"Submitting OAI compatible completions request...") - expect_api_error = api_error == 'raised' - seeds = await completions_seed(context, num_seeds=1), - completion = await oai_chat_completions(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.system_prompt, - context.base_url, - '/v1/chat', - False, - model=context.model if hasattr(context, 'model') else None, - - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - - response_format=context.response_format - if hasattr(context, 'response_format') else None, - - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None, - - expect_api_error=expect_api_error) - context.tasks_result.append(completion) - if context.debug: - print(f"Completion response: {completion}") - if expect_api_error: - assert completion == 401, f"completion must be an 401 status code: {completion}" - - if context.debug: - print(f"Completion response: {completion}") - - -@step('a prompt') -def step_a_prompt(context): - context.prompts.append(context_text(context)) - context.n_prompts = len(context.prompts) - - -@step('a prompt {prompt}') -def step_a_prompt_prompt(context, prompt): - context.prompts.append(prompt) - context.n_prompts = len(context.prompts) - - -# TODO: allow this to be repeated -@step('an infill input extra {filename} {text}') -def step_infill_input_extra(context, filename, text): - if filename == 'none': - context.infill_input_extra = None - else: - context.infill_input_extra = [{'filename': filename, 'text': text}] - - -@step('an infill input suffix {text}') -def step_infill_input_suffix(context, text): - context.infill_input_suffix = text - - -@step('an infill input prefix {text}') -def step_infill_input_prefix(context, text): - context.infill_input_prefix = text - - -@step('{num_prompts:d} prompts {prompt} with seed {seed:d}') -def step_many_prompts(context, num_prompts, prompt, seed): - if context.seed is None: - context.seed = [] - for _ in range(num_prompts): - context.seed.append(seed) - context.prompts.append(prompt) - context.n_prompts = len(context.prompts) - - -@step('concurrent completion requests') -@async_run_until_complete() -async def step_concurrent_completion_requests(context): - await concurrent_requests( - context, - request_completion, - # prompt is inserted automatically - context.base_url, - debug=context.debug, - prompt_prefix=context.prompt_prefix, - prompt_suffix=context.prompt_suffix, - n_predict=context.n_predict if hasattr(context, 'n_predict') else None, - user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, - temperature=context.temperature, - ) - - -@step('concurrent OAI completions requests') -@async_run_until_complete -async def step_oai_chat_completions(context): - await concurrent_requests(context, oai_chat_completions, - # user_prompt is inserted automatically - context.system_prompt, - context.base_url, - '/v1/chat/completions', - True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) - - -@step('concurrent OAI completions requests no v1') -@async_run_until_complete -async def step_oai_chat_completions(context): - await concurrent_requests(context, oai_chat_completions, - # user_prompt is inserted automatically - context.system_prompt, - context.base_url, - '/chat/completions', - True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) - - -@step('all prompts are predicted') -@async_run_until_complete -async def step_all_prompts_are_predicted(context): - await all_prompts_are_predicted(context) - - -@step('all prompts are predicted with {n_expected_predicted:d} tokens') -@async_run_until_complete -async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted): - await all_prompts_are_predicted(context, n_expected_predicted) - - -async def all_prompts_are_predicted(context, expected_predicted_n=None): - n_completions = await gather_tasks_results(context) - assert n_completions > 0 - for i in range(n_completions): - assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n) - assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests" - - -@step('embeddings are computed for') -@async_run_until_complete -async def step_compute_embedding(context): - context.n_prompts = 1 - context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) - - -@step('reranking request') -@async_run_until_complete -async def step_compute_reranking(context): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/reranking', - json={ - "query": context.reranking_query, - "documents": context.reranking_documents, - }) as response: - if response.status == 200: - response_json = await response.json() - context.reranking_results = response_json['results'] - else: - context.reranking_results = response.status - - -@step('all embeddings are the same') -@async_run_until_complete -async def step_all_embeddings_are_the_same(context): - n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests > 0 - embeddings = [] - for i in range(n_embedding_requests): - embedding = context.tasks_result.pop().pop() - embeddings.append(embedding) - assert_embeddings(embedding) - n = len(embeddings) - for i in range(n-1): - for j in range(i+1, n): - embedding1 = np.array(embeddings[i]) - embedding2 = np.array(embeddings[j]) - if context.debug: - print(f"embedding1: {embedding1[-8:]}") - print(f"embedding2: {embedding2[-8:]}") - similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) - msg = f"Similarity between {i} and {j}: {similarity:.10f}" - if context.debug: - print(f"{msg}") - assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg - - -@step('embeddings are generated') -def step_assert_embeddings(context): - assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n" - f"context.n_prompts={context.n_prompts}\n" - f"context.embeddings={context.embeddings}") - for embedding in context.embeddings: - assert_embeddings(embedding) - -@step('embeddings request with {api_error_code:d} api error') -def step_assert_embeddings(context, api_error_code: int): - assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}" - -@step('an OAI compatible embeddings computation request for') -@async_run_until_complete -async def step_oai_compute_embeddings(context): - context.n_prompts = 1 - context.embeddings = await request_oai_embeddings(context_text(context), None, - base_url=context.base_url, - user_api_key=context.user_api_key, - model=context.model) - - -@step('an OAI compatible embeddings computation request for multiple inputs') -@async_run_until_complete -async def step_oai_compute_embeddings_multiple_inputs(context): - context.embeddings = await request_oai_embeddings(context.prompts, None, - base_url=context.base_url, - user_api_key=context.user_api_key, - model=context.model) - context.prompts.clear() - - -@step('concurrent embedding requests') -@async_run_until_complete() -async def step_concurrent_embedding_requests(context): - await concurrent_requests(context, - request_embedding, - # prompt is inserted automatically - base_url=context.base_url) - - -@step('concurrent OAI embedding requests') -@async_run_until_complete() -async def step_concurrent_oai_embedding_requests(context): - await concurrent_requests(context, - request_oai_embeddings, - # prompt is inserted automatically - base_url=context.base_url, - async_client=True, - model=context.model) - - -@step('all embeddings are generated') -@async_run_until_complete() -async def all_embeddings_are_generated(context): - n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests == context.n_prompts - for i in range(n_embedding_requests): - assert_embeddings(context.tasks_result.pop().pop()) - -@step('reranking results are returned') -def reranking_results_are_returned(context): - assert len(context.reranking_results) == len(context.reranking_documents) - -@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}') -def reranking_results_are_returned(context, idx_high: int, idx_low: int): - max_score, max_idx = 0, 0 - min_score, min_idx = 0, 0 - for res in context.reranking_results: - if max_score < res['relevance_score']: - max_score = res['relevance_score'] - max_idx = res['index'] - if min_score > res['relevance_score']: - min_score = res['relevance_score'] - min_idx = res['index'] - print(context.reranking_results) - assert max_idx == idx_high - assert min_idx == idx_low - -@step('adding special tokens') -def step_tokenize_set_add_special(context): - context.tokenize_add_special = True - - -@step("tokenizing with pieces") -@async_run_until_complete -async def step_tokenize_with_pieces(context): - context.tokenized_text = context_text(context) - async with aiohttp.ClientSession() as session: - tokenize_args = {"content": context.tokenized_text, "with_pieces": True} - if getattr(context, "tokenize_add_special", None) is not None: - tokenize_args["add_special"] = context.tokenize_add_special - - async with session.post( - f"{context.base_url}/tokenize", json=tokenize_args - ) as response: - assert response.status == 200 - tokenize_json = await response.json() - context.tokens_with_pieces = tokenize_json["tokens"] - - -@step("tokens are given with pieces") -@async_run_until_complete -async def step_tokenize_with_pieces(context): - # Verify that the response contains both token IDs and pieces - assert all( - "id" in token and "piece" in token for token in context.tokens_with_pieces - ) - - -@step('tokenizing') -@async_run_until_complete -async def step_tokenize(context): - context.tokenized_text = context_text(context) - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - tokenize_args = { - "content": context.tokenized_text, - } - if getattr(context, 'tokenize_add_special', None) is not None: - tokenize_args['add_special'] = context.tokenize_add_special - async with session.post(f'{context.base_url}/tokenize', - json=tokenize_args) as response: - assert response.status == 200 - tokenize_json = await response.json() - context.tokens = tokenize_json['tokens'] - - -@step('tokens can be detokenized') -@async_run_until_complete -async def step_detokenize(context): - assert len(context.tokens) > 0 - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/detokenize', - json={ - "tokens": context.tokens, - }) as response: - assert response.status == 200 - detokenize_json = await response.json() - # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15 - assert context.tokenized_text == detokenize_json['content'].strip() - - -@step('tokens begin with BOS') -def step_strings_for_tokenization(context): - assert context.tokens[0] == context.bos - - -@step('tokens do not begin with BOS') -def step_strings_for_tokenization(context): - assert context.tokens[0] != context.bos - - -@step('first token is removed') -def step_strings_for_tokenization(context): - context.tokens = context.tokens[1:] - - -@step('an OPTIONS request is sent from {origin}') -@async_run_until_complete -async def step_options_request(context, origin): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin} - async with session.options(f'{context.base_url}/v1/chat/completions', - headers=headers) as response: - assert response.status == 200 - context.options_response = response - - -@step('CORS header {cors_header} is set to {cors_header_value}') -def step_check_options_header_value(context, cors_header, cors_header_value): - assert context.options_response.headers[cors_header] == cors_header_value - - -@step('prometheus metrics are exposed') -@async_run_until_complete -async def step_prometheus_metrics_exported(context): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with await session.get(f'{context.base_url}/metrics') as metrics_response: - assert metrics_response.status == 200 - assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" - metrics_raw = await metrics_response.text() - metric_exported = False - if context.debug: - print(f"/metrics answer:\n{metrics_raw}") - context.metrics = {} - for metric in parser.text_string_to_metric_families(metrics_raw): - match metric.name: - case "llamacpp:kv_cache_usage_ratio": - assert len(metric.samples) > 0 - metric_exported = True - context.metrics[metric.name] = metric - assert int(metrics_response.headers["Process-Start-Time-Unix"]) > 0, "no header process start time" - assert metric_exported, "No metrics exported" - - -@step('metric {metric_name} is {metric_value:d}') -def step_assert_metric_value(context, metric_name, metric_value): - if metric_name not in context.metrics: - assert False, f"no metric {metric_name} in {context.metrics.keys()}" - assert context.metrics[metric_name].samples[0].value == metric_value, f"metric: {context.metrics[metric_name]}" - - -@step('available models') -def step_available_models(context): - # openai client always expects an api_key - openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope' - openai.base_url = f'{context.base_url}/v1/' - context.models = openai.models.list().data - - -@step('{n_model:d} models are supported') -def step_supported_models(context, n_model): - if context.debug: - print("server models available:", context.models) - assert len(context.models) == n_model - - -@step('model {i_model:d} is {param} {preposition} {param_value}') -def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str): - assert i_model < len(context.models) - model = context.models[i_model] - - param_value = param_value.split(' ', 1)[0] - match param: - case 'identified': - value = model.id - case 'trained': - value = str(model.meta["n_ctx_train"]) - case _: - assert False, "param {param} not supported" - assert param_value == value, f"model param {param} {value} != {param_value}" - - -async def concurrent_requests(context, f_completion, *args, **kwargs): - context.n_prompts = len(context.prompts) - if context.debug: - print(f"starting {context.n_prompts} concurrent completion requests...") - assert context.n_prompts > 0 - seeds = await completions_seed(context) - assert seeds is not None - for prompt_no in range(context.n_prompts): - shifted_args = [context.prompts.pop(), seeds[prompt_no], *args] - context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) - await asyncio.sleep(0.01) - - -@step('the slot {slot_id:d} is saved with filename "{filename}"') -@async_run_until_complete -async def step_save_slot(context, slot_id, filename): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', - json={"filename": filename}, - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('the slot {slot_id:d} is restored with filename "{filename}"') -@async_run_until_complete -async def step_restore_slot(context, slot_id, filename): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', - json={"filename": filename}, - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('the slot {slot_id:d} is erased') -@async_run_until_complete -async def step_erase_slot(context, slot_id): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('switch {on_or_off} lora adapter {lora_id:d}') -@async_run_until_complete -async def toggle_lora_adapter(context, on_or_off: str, lora_id: int): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/lora-adapters', - json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}], - headers={"Content-Type": "application/json"}) as response: - context.response = response - print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}]) - - -@step('the server responds with status code {status_code:d}') -def step_server_responds_with_status_code(context, status_code): - assert context.response.status == status_code - - -async def request_completion(prompt, - seed, - base_url, - debug=False, - prompt_prefix=None, - prompt_suffix=None, - n_predict=None, - cache_prompt=False, - id_slot=None, - expect_api_error=None, - user_api_key=None, - temperature=None) -> int | dict[str, Any]: - if debug: - print(f"Sending completion request: {prompt}") - origin = "my.super.domain" - headers = { - 'Origin': origin - } - if user_api_key is not None: - if debug: - print(f"Set user_api_key: {user_api_key}") - headers['Authorization'] = f'Bearer {user_api_key}' - - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/completion', - json={ - "input_prefix": prompt_prefix, - "prompt": prompt, - "input_suffix": prompt_suffix, - "n_predict": n_predict if n_predict is not None else -1, - "cache_prompt": cache_prompt, - "id_slot": id_slot, - "seed": seed if seed is not None else 42, - "temperature": temperature if temperature is not None else 0.8, - "n_probs": 2, - }, - headers=headers) as response: - if expect_api_error is None or not expect_api_error: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - return await response.json() - else: - return response.status - - -async def oai_chat_completions(user_prompt, - seed, - system_prompt, - base_url: str, - base_path: str, - async_client, - debug=False, - temperature=None, - model=None, - n_predict=None, - enable_streaming=None, - response_format=None, - user_api_key=None, - expect_api_error=None) -> int | dict[str, Any]: - if debug: - print(f"Sending OAI Chat completions request: {user_prompt}") - # openai client always expects an api key - user_api_key = user_api_key if user_api_key is not None else 'nope' - seed = seed if seed is not None else 42 - enable_streaming = enable_streaming if enable_streaming is not None else False - payload = { - "messages": [ - { - "role": "system", - "content": system_prompt, - }, - { - "role": "user", - "content": user_prompt, - } - ], - "model": model, - "max_tokens": n_predict, - "stream": enable_streaming, - "temperature": temperature if temperature is not None else 0.0, - "seed": seed, - } - if response_format is not None: - payload['response_format'] = response_format - completion_response = { - 'content': '', - 'timings': { - 'predicted_n': 0, - 'prompt_n': 0 - } - } - if async_client: - origin = 'llama.cpp' - headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}{base_path}', - json=payload, - headers=headers) as response: - if enable_streaming: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "text/event-stream" - event_received = True - while event_received: - event_received = False - async for line_in_bytes in response.content: - line = line_in_bytes.decode('utf-8') - line = line.rstrip('\n').rstrip('\r') - if line == '': - continue - event_data = line.split(': ', 1) - assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```' - chunk_raw = event_data[1] - if chunk_raw == '[DONE]': - break - - chunk = json.loads(chunk_raw) - assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```" - delta = chunk['choices'][0]['delta'] - if 'content' in delta: - completion_response['content'] += delta['content'] - completion_response['timings']['predicted_n'] += 1 - else: - if expect_api_error is None or not expect_api_error: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "application/json; charset=utf-8" - chat_completion_raw = await response.json() - completion_response = { - 'content': chat_completion_raw['choices'][0]['message'], - 'timings': { - 'predicted_n': chat_completion_raw['usage']['completion_tokens'], - 'prompt_n': chat_completion_raw['usage']['prompt_tokens'] - } - } - else: - return response.status - else: - try: - openai.api_key = user_api_key - openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' - assert model is not None - chat_completion = openai.chat.completions.create( - messages=payload['messages'], - model=model, - max_tokens=n_predict, - stream=enable_streaming, - response_format=payload.get('response_format') or openai.NOT_GIVEN, - seed=seed, - temperature=payload['temperature'] - ) - except openai.AuthenticationError as e: - if expect_api_error is not None and expect_api_error: - return 401 - else: - assert False, f'error raised: {e}' - - if enable_streaming: - chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion) - for chunk in chat_completion: - assert len(chunk.choices) == 1 - delta = chunk.choices[0].delta - if delta.content is not None: - completion_response['content'] += delta.content - completion_response['timings']['predicted_n'] += 1 - completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop' - else: - assert len(chat_completion.choices) == 1 - assert chat_completion.usage is not None - completion_response = { - 'content': chat_completion.choices[0].message.content, - 'timings': { - 'predicted_n': chat_completion.usage.completion_tokens, - 'prompt_n': chat_completion.usage.prompt_tokens - }, - 'truncated': chat_completion.choices[0].finish_reason != 'stop' - } - if debug: - print("OAI response formatted to llama.cpp:", completion_response) - return completion_response - - -async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int: - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/embedding', - json={ - "content": content, - }) as response: - if response.status == 200: - response_json = await response.json() - return [response_json['embedding']] - else: - return response.status - - -async def request_oai_embeddings(input, seed, - base_url=None, user_api_key=None, - model=None, async_client=False) -> list[list[float]]: - # openai client always expects an api_key - user_api_key = user_api_key if user_api_key is not None else 'nope' - if async_client: - origin = 'llama.cpp' - headers=[] - if user_api_key is not None: - headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/v1/embeddings', - json={ - "input": input, - "model": model, - }, - headers=headers) as response: - assert response.status == 200, f"received status code not expected: {response.status}" - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "application/json; charset=utf-8" - response_json = await response.json() - assert response_json['model'] == model, f"invalid model received: {response_json['model']}" - assert response_json['object'] == 'list' - if isinstance(input, Sequence): - embeddings = [] - for an_oai_embeddings in response_json['data']: - embeddings.append(an_oai_embeddings['embedding']) - else: - embeddings = [response_json['data']['embedding']] - return embeddings - else: - openai.api_key = user_api_key - openai.base_url = f'{base_url}/v1/' - assert model is not None - oai_embeddings = openai.embeddings.create( - model=model, - input=input, - ) - - return [e.embedding for e in oai_embeddings.data] - - -def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): - content = completion_response['content'] - n_predicted = completion_response['timings']['predicted_n'] - assert len(content) > 0, "no token predicted" - if re_content is not None: - p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL) - matches = p.finditer(content) - last_match = 0 - highlighted = '' - for match in matches: - start, end = match.span() - highlighted += content[last_match: start] - highlighted += '\x1b[33m' - highlighted += content[start: end] - highlighted += '\x1b[0m' - last_match = end - highlighted += content[last_match:] - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - print(f"Checking completion response: {highlighted}") - assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' - if expected_predicted_n and expected_predicted_n > 0: - assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' - f' {n_predicted} <> {expected_predicted_n}') - -def assert_all_predictions_equal(completion_responses): - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - print(f"content {i}: {content_i}") - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - content_j = response_j['content'] - assert content_i == content_j, "contents not equal" - - -def assert_all_predictions_different(completion_responses): - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - print(f"content {i}: {content_i}") - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - content_j = response_j['content'] - assert content_i != content_j, "contents not different" - - -def assert_all_token_probabilities_equal(completion_responses): - n_predict = len(completion_responses[0]['completion_probabilities']) - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for pos in range(n_predict): - for i, response_i in enumerate(completion_responses): - probs_i = response_i['completion_probabilities'][pos]['probs'] - print(f"pos {pos}, probs {i}: {probs_i}") - for pos in range(n_predict): - for i, response_i in enumerate(completion_responses): - probs_i = response_i['completion_probabilities'][pos]['probs'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - probs_j = response_j['completion_probabilities'][pos]['probs'] - assert probs_i == probs_j, "contents not equal" - - -async def gather_tasks_results(context): - n_tasks = len(context.concurrent_tasks) - if context.debug: - print(f"Waiting for all {n_tasks} tasks results...") - for task_no in range(n_tasks): - context.tasks_result.append(await context.concurrent_tasks.pop()) - n_completions = len(context.tasks_result) - return n_completions - - -async def wait_for_slots_status(context, - base_url, - expected_http_status_code, - timeout=3, - params=None, - slots_idle=None, - slots_processing=None): - if context.debug: - print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}") - interval = 0.5 - counter = 0 - if 'GITHUB_ACTIONS' in os.environ: - timeout *= 2 - - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - while True: - headers = {'Authorization': f'Bearer {context.server_api_key}'} - async with await session.get(f'{base_url}/slots', params=params, headers=headers) as slots_response: - status_code = slots_response.status - slots = await slots_response.json() - if context.debug: - print(f"slots responses {slots}\n") - if status_code == 503 and status_code == expected_http_status_code: - return - if status_code == 200 and status_code == expected_http_status_code: - n_slots_idle = sum(1 if not slot["is_processing"] else 0 for slot in slots) - n_slots_processing = sum(1 if slot["is_processing"] else 0 for slot in slots) - if ((slots_idle is None or slots_idle == n_slots_idle) - and (slots_processing is None or slots_processing == n_slots_processing)): - return - await asyncio.sleep(interval) - - counter += interval - if counter >= timeout: - # Sometimes health requests are triggered after completions are predicted - if expected_http_status_code == 503: - if len(context.tasks_result) == 0: - print("\x1b[5;37;43mWARNING: forcing concurrent tasks," - " busy health check missed, probably too fast inference\x1b[0m\n") - n_completions = await gather_tasks_results(context) - if n_completions > 0: - return - - assert False, f'slots check timeout exceeded {counter}s>={timeout}' - - -def assert_embeddings(embeddings): - assert len(embeddings) > 0 - embeddings_computed = False - for emb in embeddings: - if not isinstance(emb, float): - assert False, f"Bad embeddings: {embeddings}" - if emb != 0: - embeddings_computed = True - assert embeddings_computed, f"Embeddings: {embeddings}" - - -async def request_slots_status(context, expected_slots): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with await session.get(f'{context.base_url}/slots') as slots_response: - assert slots_response.status == 200 - slots = await slots_response.json() - assert_slots_status(slots, expected_slots) - - -def assert_slots_status(slots, expected_slots): - assert len(slots) == len(expected_slots) - for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)): - for key in expected: - assert expected[key] == slot[key], (f"invalid slot {slot_id}" - f" expected[{key}] != slot[{key}]" - f" = {expected[key]} != {slot[key]}") - - -async def completions_seed(context, num_seeds=None): - if hasattr(context, "seed") and context.seed is not None: - assert len(context.seed) == context.n_prompts - if num_seeds is None: - num_seeds = context.n_prompts - assert num_seeds <= context.n_prompts - seeds = context.seed[:num_seeds] - context.seed = context.seed[num_seeds:] if num_seeds < context.n_prompts else None - return seeds - - if hasattr(context, "server_seed") and context.server_seed is not None: - if num_seeds is None: - return [context.server_seed] * context.n_prompts - else: - return [context.server_seed] * num_seeds - return None - - -def context_text(context): - return context.text.replace('\r', '') - - -def start_server_background(context): - if os.name == 'nt': - context.server_path = '../../../build/bin/Release/llama-server.exe' - else: - context.server_path = '../../../build/bin/llama-server' - if 'LLAMA_SERVER_BIN_PATH' in os.environ: - context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] - server_listen_addr = context.server_fqdn - server_args = [ - '--slots', # requires to get slot status via /slots endpoint - '--host', server_listen_addr, - '--port', context.server_port, - ] - if context.model_file: - server_args.extend(['--model', context.model_file]) - if context.model_url: - server_args.extend(['--model-url', context.model_url]) - if context.model_hf_repo: - server_args.extend(['--hf-repo', context.model_hf_repo]) - if context.model_hf_file: - server_args.extend(['--hf-file', context.model_hf_file]) - if context.n_batch: - server_args.extend(['--batch-size', context.n_batch]) - if context.n_ubatch: - server_args.extend(['--ubatch-size', context.n_ubatch]) - if context.n_threads: - server_args.extend(['--threads', context.threads]) - if context.n_gpu_layer: - server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) - if context.draft is not None: - server_args.extend(['--draft', context.draft]) - if context.server_continuous_batching: - server_args.append('--cont-batching') - if context.server_embeddings: - server_args.append('--embedding') - if context.server_reranking: - server_args.append('--reranking') - if context.server_metrics: - server_args.append('--metrics') - if context.model_alias: - server_args.extend(['--alias', context.model_alias]) - if context.n_ctx: - server_args.extend(['--ctx-size', context.n_ctx]) - if context.n_slots: - server_args.extend(['--parallel', context.n_slots]) - if context.n_server_predict: - server_args.extend(['--n-predict', context.n_server_predict]) - if context.slot_save_path: - server_args.extend(['--slot-save-path', context.slot_save_path]) - if context.server_api_key: - server_args.extend(['--api-key', context.server_api_key]) - if context.n_ga: - server_args.extend(['--grp-attn-n', context.n_ga]) - if context.n_ga_w: - server_args.extend(['--grp-attn-w', context.n_ga_w]) - if context.debug: - server_args.append('--verbose') - if context.lora_file: - server_args.extend(['--lora', context.lora_file]) - if context.disable_ctx_shift: - server_args.extend(['--no-context-shift']) - - args = [str(arg) for arg in [context.server_path, *server_args]] - print(f"bench: starting server with: {' '.join(args)}") - - flags = 0 - if 'nt' == os.name: - flags |= subprocess.DETACHED_PROCESS - flags |= subprocess.CREATE_NEW_PROCESS_GROUP - flags |= subprocess.CREATE_NO_WINDOW - - pkwargs = { - 'creationflags': flags, - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE - } - context.server_process = subprocess.Popen( - [str(arg) for arg in [context.server_path, *server_args]], - **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] - - def server_log(in_stream, out_stream): - for line in iter(in_stream.readline, b''): - print(line.decode('utf-8'), end='', file=out_stream) - - thread_stdout = threading.Thread(target=server_log, args=(context.server_process.stdout, sys.stdout)) - thread_stdout.start() - - thread_stderr = threading.Thread(target=server_log, args=(context.server_process.stderr, sys.stderr)) - thread_stderr.start() - - print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}") diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature deleted file mode 100644 index 61d5f315e1567..0000000000000 --- a/examples/server/tests/features/wrong_usages.feature +++ /dev/null @@ -1,25 +0,0 @@ -# run with: ./tests.sh --no-skipped --tags wrong_usage -@wrong_usage -Feature: Wrong usage of llama.cpp server - - #3969 The user must always set --n-predict option - # to cap the number of tokens any completion request can generate - # or pass n_predict/max_tokens in the request. - Scenario: Infinite loop - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And 42 as server seed - And 2048 KV cache size - # Uncomment below to fix the issue - #And 64 server max tokens to predict - Then the server is starting - Then the server is healthy - Given a prompt: - """ - Go to: infinite loop - """ - # Uncomment below to fix the issue - #And 128 max tokens to predict - Given concurrent completion requests - Then the server is idle - Then all prompts are predicted diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 5539548720ff1..935a79114b45e 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,5 +1,5 @@ aiohttp~=3.9.3 -behave~=1.2.6 +pytest~=8.3.3 huggingface_hub~=0.23.2 numpy~=1.26.4 openai~=1.30.3 diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 72a0fbad827db..1e285dcdac14b 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -4,8 +4,7 @@ set -eu if [ $# -lt 1 ] then - # Start @llama.cpp scenario - behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + pytest -v -x else - behave "$@" + pytest "$@" fi diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py new file mode 100644 index 0000000000000..84db5ca1ca192 --- /dev/null +++ b/examples/server/tests/unit/test_basic.py @@ -0,0 +1,34 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_server_start_simple(): + global server + server.start() + res = server.make_request("GET", "/health") + assert res.status_code == 200 + + +def test_server_props(): + global server + server.start() + res = server.make_request("GET", "/props") + assert res.status_code == 200 + assert res.body["total_slots"] == server.n_slots + + +def test_server_models(): + global server + server.start() + res = server.make_request("GET", "/models") + assert res.status_code == 200 + assert len(res.body["data"]) == 1 + assert res.body["data"][0]["id"] == server.model_alias diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py new file mode 100644 index 0000000000000..d7aeb288d45cc --- /dev/null +++ b/examples/server/tests/unit/test_chat_completion.py @@ -0,0 +1,129 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +@pytest.mark.parametrize( + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + [ + ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + ] +) +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + }) + assert res.status_code == 200 + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["completion_tokens"] == n_predicted + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + if truncated: + assert choice["finish_reason"] == "length" + else: + assert choice["finish_reason"] == "stop" + + +@pytest.mark.parametrize( + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated", + [ + ("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False), + ] +) +def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated): + global server + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "stream": True, + }) + content = "" + for data in res: + choice = data["choices"][0] + if choice["finish_reason"] in ["stop", "length"]: + assert data["usage"]["prompt_tokens"] == n_prompt + assert data["usage"]["completion_tokens"] == n_predicted + assert "content" not in choice["delta"] + assert match_regex(re_content, content) + # FIXME: not sure why this is incorrect in stream mode + # if truncated: + # assert choice["finish_reason"] == "length" + # else: + # assert choice["finish_reason"] == "stop" + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] + + +def test_chat_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=8, + seed=42, + temperature=0.8, + ) + print(res) + assert res.choices[0].finish_reason == "stop" + assert res.choices[0].message.content is not None + assert match_regex("(Suddenly)+", res.choices[0].message.content) + + +@pytest.mark.parametrize("response_format,n_predicted,re_content", [ + ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), + ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), + ({"type": "json_object"}, 10, "(\\{|John)+"), + ({"type": "sound"}, 0, None), + # invalid response format (expected to fail) + ({"type": "json_object", "schema": 123}, 0, None), + ({"type": "json_object", "schema": {"type": 123}}, 0, None), + ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), +]) +def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "response_format": response_format, + }) + if re_content is not None: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + assert "error" in res.body + diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py new file mode 100644 index 0000000000000..2fa30dd033431 --- /dev/null +++ b/examples/server/tests/unit/test_completion.py @@ -0,0 +1,223 @@ +import pytest +import time +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == n_prompt + assert res.body["timings"]["predicted_n"] == n_predicted + assert res.body["truncated"] == truncated + assert match_regex(re_content, res.body["content"]) + + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "stream": True, + }) + content = "" + for data in res: + if data["stop"]: + assert data["timings"]["prompt_n"] == n_prompt + assert data["timings"]["predicted_n"] == n_predicted + assert data["truncated"] == truncated + assert match_regex(re_content, content) + else: + content += data["content"] + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_consistent_result_same_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_different_result_different_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for seed in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": seed, + "temperature": 1.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] != last_res.body["content"] + last_res = res + + +@pytest.mark.parametrize("n_batch", [16, 32]) +@pytest.mark.parametrize("temperature", [0.0, 1.0]) +def test_consistent_result_different_batch_size(n_batch: int, temperature: float): + global server + server.n_batch = n_batch + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": temperature, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.skip(reason="This test fails on linux, need to be fixed") +def test_cache_vs_nocache_prompt(): + global server + server.start() + res_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": True, + }) + res_no_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res_cache.body["content"] == res_no_cache.body["content"] + + +def test_completion_with_tokens_input(): + global server + server.temperature = 0.0 + server.start() + prompt_str = "I believe the meaning of life is" + res = server.make_request("POST", "/tokenize", data={ + "content": prompt_str, + "add_special": True, + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + + # single completion + res = server.make_request("POST", "/completion", data={ + "prompt": tokens, + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + # batch completion + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, tokens], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens in one sequence + res = server.make_request("POST", "/completion", data={ + "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 3), + (2, 2), + (2, 4), + (4, 2), # some slots must be idle + (4, 6), +]) +def test_completion_parallel_slots(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.temperature = 0.0 + server.start() + + PROMPTS = [ + ("Write a very long book.", "(very|special|big)+"), + ("Write another a poem.", "(small|house)+"), + ("What is LLM?", "(Dad|said)+"), + ("The sky is blue and I love it.", "(climb|leaf)+"), + ("Write another very long music lyrics.", "(friends|step|sky)+"), + ("Write a very long joke.", "(cat|Whiskers)+"), + ] + def check_slots_status(): + should_all_slots_busy = n_requests >= n_slots + time.sleep(0.1) + res = server.make_request("GET", "/slots") + n_busy = sum([1 for slot in res.body if slot["is_processing"]]) + if should_all_slots_busy: + assert n_busy == n_slots + else: + assert n_busy <= n_slots + + tasks = [] + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42, + "temperature": 1.0, + }))) + tasks.append((check_slots_status, ())) + results = parallel_function_calls(tasks) + + # check results + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + res = results[i] + assert res.status_code == 200 + assert type(res.body["content"]) == str + assert len(res.body["content"]) > 10 + # FIXME: the result is not deterministic when using other slot than slot 0 + # assert match_regex(re_content, res.body["content"]) diff --git a/examples/server/tests/unit/test_ctx_shift.py b/examples/server/tests/unit/test_ctx_shift.py new file mode 100644 index 0000000000000..be93a6d31f410 --- /dev/null +++ b/examples/server/tests/unit/test_ctx_shift.py @@ -0,0 +1,67 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +LONG_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. +""".strip() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.n_ctx = 256 + server.n_slots = 2 + + +def test_ctx_shift_enabled(): + # the prompt is 301 tokens + # the slot context is 256/2 = 128 tokens + # the prompt is truncated to keep the last 109 tokens + # 64 tokens are generated thanks to shifting the context when it gets full + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == 109 + assert res.body["timings"]["predicted_n"] == 64 + assert res.body["truncated"] is True + + +@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ + (64, 64, False), + (-1, 120, True), +]) +def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): + global server + server.disable_ctx_shift = True + server.n_predict = -1 + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": "Hi how are you", + }) + assert res.status_code == 200 + assert res.body["timings"]["predicted_n"] == n_token_output + assert res.body["truncated"] == truncated + + +def test_ctx_shift_disabled_long_prompt(): + global server + server.disable_ctx_shift = True + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code != 200 + assert "error" in res.body + assert "exceeds the available context size" in res.body["error"]["message"] diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py new file mode 100644 index 0000000000000..fc7c20064ddfc --- /dev/null +++ b/examples/server/tests/unit/test_embedding.py @@ -0,0 +1,99 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.bert_bge_small() + +EPSILON = 1e-3 + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.bert_bge_small() + + +def test_embedding_single(): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_multiple(): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +def test_embedding_openai_library_single(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") + assert len(res.data) == 1 + assert len(res.data[0].embedding) > 1 + + +def test_embedding_openai_library_multiple(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.embeddings.create(model="text-embedding-3-small", input=[ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ]) + assert len(res.data) == 4 + for d in res.data: + assert len(d.embedding) > 1 + + +def test_embedding_error_prompt_too_long(): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "This is a test " * 512, + }) + assert res.status_code != 200 + assert "too large" in res.body["error"]["message"] + + +def test_same_prompt_give_same_result(): + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 5 + for i in range(1, len(res.body['data'])): + v0 = res.body['data'][0]['embedding'] + vi = res.body['data'][i]['embedding'] + for x, y in zip(v0, vi): + assert abs(x - y) < EPSILON diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py new file mode 100644 index 0000000000000..38ce6c42954ed --- /dev/null +++ b/examples/server/tests/unit/test_infill.py @@ -0,0 +1,35 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama_infill() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama_infill() + +def test_infill_without_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "prompt": "Complete this", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) + +def test_infill_with_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "prompt": "Complete this", + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py new file mode 100644 index 0000000000000..7496154493917 --- /dev/null +++ b/examples/server/tests/unit/test_lora.py @@ -0,0 +1,42 @@ +import pytest +import os +from utils import * + +server = ServerPreset.stories15m_moe() + +LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.stories15m_moe() + # download lora file if needed + file_name = LORA_FILE_URL.split('/').pop() + lora_file = f'../../../{file_name}' + if not os.path.exists(lora_file): + print(f"Downloading {LORA_FILE_URL} to {lora_file}") + with open(lora_file, 'wb') as f: + f.write(requests.get(LORA_FILE_URL).content) + print(f"Done downloading lora file") + server.lora_files = [lora_file] + + +@pytest.mark.parametrize("scale,re_content", [ + # without applying lora, the model should behave like a bedtime story generator + (0.0, "(little|girl|three|years|old)+"), + # with lora, the model should behave like a Shakespearean text generator + (1.0, "(eye|love|glass|sun)+"), +]) +def test_lora(scale: float, re_content: str): + global server + server.start() + res_lora_control = server.make_request("POST", "/lora-adapters", data=[ + {"id": 0, "scale": scale} + ]) + assert res_lora_control.status_code == 200 + res = server.make_request("POST", "/completion", data={ + "prompt": "Look in thy glass", + }) + assert res.status_code == 200 + assert match_regex(re_content, res.body["content"]) + diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py new file mode 100644 index 0000000000000..3a49fd3ac6bdf --- /dev/null +++ b/examples/server/tests/unit/test_rerank.py @@ -0,0 +1,38 @@ +import pytest +from utils import * + +server = ServerPreset.jina_reranker_tiny() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.jina_reranker_tiny() + + +def test_rerank(): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": [ + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + ] + }) + assert res.status_code == 200 + assert len(res.body["results"]) == 4 + + most_relevant = res.body["results"][0] + least_relevant = res.body["results"][0] + for doc in res.body["results"]: + if doc["relevance_score"] > most_relevant["relevance_score"]: + most_relevant = doc + if doc["relevance_score"] < least_relevant["relevance_score"]: + least_relevant = doc + + assert most_relevant["relevance_score"] > least_relevant["relevance_score"] + assert most_relevant["index"] == 2 + assert least_relevant["index"] == 3 diff --git a/examples/server/tests/unit/test_security.py b/examples/server/tests/unit/test_security.py new file mode 100644 index 0000000000000..620b25376bd81 --- /dev/null +++ b/examples/server/tests/unit/test_security.py @@ -0,0 +1,83 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + +TEST_API_KEY = "sk-this-is-the-secret-key" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + + +@pytest.mark.parametrize("endpoint", ["/health", "/models"]) +def test_access_public_endpoint(endpoint: str): + global server + server.start() + res = server.make_request("GET", endpoint) + assert res.status_code == 200 + assert "error" not in res.body + + +@pytest.mark.parametrize("api_key", [None, "invalid-key"]) +def test_incorrect_api_key(api_key: str): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {api_key}" if api_key else None, + }) + assert res.status_code == 401 + assert "error" in res.body + assert res.body["error"]["type"] == "authentication_error" + + +def test_correct_api_key(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + +def test_openai_library_correct_api_key(): + global server + server.start() + client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a chatbot."}, + {"role": "user", "content": "What is the meaning of life?"}, + ], + ) + assert len(res.choices) == 1 + + +@pytest.mark.parametrize("origin,cors_header,cors_header_value", [ + ("localhost", "Access-Control-Allow-Origin", "localhost"), + ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), + ("origin", "Access-Control-Allow-Credentials", "true"), + ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), + ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), +]) +def test_cors_options(origin: str, cors_header: str, cors_header_value: str): + global server + server.start() + res = server.make_request("OPTIONS", "/completions", headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization", + }) + assert res.status_code == 200 + assert cors_header in res.headers + assert res.headers[cors_header] == cors_header_value diff --git a/examples/server/tests/unit/test_slot_save.py b/examples/server/tests/unit/test_slot_save.py new file mode 100644 index 0000000000000..38704f5ece35a --- /dev/null +++ b/examples/server/tests/unit/test_slot_save.py @@ -0,0 +1,98 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.slot_save_path = "./tmp" + server.temperature = 0.0 + + +def test_slot_save_restore(): + global server + server.start() + + # First prompt in slot 1 should be fully processed + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # Save state of slot 1 + res = server.make_request("POST", "/slots/1?action=save", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_saved"] == 84 + + # Since we have cache, this should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # Loading the saved cache into slot 0 + res = server.make_request("POST", "/slots/0?action=restore", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_restored"] == 84 + + # Since we have cache, slot 0 should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 0, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # For verification that slot 1 was not corrupted during slot 0 load, same thing should work + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 1 + + +def test_slot_erase(): + global server + server.start() + + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # erase slot 1 + res = server.make_request("POST", "/slots/1?action=erase") + assert res.status_code == 200 + + # re-run the same prompt, it should process all tokens again + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed diff --git a/examples/server/tests/unit/test_tokenize.py b/examples/server/tests/unit/test_tokenize.py new file mode 100644 index 0000000000000..382457c9d602f --- /dev/null +++ b/examples/server/tests/unit/test_tokenize.py @@ -0,0 +1,59 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_tokenize_detokenize(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content + }) + assert res_tok.status_code == 200 + assert len(res_tok.body["tokens"]) > 5 + # detokenize + res_detok = server.make_request("POST", "/detokenize", data={ + "tokens": res_tok.body["tokens"], + }) + assert res_detok.status_code == 200 + assert res_detok.body["content"].strip() == content + + +def test_tokenize_with_bos(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + bosId = 1 + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "add_special": True, + }) + assert res_tok.status_code == 200 + assert res_tok.body["tokens"][0] == bosId + + +def test_tokenize_with_pieces(): + global server + server.start() + # tokenize + content = "This is a test string with unicode 媽 and emoji 🤗" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "with_pieces": True, + }) + assert res_tok.status_code == 200 + for token in res_tok.body["tokens"]: + assert "id" in token + assert token["id"] > 0 + assert "piece" in token + assert len(token["piece"]) > 0 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py new file mode 100644 index 0000000000000..bc590bcb31547 --- /dev/null +++ b/examples/server/tests/utils.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# type: ignore[reportUnusedImport] + +import subprocess +import os +import re +import json +import sys +import threading +import requests +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import ( + Any, + Callable, + ContextManager, + Iterable, + Iterator, + List, + Literal, + Tuple, + Set, +) +from re import RegexFlag + + +class ServerResponse: + headers: dict + status_code: int + body: dict | Any + + +class ServerProcess: + # default options + debug: bool = False + server_port: int = 8080 + server_host: str = "127.0.0.1" + model_hf_repo: str = "ggml-org/models" + model_hf_file: str = "tinyllamas/stories260K.gguf" + model_alias: str = "tinyllama-2" + temperature: float = 0.8 + seed: int = 42 + + # custom options + model_alias: str | None = None + model_url: str | None = None + model_file: str | None = None + n_threads: int | None = None + n_gpu_layer: int | None = None + n_batch: int | None = None + n_ubatch: int | None = None + n_ctx: int | None = None + n_ga: int | None = None + n_ga_w: int | None = None + n_predict: int | None = None + n_prompts: int | None = 0 + slot_save_path: str | None = None + id_slot: int | None = None + cache_prompt: bool | None = None + n_slots: int | None = None + server_continuous_batching: bool | None = False + server_embeddings: bool | None = False + server_reranking: bool | None = False + server_metrics: bool | None = False + draft: int | None = None + api_key: str | None = None + response_format: str | None = None + lora_files: List[str] | None = None + disable_ctx_shift: int | None = False + + # session variables + process: subprocess.Popen | None = None + + def __init__(self): + if "N_GPU_LAYERS" in os.environ: + self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"]) + if "DEBUG" in os.environ: + self.debug = True + if "PORT" in os.environ: + self.server_port = int(os.environ["PORT"]) + + def start(self, timeout_seconds: int = 10) -> None: + if "LLAMA_SERVER_BIN_PATH" in os.environ: + server_path = os.environ["LLAMA_SERVER_BIN_PATH"] + elif os.name == "nt": + server_path = "../../../build/bin/Release/llama-server.exe" + else: + server_path = "../../../build/bin/llama-server" + server_args = [ + "--slots", # requires to get slot status via /slots endpoint + "--host", + self.server_host, + "--port", + self.server_port, + "--temp", + self.temperature, + "--seed", + self.seed, + ] + if self.model_file: + server_args.extend(["--model", self.model_file]) + if self.model_url: + server_args.extend(["--model-url", self.model_url]) + if self.model_hf_repo: + server_args.extend(["--hf-repo", self.model_hf_repo]) + if self.model_hf_file: + server_args.extend(["--hf-file", self.model_hf_file]) + if self.n_batch: + server_args.extend(["--batch-size", self.n_batch]) + if self.n_ubatch: + server_args.extend(["--ubatch-size", self.n_ubatch]) + if self.n_threads: + server_args.extend(["--threads", self.n_threads]) + if self.n_gpu_layer: + server_args.extend(["--n-gpu-layers", self.n_gpu_layer]) + if self.draft is not None: + server_args.extend(["--draft", self.draft]) + if self.server_continuous_batching: + server_args.append("--cont-batching") + if self.server_embeddings: + server_args.append("--embedding") + if self.server_reranking: + server_args.append("--reranking") + if self.server_metrics: + server_args.append("--metrics") + if self.model_alias: + server_args.extend(["--alias", self.model_alias]) + if self.n_ctx: + server_args.extend(["--ctx-size", self.n_ctx]) + if self.n_slots: + server_args.extend(["--parallel", self.n_slots]) + if self.n_predict: + server_args.extend(["--n-predict", self.n_predict]) + if self.slot_save_path: + server_args.extend(["--slot-save-path", self.slot_save_path]) + if self.n_ga: + server_args.extend(["--grp-attn-n", self.n_ga]) + if self.n_ga_w: + server_args.extend(["--grp-attn-w", self.n_ga_w]) + if self.debug: + server_args.append("--verbose") + if self.lora_files: + for lora_file in self.lora_files: + server_args.extend(["--lora", lora_file]) + if self.disable_ctx_shift: + server_args.extend(["--no-context-shift"]) + if self.api_key: + server_args.extend(["--api-key", self.api_key]) + + args = [str(arg) for arg in [server_path, *server_args]] + print(f"bench: starting server with: {' '.join(args)}") + + flags = 0 + if "nt" == os.name: + flags |= subprocess.DETACHED_PROCESS + flags |= subprocess.CREATE_NEW_PROCESS_GROUP + flags |= subprocess.CREATE_NO_WINDOW + + self.process = subprocess.Popen( + [str(arg) for arg in [server_path, *server_args]], + creationflags=flags, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={**os.environ, "LLAMA_CACHE": "tmp"}, + ) + server_instances.add(self) + + def server_log(in_stream, out_stream): + for line in iter(in_stream.readline, b""): + print(line.decode("utf-8"), end="", file=out_stream) + + thread_stdout = threading.Thread( + target=server_log, args=(self.process.stdout, sys.stdout), daemon=True + ) + thread_stdout.start() + + thread_stderr = threading.Thread( + target=server_log, args=(self.process.stderr, sys.stderr), daemon=True + ) + thread_stderr.start() + + print(f"server pid={self.process.pid}, pytest pid={os.getpid()}") + + # wait for server to start + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = self.make_request("GET", "/slots", headers={ + "Authorization": f"Bearer {self.api_key}" if self.api_key else None + }) + if response.status_code == 200: + self.ready = True + return # server is ready + except Exception as e: + pass + print(f"Waiting for server to start...") + time.sleep(0.5) + raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") + + def stop(self) -> None: + server_instances.remove(self) + if self.process: + print(f"Stopping server with pid={self.process.pid}") + self.process.kill() + self.process = None + + def make_request( + self, + method: str, + path: str, + data: dict | Any | None = None, + headers: dict | None = None, + ) -> ServerResponse: + url = f"http://{self.server_host}:{self.server_port}{path}" + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers) + else: + raise ValueError(f"Unimplemented method: {method}") + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + result.body = response.json() if parse_body else None + print("Response from server", result.body) + return result + + def make_stream_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + ) -> Iterator[dict]: + url = f"http://{self.server_host}:{self.server_port}{path}" + if method == "POST": + response = requests.post(url, headers=headers, json=data, stream=True) + else: + raise ValueError(f"Unimplemented method: {method}") + for line_bytes in response.iter_lines(): + line = line_bytes.decode("utf-8") + if '[DONE]' in line: + break + elif line.startswith('data: '): + data = json.loads(line[6:]) + print("Partial response from server", data) + yield data + + +server_instances: Set[ServerProcess] = set() + + +class ServerPreset: + @staticmethod + def tinyllama2() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K.gguf" + server.model_alias = "tinyllama-2" + server.n_ctx = 256 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 64 + server.seed = 42 + return server + + @staticmethod + def bert_bge_small() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 512 + server.n_batch = 128 + server.n_ubatch = 128 + server.n_slots = 2 + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def tinyllama_infill() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_alias = "tinyllama-infill" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def stories15m_moe() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/stories15M_MOE" + server.model_hf_file = "stories15M_MOE-F16.gguf" + server.model_alias = "stories15m-moe" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def jina_reranker_tiny() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" + server.model_alias = "jina-reranker" + server.model_file = "./tmp/jina-reranker-v1-tiny-en.gguf" + server.n_ctx = 512 + server.n_batch = 512 + server.n_slots = 1 + server.seed = 42 + server.server_reranking = True + return server + + +def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: + """ + Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. + + Example usage: + + results = parallel_function_calls([ + (func1, (arg1, arg2)), + (func2, (arg3, arg4)), + ]) + """ + results = [None] * len(function_list) + exceptions = [] + + def worker(index, func, args): + try: + result = func(*args) + results[index] = result + except Exception as e: + exceptions.append((index, str(e))) + + with ThreadPoolExecutor() as executor: + futures = [] + for i, (func, args) in enumerate(function_list): + future = executor.submit(worker, i, func, args) + futures.append(future) + + # Wait for all futures to complete + for future in as_completed(futures): + pass + + # Check if there were any exceptions + if exceptions: + print("Exceptions occurred:") + for index, error in exceptions: + print(f"Function at index {index}: {error}") + + return results + + +def match_regex(regex: str, text: str) -> bool: + return ( + re.compile( + regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL + ).search(text) + is not None + )