From 6b1ffedcdf0c4578febac3003109669c450a48ff Mon Sep 17 00:00:00 2001 From: NotBioWaste905 Date: Wed, 11 Dec 2024 12:53:12 +0300 Subject: [PATCH] Added test for conditions + fixed some bugs --- chatsky/conditions/llm.py | 2 +- chatsky/llm/utils.py | 3 + tests/llm/test_llm.py | 158 +++++++++++++++++++++++++++++--------- 3 files changed, 127 insertions(+), 36 deletions(-) diff --git a/chatsky/conditions/llm.py b/chatsky/conditions/llm.py index 3db5578e6..522154db9 100644 --- a/chatsky/conditions/llm.py +++ b/chatsky/conditions/llm.py @@ -50,7 +50,7 @@ async def call(self, ctx: Context) -> bool: if model.system_prompt == "": history_messages = [] else: - history_messages = [message_to_langchain(model.system_prompt, ctx=ctx, source="system")] + history_messages = [await message_to_langchain(model.system_prompt, ctx=ctx, source="system")] if not (self.history == 0 or len(ctx.responses) == 0 or len(ctx.requests) == 0): history_messages.extend( diff --git a/chatsky/llm/utils.py b/chatsky/llm/utils.py index a33fd55bc..7bfbdae40 100644 --- a/chatsky/llm/utils.py +++ b/chatsky/llm/utils.py @@ -9,6 +9,7 @@ from chatsky.core.context import Context from chatsky.core.message import Message +from chatsky.core.script_function import ConstResponse from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available from chatsky.llm.filters import BaseHistoryFilter @@ -28,6 +29,8 @@ async def message_to_langchain( check_langchain_available() if isinstance(message, str): message = Message(text=message) + if isinstance(message, ConstResponse): + message = message.root if message.text is None: content = [] diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index b70283759..a57a3177c 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -4,10 +4,12 @@ from chatsky.llm._langchain_imports import langchain_available from chatsky.llm.llm_api import LLM_API from chatsky.responses.llm import LLMResponse +from chatsky.conditions.llm import LLMCondition from chatsky.slots.llm import LLMGroupSlot, LLMSlot +from chatsky.slots.slots import SlotNotExtracted, ExtractedGroupSlot from chatsky.llm.utils import message_to_langchain, context_to_history from chatsky.llm.filters import IsImportant, FromModel -from chatsky.llm.methods import Contains, LogProb +from chatsky.llm.methods import Contains, LogProb, BaseMethod from chatsky.core.message import Message from chatsky.core.context import Context from chatsky.core.script import Node @@ -29,12 +31,40 @@ async def ainvoke(self, history: list = [""]): content=f"Mock response with history: {[message.content[0]['text'] for message in history]}" ) return response + + async def agenerate(self, history: list, logprobs=True, top_logprobs=10): + return LLMResult( + generations=[ + [ + ChatGeneration( + message=HumanMessage(content=f"Mock generation without history."), + generation_info={ + "logprobs": { + "content": [ + { + "top_logprobs": [ + {"token": "true", "logprob": 0.1}, + {"token": "false", "logprob": 0.5}, + ] + } + ] + } + }, + ) + ] + ] + ) def with_structured_output(self, message_schema): return MockedStructuredModel(root_model=message_schema) - def respond(self, history: list = [""]): + async def respond(self, history: list, message_schema=None): return self.ainvoke(history) + + async def condition(self, history: list, method: BaseMethod, return_schema=None): + result = await method(history, await self.model.agenerate(history, logprobs=True, top_logprobs=10)) + return result + class MockedStructuredModel: @@ -42,8 +72,23 @@ def __init__(self, root_model): self.root = root_model async def ainvoke(self, history): - inst = self.root(history=history) - return inst() + if isinstance(history, list): + inst = self.root(history=history) + else: + # For LLMSlot + if hasattr(self.root, "value"): + inst = self.root(value="mocked_value") + # For LLMGroupSlot + else: + inst = self.root( + name="John", + age=25, + nested={"city": "New York"} + ) + return inst + + def with_structured_output(self, message_schema): + return message_schema class MessageSchema(BaseModel): @@ -57,6 +102,29 @@ def __call__(self): def mock_structured_model(): return MockedStructuredModel +@pytest.fixture +def llmresult(): + return LLMResult( + generations=[ + [ + ChatGeneration( + message=HumanMessage(content="this is a very IMPORTANT message"), + generation_info={ + "logprobs": { + "content": [ + { + "top_logprobs": [ + {"token": "true", "logprob": 0.1}, + {"token": "false", "logprob": 0.5}, + ] + } + ] + } + }, + ) + ] + ] + ) async def test_structured_output(monkeypatch, mock_structured_model): # Create a mock LLM_API instance @@ -80,7 +148,7 @@ def mock_model(): class MockPipeline: def __init__(self, mock_model): - self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(mock_structured_model)} + self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(MockChatOpenAI)} # self.models = {"test_model": LLM_API(mock_model)} @@ -182,6 +250,22 @@ async def test_context_to_history(context): assert res == expected +async def test_conditions(context): + cond1 = LLMCondition( + model_name="test_model", + prompt="test_prompt", + method=Contains(pattern="history"), + ) + cond2 = LLMCondition( + model_name="test_model", + prompt="test_prompt", + method=Contains(pattern="abrakadabra"), + ) + assert await cond1(ctx=context) == True + assert await cond2(ctx=context) == False + + + def test_is_important_filter(filter_context): filter_func = IsImportant() ctx = filter_context @@ -205,29 +289,6 @@ def test_model_filter(filter_context): assert filter_func(ctx, ctx.requests[2], ctx.responses[3], model_name="test_model") -@pytest.fixture -def llmresult(): - return LLMResult( - generations=[ - [ - ChatGeneration( - message=HumanMessage(content="this is a very IMPORTANT message"), - generation_info={ - "logprobs": { - "content": [ - { - "top_logprobs": [ - {"token": "true", "logprob": 0.1}, - {"token": "false", "logprob": 0.5}, - ] - } - ] - } - }, - ) - ] - ] - ) async def test_base_method(llmresult): @@ -251,10 +312,37 @@ async def test_logprob_method(filter_context, llmresult): assert not await c(ctx, llmresult) -# async def test_llm_slot(pipeline, context): -# slot = LLMSlot(caption="test_caption", model="struct_model") -# res = await slot.extract_value(context) - - -# async def test_llm_group_slot(): -# pass +async def test_llm_slot(pipeline, context): + slot = LLMSlot(caption="test_caption", model="struct_model") + # Test empty request + context.add_request("") + assert isinstance(await slot.extract_value(context), SlotNotExtracted) + + # Test normal request + context.add_request("test request") + result = await slot.extract_value(context) + assert isinstance(result, str) + + +async def test_llm_group_slot(pipeline, context): + slot = LLMGroupSlot( + model="struct_model", + __pydantic_extra__={ + "name": LLMSlot(caption="Extract person's name"), + "age": LLMSlot(caption="Extract person's age"), + "nested": LLMGroupSlot( + model="struct_model", + __pydantic_extra__={ + "city": LLMSlot(caption="Extract person's city") + } + ) + } + ) + + context.add_request("John is 25 years old and lives in New York") + result = await slot.get_value(context) + + assert isinstance(result, ExtractedGroupSlot) + assert result.name.extracted_value == "John" + assert result.age.extracted_value == 25 + assert result.nested.city.extracted_value == "New York"