Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Evangelos Lamprou <[email protected]>
  • Loading branch information
vagos committed Jul 17, 2024
1 parent e241017 commit 9df883f
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Test

on: [push, pull_request]

permissions:
contents: read

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py
- name: Cache models
uses: actions/cache@v3
with:
path: ~/.cache/torch
key: ${{ runner.os }}-torch-
- name: Install dependencies
run: |
pip install -e '.[test]'
- name: Run tests
run: |
pytest -s
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import llm
from llm.plugins import pm
import pytest


class LengthSummary(llm.Model):
model_id = "length-summary"

def execute(self, prompt, stream, response, conversation):
return ["Length: {}".format(len(prompt.prompt))]


class SimpleEmbeddings(llm.EmbeddingModel):
model_id = "simple-embeddings"

def embed_batch(self, texts):
for text in texts:
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
yield embedding


@pytest.fixture(autouse=True)
def register_models():
class ModelsPlugin:
__name__ = "ModelsPlugin"

@llm.hookimpl
def register_models(self, register):
register(LengthSummary())

@llm.hookimpl
def register_embedding_models(self, register):
register(SimpleEmbeddings())

pm.register(ModelsPlugin(), name="undo-demo-plugin")
try:
yield
finally:
pm.unregister(name="undo-demo-plugin")
97 changes: 97 additions & 0 deletions tests/test_llm_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
import click
import pytest
import sqlite_utils
import numpy as np
from llm.cli import cli
from llm import Collection
from click.testing import CliRunner


@pytest.fixture
def db_path(tmpdir):
db_path = tmpdir / "data.db"
db = sqlite_utils.Database(str(db_path))
collection = Collection("entries", db, model_id="simple-embeddings")
collection.embed_multi(
[
(1, "one word"),
(2, "two words"),
(3, "three thing"),
(4, "fourth thing"),
(5, "fifth thing"),
(6, "sixth thing"),
(7, "seventh thing"),
(8, "eighth thing"),
(9, "ninth thing"),
(10, "tenth thing"),
],
store=True,
)
return db_path


@pytest.mark.parametrize("n", (2, 5, 10))
def test_interpolate_linear(db_path, n):
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 10
runner = CliRunner()
result = runner.invoke(
cli,
[
"interpolate",
"entries",
"1",
"10",
"-n",
str(n),
"--method",
"linear",
"--database",
str(db_path),
],
)
assert result.exit_code == 0, result.output
points = json.loads(result.output)
assert len(points) == n
assert points[-1] == "10"


def test_interpolate_linear_no_db_env(monkeypatch, tmpdir):
db_path = tmpdir / "data.db"
db = sqlite_utils.Database(str(db_path))
collection = Collection("entries", db, model_id="simple-embeddings")
collection.embed_multi(
[
(1, "one word"),
(2, "two words"),
(3, "three thing"),
(4, "fourth thing"),
(5, "fifth thing"),
(6, "sixth thing"),
(7, "seventh thing"),
(8, "eighth thing"),
(9, "ninth thing"),
(10, "tenth thing"),
],
store=True,
)
monkeypatch.setenv("LLM_EMBEDDINGS_DB", str(db_path))
runner = CliRunner()
result = runner.invoke(
cli,
[
"interpolate",
"entries",
"1",
"10",
"-n",
"5",
"--method",
"linear",
],
)
assert result.exit_code == 0, result.output
points = json.loads(result.output)
assert len(points) == 5
assert points[-1] == "10"

0 comments on commit 9df883f

Please sign in to comment.