Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump LiteralAI version #1376

Merged
merged 15 commits into from
Oct 2, 2024
227 changes: 168 additions & 59 deletions backend/chainlit/data/literalai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import json
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast
from typing import Dict, List, Literal, Optional, Union, cast

import aiofiles
from httpx import HTTPStatusError, RequestError
from literalai import (
Attachment as LiteralAttachment,
Score as LiteralScore,
Step as LiteralStep,
Thread as LiteralThread,
)
from literalai.observability.filter import threads_filters as LiteralThreadsFilters
from literalai.observability.step import StepDict as LiteralStepDict

from chainlit.data.base import BaseDataLayer
from chainlit.data.utils import queue_until_user_message
from chainlit.element import Audio, Element, ElementDict, File, Image, Pdf, Text, Video
from chainlit.logger import logger
from chainlit.step import (
FeedbackDict,
Step,
StepDict,
StepType,
TrueStepType,
check_add_step_in_cot,
stub_step,
)
from chainlit.types import (
Feedback,
PageInfo,
Expand All @@ -14,50 +34,19 @@
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from httpx import HTTPStatusError, RequestError
from literalai import Attachment
from literalai import Score as LiteralScore
from literalai import Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
from chainlit.step import FeedbackDict, StepDict


_data_layer: Optional[BaseDataLayer] = None


class LiteralDataLayer(BaseDataLayer):
def __init__(self, api_key: str, server: Optional[str]):
from literalai import AsyncLiteralClient
class LiteralToChainlitConverter:
@classmethod
def steptype_to_steptype(cls, step_type: Optional[StepType]) -> TrueStepType:
if step_type in ["user_message", "assistant_message", "system_message"]:
return "undefined"
return cast(TrueStepType, step_type or "undefined")

self.client = AsyncLiteralClient(api_key=api_key, url=server)
logger.info("Chainlit data layer initialized")

def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
metadata = attachment.metadata or {}
return {
"chainlitKey": None,
"display": metadata.get("display", "side"),
"language": metadata.get("language"),
"autoPlay": metadata.get("autoPlay", None),
"playerConfig": metadata.get("playerConfig", None),
"page": metadata.get("page"),
"size": metadata.get("size"),
"type": metadata.get("type", "file"),
"forId": attachment.step_id,
"id": attachment.id or "",
"mime": attachment.mime,
"name": attachment.name or "",
"objectKey": attachment.object_key,
"url": attachment.url,
"threadId": attachment.thread_id,
}

def score_to_feedback_dict(
self, score: Optional[LiteralScore]
@classmethod
def score_to_feedbackdict(
cls,
score: Optional[LiteralScore],
) -> "Optional[FeedbackDict]":
if not score:
return None
Expand All @@ -68,7 +57,8 @@ def score_to_feedback_dict(
"comment": score.comment,
}

def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
@classmethod
def step_to_stepdict(cls, step: LiteralStep) -> "StepDict":
metadata = step.metadata or {}
input = (step.input or {}).get("content") or (
json.dumps(step.input) if step.input and step.input != {} else ""
Expand All @@ -95,7 +85,7 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
"id": step.id or "",
"threadId": step.thread_id or "",
"parentId": step.parent_id,
"feedback": self.score_to_feedback_dict(user_feedback),
"feedback": cls.score_to_feedbackdict(user_feedback),
"start": step.start_time,
"end": step.end_time,
"type": step.type or "undefined",
Expand All @@ -110,6 +100,116 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
"waitForAnswer": metadata.get("waitForAnswer", False),
}

@classmethod
def attachment_to_elementdict(cls, attachment: LiteralAttachment) -> ElementDict:
metadata = attachment.metadata or {}
return {
"chainlitKey": None,
"display": metadata.get("display", "side"),
"language": metadata.get("language"),
"autoPlay": metadata.get("autoPlay", None),
"playerConfig": metadata.get("playerConfig", None),
"page": metadata.get("page"),
"size": metadata.get("size"),
"type": metadata.get("type", "file"),
"forId": attachment.step_id,
"id": attachment.id or "",
"mime": attachment.mime,
"name": attachment.name or "",
"objectKey": attachment.object_key,
"url": attachment.url,
"threadId": attachment.thread_id,
}

@classmethod
def attachment_to_element(
cls, attachment: LiteralAttachment, thread_id: Optional[str] = None
) -> Element:
metadata = attachment.metadata or {}
element_type = metadata.get("type", "file")

element_class = {
"file": File,
"image": Image,
"audio": Audio,
"video": Video,
"text": Text,
"pdf": Pdf,
}.get(element_type, Element)

assert thread_id or attachment.thread_id

element = element_class(
name=attachment.name or "",
display=metadata.get("display", "side"),
language=metadata.get("language"),
size=metadata.get("size"),
url=attachment.url,
mime=attachment.mime,
thread_id=thread_id or attachment.thread_id,
)
element.id = attachment.id or ""
element.for_id = attachment.step_id
element.object_key = attachment.object_key
return element

@classmethod
def step_to_step(cls, step: LiteralStep) -> Step:
chainlit_step = Step(
name=step.name or "",
type=cls.steptype_to_steptype(step.type),
id=step.id,
parent_id=step.parent_id,
thread_id=step.thread_id or None,
)
chainlit_step.start = step.start_time
chainlit_step.end = step.end_time
chainlit_step.created_at = step.created_at
chainlit_step.input = step.input.get("content", "") if step.input else ""
chainlit_step.output = step.output.get("content", "") if step.output else ""
chainlit_step.is_error = bool(step.error)
chainlit_step.metadata = step.metadata or {}
chainlit_step.tags = step.tags
chainlit_step.generation = step.generation

if step.attachments:
chainlit_step.elements = [
cls.attachment_to_element(attachment, chainlit_step.thread_id)
for attachment in step.attachments
]

return chainlit_step

@classmethod
def thread_to_threaddict(cls, thread: LiteralThread) -> ThreadDict:
return {
"id": thread.id,
"createdAt": getattr(thread, "created_at", ""),
"name": thread.name,
"userId": thread.participant_id,
"userIdentifier": thread.participant_identifier,
"tags": thread.tags,
"metadata": thread.metadata,
"steps": [cls.step_to_stepdict(step) for step in thread.steps]
if thread.steps
else [],
"elements": [
cls.attachment_to_elementdict(attachment)
for step in thread.steps
for attachment in step.attachments
]
if thread.steps
else [],
}


class LiteralDataLayer(BaseDataLayer):
def __init__(self, api_key: str, server: Optional[str]):
from literalai import AsyncLiteralClient

self.client = AsyncLiteralClient(api_key=api_key, url=server)
logger.info("Chainlit data layer initialized")

async def build_debug_url(self) -> str:
try:
project_id = await self.client.api.get_my_project_id()
Expand Down Expand Up @@ -239,7 +339,7 @@ async def get_element(
attachment = await self.client.api.get_attachment(id=element_id)
if not attachment:
return None
return self.attachment_to_element_dict(attachment)
return LiteralToChainlitConverter.attachment_to_elementdict(attachment)

@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
Expand Down Expand Up @@ -339,32 +439,41 @@ async def list_threads(
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)

chainlit_threads = [
*map(LiteralToChainlitConverter.thread_to_threaddict, literal_response.data)
]

return PaginatedResponse(
pageInfo=PageInfo(
hasNextPage=literal_response.pageInfo.hasNextPage,
startCursor=literal_response.pageInfo.startCursor,
endCursor=literal_response.pageInfo.endCursor,
hasNextPage=literal_response.page_info.has_next_page,
startCursor=literal_response.page_info.start_cursor,
endCursor=literal_response.page_info.end_cursor,
),
data=literal_response.data,
data=chainlit_threads,
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
from chainlit.step import check_add_step_in_cot, stub_step

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
thread = await self.client.api.get_thread(id=thread_id)
if not thread:
return None
elements = [] # List[ElementDict]
steps = [] # List[StepDict]

elements: List[ElementDict] = []
steps: List[StepDict] = []
if thread.steps:
for step in thread.steps:
for attachment in step.attachments:
elements.append(self.attachment_to_element_dict(attachment))

if check_add_step_in_cot(step):
steps.append(self.step_to_step_dict(step))
elements.append(
LiteralToChainlitConverter.attachment_to_elementdict(attachment)
)

chainlit_step = LiteralToChainlitConverter.step_to_step(step)
if check_add_step_in_cot(chainlit_step):
steps.append(
LiteralToChainlitConverter.step_to_stepdict(step)
) # TODO: chainlit_step.to_dict()
else:
steps.append(stub_step(step))
steps.append(stub_step(chainlit_step))

return {
"createdAt": thread.created_at or "",
Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class ElementDict(TypedDict):

@dataclass
class Element:
# Thread id
thread_id: str = Field(default_factory=lambda: context.session.thread_id)
# The type of the element. This will be used to determine how to display the element in the UI.
type: ClassVar[ElementType]
# Name of the element, this will be used to reference the element in the UI.
Expand Down Expand Up @@ -88,7 +90,6 @@ def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.persisted = False
self.updatable = False
self.thread_id = context.session.thread_id

if not self.url and not self.path and not self.content:
raise ValueError("Must provide url, path or content to instantiate element")
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from literalai.step import TrueStepType
from literalai.observability.step import TrueStepType

DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]

Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FileDict,
)
from literalai.helper import utc_now
from literalai.step import MessageStepType
from literalai.observability.step import MessageStepType


class MessageBase(ABC):
Expand Down
7 changes: 4 additions & 3 deletions backend/chainlit/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from chainlit.types import FeedbackDict
from literalai import BaseGeneration
from literalai.helper import utc_now
from literalai.step import StepType, TrueStepType
from literalai.observability.step import StepType, TrueStepType


def check_add_step_in_cot(step: "Step"):
Expand All @@ -30,7 +30,7 @@ def check_add_step_in_cot(step: "Step"):
return True


def stub_step(step: "Step"):
def stub_step(step: "Step") -> "StepDict":
return {
"type": step.type,
"name": step.name,
Expand Down Expand Up @@ -189,12 +189,13 @@ def __init__(
tags: Optional[List[str]] = None,
language: Optional[str] = None,
show_input: Union[bool, str] = "json",
thread_id: Optional[str] = None,
):
trace_event(f"init {self.__class__.__name__} {type}")
time.sleep(0.001)
self._input = ""
self._output = ""
self.thread_id = context.session.thread_id
self.thread_id = thread_id or context.session.thread_id
self.name = name or ""
self.type = type
self.id = id or str(uuid.uuid4())
Expand Down
Loading
Loading