diff --git a/python/mall/llm.py b/python/mall/llm.py index 2191ac8..0433b3c 100644 --- a/python/mall/llm.py +++ b/python/mall/llm.py @@ -50,8 +50,7 @@ def llm_call(x, msg, use, preview=False, valid_resps="", convert=None, data_type options=use.get("options"), ) - if preview: - print(call) + if preview: print(call) cache = "" if use.get("_cache") != "": diff --git a/python/tests/test_classify.py b/python/tests/test_classify.py index f31fae1..6cd286a 100644 --- a/python/tests/test_classify.py +++ b/python/tests/test_classify.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_classify(): diff --git a/python/tests/test_extract.py b/python/tests/test_extract.py index 5c58424..5190e03 100644 --- a/python/tests/test_extract.py +++ b/python/tests/test_extract.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_extract_list(): @@ -36,3 +35,11 @@ def test_extract_one(): x["extract"][0] == "You are a helpful text extraction engine. Extract the a being referred to on the text. I expect 1 item exactly. No capitalization. No explanations. The answer is based on the following text:\n{}" ) + +def test_extract_expand(): + df = pl.DataFrame(dict(x="x | y")) + df.llm.use("test", "echo", _cache="_test_cache") + x = df.llm.extract("x", ["a", "b"], expand_cols = True) + assert ( + x["a"][0] == "x " + ) \ No newline at end of file diff --git a/python/tests/test_sentiment.py b/python/tests/test_sentiment.py index a21699b..c76716c 100644 --- a/python/tests/test_sentiment.py +++ b/python/tests/test_sentiment.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_sentiment_simple(): diff --git a/python/tests/test_summarize.py b/python/tests/test_summarize.py index 6d28578..9a54273 100644 --- a/python/tests/test_summarize.py +++ b/python/tests/test_summarize.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_summarize_prompt(): diff --git a/python/tests/test_translate.py b/python/tests/test_translate.py index 1230688..6672cea 100644 --- a/python/tests/test_translate.py +++ b/python/tests/test_translate.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_translate_prompt(): diff --git a/python/tests/test_verify.py b/python/tests/test_verify.py index e520ab9..b55a0f7 100644 --- a/python/tests/test_verify.py +++ b/python/tests/test_verify.py @@ -4,8 +4,7 @@ import shutil import os -if os._exists("_test_cache"): - shutil.rmtree("_test_cache", ignore_errors=True) +if os._exists("_test_cache"): shutil.rmtree("_test_cache", ignore_errors=True) def test_verify():