Skip to content

Commit

Permalink
Adds classify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edgararuiz committed Oct 15, 2024
1 parent c9be4ab commit 7b12d14
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
Binary file modified python/.coverage
Binary file not shown.
29 changes: 29 additions & 0 deletions python/tests/test_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import mall
import polars as pl
import pyarrow
import shutil
import os

if os._exists("_test_cache"):
shutil.rmtree("_test_cache", ignore_errors=True)


def test_classify():
df = pl.DataFrame(dict(x=["one", "two", "three"]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.classify("x", ["one", "two"])
assert (
x.select("classify").to_pandas().to_string()
== " classify\n0 one\n1 two\n2 None"
)


def test_classify_dict():
df = pl.DataFrame(dict(x=[1, 2, 3]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.classify("x", {"one": 1, "two": 2})
assert (
x.select("classify").to_pandas().to_string()
== " classify\n0 1.0\n1 2.0\n2 NaN"
)
1 change: 0 additions & 1 deletion python/tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import mall
import polars as pl
import pyarrow

import shutil
import os

Expand Down
20 changes: 7 additions & 13 deletions python/tests/test_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,28 @@
import mall
import polars as pl
import pyarrow

import shutil
import os

if os._exists("_test_cache"):
shutil.rmtree("_test_cache", ignore_errors=True)


def sim_verify():
df = pl.DataFrame(dict(x=[1,1,0,2]))
df.llm.use("test", "echo", _cache="_test_cache")
return df


def test_verify():
x = sim_verify()
x = x.llm.verify("x", "this is my question")
df = pl.DataFrame(dict(x=[1, 1, 0, 2]))
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.verify("x", "this is my question")
assert (
x.select("verify").to_pandas().to_string()
== ' verify\n0 1.0\n1 1.0\n2 0.0\n3 NaN'
== " verify\n0 1.0\n1 1.0\n2 0.0\n3 NaN"
)


def test_verify_yn():
df = pl.DataFrame(dict(x=["y", "n", "y", "x"]))
df.llm.use("test", "echo", _cache="_test_cache")
df.llm.use("test", "echo", _cache="_test_cache")
x = df.llm.verify("x", "this is my question", ["y", "n"])
assert (
x.select("verify").to_pandas().to_string()
== ' verify\n0 y\n1 n\n2 y\n3 None'
== " verify\n0 y\n1 n\n2 y\n3 None"
)

0 comments on commit 7b12d14

Please sign in to comment.