Skip to content

Commit

Permalink
Added test for conditions + fixed some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Dec 11, 2024
1 parent e723334 commit 6b1ffed
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 36 deletions.
2 changes: 1 addition & 1 deletion chatsky/conditions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def call(self, ctx: Context) -> bool:
if model.system_prompt == "":
history_messages = []
else:
history_messages = [message_to_langchain(model.system_prompt, ctx=ctx, source="system")]
history_messages = [await message_to_langchain(model.system_prompt, ctx=ctx, source="system")]

if not (self.history == 0 or len(ctx.responses) == 0 or len(ctx.requests) == 0):
history_messages.extend(
Expand Down
3 changes: 3 additions & 0 deletions chatsky/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from chatsky.core.context import Context
from chatsky.core.message import Message
from chatsky.core.script_function import ConstResponse
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter

Expand All @@ -28,6 +29,8 @@ async def message_to_langchain(
check_langchain_available()
if isinstance(message, str):
message = Message(text=message)
if isinstance(message, ConstResponse):
message = message.root

if message.text is None:
content = []
Expand Down
158 changes: 123 additions & 35 deletions tests/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from chatsky.llm._langchain_imports import langchain_available
from chatsky.llm.llm_api import LLM_API
from chatsky.responses.llm import LLMResponse
from chatsky.conditions.llm import LLMCondition
from chatsky.slots.llm import LLMGroupSlot, LLMSlot
from chatsky.slots.slots import SlotNotExtracted, ExtractedGroupSlot
from chatsky.llm.utils import message_to_langchain, context_to_history
from chatsky.llm.filters import IsImportant, FromModel
from chatsky.llm.methods import Contains, LogProb
from chatsky.llm.methods import Contains, LogProb, BaseMethod
from chatsky.core.message import Message
from chatsky.core.context import Context
from chatsky.core.script import Node
Expand All @@ -29,21 +31,64 @@ async def ainvoke(self, history: list = [""]):
content=f"Mock response with history: {[message.content[0]['text'] for message in history]}"
)
return response

async def agenerate(self, history: list, logprobs=True, top_logprobs=10):
return LLMResult(
generations=[
[
ChatGeneration(
message=HumanMessage(content=f"Mock generation without history."),
generation_info={
"logprobs": {
"content": [
{
"top_logprobs": [
{"token": "true", "logprob": 0.1},
{"token": "false", "logprob": 0.5},
]
}
]
}
},
)
]
]
)

def with_structured_output(self, message_schema):
return MockedStructuredModel(root_model=message_schema)

def respond(self, history: list = [""]):
async def respond(self, history: list, message_schema=None):
return self.ainvoke(history)

async def condition(self, history: list, method: BaseMethod, return_schema=None):
result = await method(history, await self.model.agenerate(history, logprobs=True, top_logprobs=10))
return result



class MockedStructuredModel:
def __init__(self, root_model):
self.root = root_model

async def ainvoke(self, history):
inst = self.root(history=history)
return inst()
if isinstance(history, list):
inst = self.root(history=history)
else:
# For LLMSlot
if hasattr(self.root, "value"):
inst = self.root(value="mocked_value")
# For LLMGroupSlot
else:
inst = self.root(
name="John",
age=25,
nested={"city": "New York"}
)
return inst

def with_structured_output(self, message_schema):
return message_schema


class MessageSchema(BaseModel):
Expand All @@ -57,6 +102,29 @@ def __call__(self):
def mock_structured_model():
return MockedStructuredModel

@pytest.fixture
def llmresult():
return LLMResult(
generations=[
[
ChatGeneration(
message=HumanMessage(content="this is a very IMPORTANT message"),
generation_info={
"logprobs": {
"content": [
{
"top_logprobs": [
{"token": "true", "logprob": 0.1},
{"token": "false", "logprob": 0.5},
]
}
]
}
},
)
]
]
)

async def test_structured_output(monkeypatch, mock_structured_model):
# Create a mock LLM_API instance
Expand All @@ -80,7 +148,7 @@ def mock_model():

class MockPipeline:
def __init__(self, mock_model):
self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(mock_structured_model)}
self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(MockChatOpenAI)}
# self.models = {"test_model": LLM_API(mock_model)}


Expand Down Expand Up @@ -182,6 +250,22 @@ async def test_context_to_history(context):
assert res == expected


async def test_conditions(context):
cond1 = LLMCondition(
model_name="test_model",
prompt="test_prompt",
method=Contains(pattern="history"),
)
cond2 = LLMCondition(
model_name="test_model",
prompt="test_prompt",
method=Contains(pattern="abrakadabra"),
)
assert await cond1(ctx=context) == True
assert await cond2(ctx=context) == False



def test_is_important_filter(filter_context):
filter_func = IsImportant()
ctx = filter_context
Expand All @@ -205,29 +289,6 @@ def test_model_filter(filter_context):
assert filter_func(ctx, ctx.requests[2], ctx.responses[3], model_name="test_model")


@pytest.fixture
def llmresult():
return LLMResult(
generations=[
[
ChatGeneration(
message=HumanMessage(content="this is a very IMPORTANT message"),
generation_info={
"logprobs": {
"content": [
{
"top_logprobs": [
{"token": "true", "logprob": 0.1},
{"token": "false", "logprob": 0.5},
]
}
]
}
},
)
]
]
)


async def test_base_method(llmresult):
Expand All @@ -251,10 +312,37 @@ async def test_logprob_method(filter_context, llmresult):
assert not await c(ctx, llmresult)


# async def test_llm_slot(pipeline, context):
# slot = LLMSlot(caption="test_caption", model="struct_model")
# res = await slot.extract_value(context)


# async def test_llm_group_slot():
# pass
async def test_llm_slot(pipeline, context):
slot = LLMSlot(caption="test_caption", model="struct_model")
# Test empty request
context.add_request("")
assert isinstance(await slot.extract_value(context), SlotNotExtracted)

# Test normal request
context.add_request("test request")
result = await slot.extract_value(context)
assert isinstance(result, str)


async def test_llm_group_slot(pipeline, context):
slot = LLMGroupSlot(
model="struct_model",
__pydantic_extra__={
"name": LLMSlot(caption="Extract person's name"),
"age": LLMSlot(caption="Extract person's age"),
"nested": LLMGroupSlot(
model="struct_model",
__pydantic_extra__={
"city": LLMSlot(caption="Extract person's city")
}
)
}
)

context.add_request("John is 25 years old and lives in New York")
result = await slot.get_value(context)

assert isinstance(result, ExtractedGroupSlot)
assert result.name.extracted_value == "John"
assert result.age.extracted_value == 25
assert result.nested.city.extracted_value == "New York"

0 comments on commit 6b1ffed

Please sign in to comment.