Skip to content

Commit

Permalink
Working on Prompt rework
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Dec 23, 2024
1 parent 968fe75 commit 8bc71ce
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion chatsky/conditions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from chatsky.core import BaseCondition, Context
from chatsky.core.script_function import AnyResponse
from chatsky.llm.methods import BaseMethod
from chatsky.llm.utils import context_to_history, message_to_langchain
from chatsky.llm.langchain_context import context_to_history, message_to_langchain
from chatsky.llm.filters import BaseHistoryFilter, DefaultFilter


Expand Down
13 changes: 13 additions & 0 deletions chatsky/llm/utils.py → chatsky/llm/langchain_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from chatsky.core.script_function import ConstResponse
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter
from chatsky.llm.prompt import Prompt, DesaultPositionConfig


async def message_to_langchain(
Expand Down Expand Up @@ -80,3 +81,15 @@ async def context_to_history(
history.append(await message_to_langchain(req, ctx=ctx, max_size=max_size))
history.append(await message_to_langchain(resp, ctx=ctx, source="ai", max_size=max_size))
return history

# get a list of messages to pass to LLM from context and prompts
# called in LLM_API
def get_langchain_context(
system_prompt: Prompt,
ctx: Context,
call_prompt,
prompt_misc_filter: str=r"prompt", # r"prompt" -> extract misc prompts
postition_config: DesaultPositionConfig=DesaultPositionConfig,
**history_args,
):
history = context_to_history(ctx, history_args)
32 changes: 32 additions & 0 deletions chatsky/llm/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Union
from pydantic import BaseModel, model_validator
from chatsky.core import BaseResponse, BasePriority, AnyPriority, AnyResponse, MessageInitTypes, Message


class DesaultPositionConfig(BaseModel):
system_prompt: float = 0
history: float = 1
misc_prompt: float = 2
call_prompt: float = 3
last_response: float = 4


class Prompt(BaseModel):
prompt: AnyResponse
position: Optional[AnyPriority] = None

def __init__(
self,
prompt: Union[MessageInitTypes, BaseResponse],
position: Optional[Union[float, BasePriority]] = None
):
super().__init__(prompt=prompt, position=position)

@model_validator(mode="before")
@classmethod
def validate_from_message(cls, data):
# MISC: {"prompt": "message", "prompt": Message("text"), "prompt": FilledTemplate(), "prompt": Prompt(prompt=FilledTemplate(), position=-2)
# Prompt.model_validate
if isinstance(data, (str, Message, BaseResponse)):
return {"prompt": data}
return data
2 changes: 1 addition & 1 deletion chatsky/responses/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from chatsky.core.message import Message
from chatsky.core.context import Context
from chatsky.llm.utils import message_to_langchain, context_to_history
from chatsky.llm.langchain_context import message_to_langchain, context_to_history
from chatsky.llm._langchain_imports import check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter, DefaultFilter
from chatsky.core.script_function import BaseResponse, AnyResponse
Expand Down
2 changes: 1 addition & 1 deletion tests/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chatsky.conditions.llm import LLMCondition
from chatsky.slots.llm import LLMSlot, LLMGroupSlot
from chatsky.slots.slots import SlotNotExtracted, ExtractedGroupSlot
from chatsky.llm.utils import message_to_langchain, context_to_history
from chatsky.llm.langchain_context import message_to_langchain, context_to_history
from chatsky.llm.filters import IsImportant, FromModel
from chatsky.llm.methods import Contains, LogProb, BaseMethod
from chatsky.core.message import Message
Expand Down

0 comments on commit 8bc71ce

Please sign in to comment.