diff --git a/backend/chainlit/data/literalai.py b/backend/chainlit/data/literalai.py index 5572dbf4a9..7b2e39fbed 100644 --- a/backend/chainlit/data/literalai.py +++ b/backend/chainlit/data/literalai.py @@ -1,10 +1,30 @@ import json -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast +from typing import Dict, List, Literal, Optional, Union, cast import aiofiles +from httpx import HTTPStatusError, RequestError +from literalai import ( + Attachment as LiteralAttachment, + Score as LiteralScore, + Step as LiteralStep, + Thread as LiteralThread, +) +from literalai.observability.filter import threads_filters as LiteralThreadsFilters +from literalai.observability.step import StepDict as LiteralStepDict + from chainlit.data.base import BaseDataLayer from chainlit.data.utils import queue_until_user_message +from chainlit.element import Audio, Element, ElementDict, File, Image, Pdf, Text, Video from chainlit.logger import logger +from chainlit.step import ( + FeedbackDict, + Step, + StepDict, + StepType, + TrueStepType, + check_add_step_in_cot, + stub_step, +) from chainlit.types import ( Feedback, PageInfo, @@ -14,50 +34,19 @@ ThreadFilter, ) from chainlit.user import PersistedUser, User -from httpx import HTTPStatusError, RequestError -from literalai import Attachment -from literalai import Score as LiteralScore -from literalai import Step as LiteralStep -from literalai.filter import threads_filters as LiteralThreadsFilters -from literalai.step import StepDict as LiteralStepDict - -if TYPE_CHECKING: - from chainlit.element import Element, ElementDict - from chainlit.step import FeedbackDict, StepDict -_data_layer: Optional[BaseDataLayer] = None - - -class LiteralDataLayer(BaseDataLayer): - def __init__(self, api_key: str, server: Optional[str]): - from literalai import AsyncLiteralClient +class LiteralToChainlitConverter: + @classmethod + def steptype_to_steptype(cls, step_type: Optional[StepType]) -> TrueStepType: + if step_type in ["user_message", "assistant_message", "system_message"]: + return "undefined" + return cast(TrueStepType, step_type or "undefined") - self.client = AsyncLiteralClient(api_key=api_key, url=server) - logger.info("Chainlit data layer initialized") - - def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": - metadata = attachment.metadata or {} - return { - "chainlitKey": None, - "display": metadata.get("display", "side"), - "language": metadata.get("language"), - "autoPlay": metadata.get("autoPlay", None), - "playerConfig": metadata.get("playerConfig", None), - "page": metadata.get("page"), - "size": metadata.get("size"), - "type": metadata.get("type", "file"), - "forId": attachment.step_id, - "id": attachment.id or "", - "mime": attachment.mime, - "name": attachment.name or "", - "objectKey": attachment.object_key, - "url": attachment.url, - "threadId": attachment.thread_id, - } - - def score_to_feedback_dict( - self, score: Optional[LiteralScore] + @classmethod + def score_to_feedbackdict( + cls, + score: Optional[LiteralScore], ) -> "Optional[FeedbackDict]": if not score: return None @@ -68,7 +57,8 @@ def score_to_feedback_dict( "comment": score.comment, } - def step_to_step_dict(self, step: LiteralStep) -> "StepDict": + @classmethod + def step_to_stepdict(cls, step: LiteralStep) -> "StepDict": metadata = step.metadata or {} input = (step.input or {}).get("content") or ( json.dumps(step.input) if step.input and step.input != {} else "" @@ -95,7 +85,7 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict": "id": step.id or "", "threadId": step.thread_id or "", "parentId": step.parent_id, - "feedback": self.score_to_feedback_dict(user_feedback), + "feedback": cls.score_to_feedbackdict(user_feedback), "start": step.start_time, "end": step.end_time, "type": step.type or "undefined", @@ -110,6 +100,116 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict": "waitForAnswer": metadata.get("waitForAnswer", False), } + @classmethod + def attachment_to_elementdict(cls, attachment: LiteralAttachment) -> ElementDict: + metadata = attachment.metadata or {} + return { + "chainlitKey": None, + "display": metadata.get("display", "side"), + "language": metadata.get("language"), + "autoPlay": metadata.get("autoPlay", None), + "playerConfig": metadata.get("playerConfig", None), + "page": metadata.get("page"), + "size": metadata.get("size"), + "type": metadata.get("type", "file"), + "forId": attachment.step_id, + "id": attachment.id or "", + "mime": attachment.mime, + "name": attachment.name or "", + "objectKey": attachment.object_key, + "url": attachment.url, + "threadId": attachment.thread_id, + } + + @classmethod + def attachment_to_element( + cls, attachment: LiteralAttachment, thread_id: Optional[str] = None + ) -> Element: + metadata = attachment.metadata or {} + element_type = metadata.get("type", "file") + + element_class = { + "file": File, + "image": Image, + "audio": Audio, + "video": Video, + "text": Text, + "pdf": Pdf, + }.get(element_type, Element) + + assert thread_id or attachment.thread_id + + element = element_class( + name=attachment.name or "", + display=metadata.get("display", "side"), + language=metadata.get("language"), + size=metadata.get("size"), + url=attachment.url, + mime=attachment.mime, + thread_id=thread_id or attachment.thread_id, + ) + element.id = attachment.id or "" + element.for_id = attachment.step_id + element.object_key = attachment.object_key + return element + + @classmethod + def step_to_step(cls, step: LiteralStep) -> Step: + chainlit_step = Step( + name=step.name or "", + type=cls.steptype_to_steptype(step.type), + id=step.id, + parent_id=step.parent_id, + thread_id=step.thread_id or None, + ) + chainlit_step.start = step.start_time + chainlit_step.end = step.end_time + chainlit_step.created_at = step.created_at + chainlit_step.input = step.input.get("content", "") if step.input else "" + chainlit_step.output = step.output.get("content", "") if step.output else "" + chainlit_step.is_error = bool(step.error) + chainlit_step.metadata = step.metadata or {} + chainlit_step.tags = step.tags + chainlit_step.generation = step.generation + + if step.attachments: + chainlit_step.elements = [ + cls.attachment_to_element(attachment, chainlit_step.thread_id) + for attachment in step.attachments + ] + + return chainlit_step + + @classmethod + def thread_to_threaddict(cls, thread: LiteralThread) -> ThreadDict: + return { + "id": thread.id, + "createdAt": getattr(thread, "created_at", ""), + "name": thread.name, + "userId": thread.participant_id, + "userIdentifier": thread.participant_identifier, + "tags": thread.tags, + "metadata": thread.metadata, + "steps": [cls.step_to_stepdict(step) for step in thread.steps] + if thread.steps + else [], + "elements": [ + cls.attachment_to_elementdict(attachment) + for step in thread.steps + for attachment in step.attachments + ] + if thread.steps + else [], + } + + +class LiteralDataLayer(BaseDataLayer): + def __init__(self, api_key: str, server: Optional[str]): + from literalai import AsyncLiteralClient + + self.client = AsyncLiteralClient(api_key=api_key, url=server) + logger.info("Chainlit data layer initialized") + async def build_debug_url(self) -> str: try: project_id = await self.client.api.get_my_project_id() @@ -239,7 +339,7 @@ async def get_element( attachment = await self.client.api.get_attachment(id=element_id) if not attachment: return None - return self.attachment_to_element_dict(attachment) + return LiteralToChainlitConverter.attachment_to_elementdict(attachment) @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): @@ -339,32 +439,41 @@ async def list_threads( filters=literal_filters, order_by={"column": "createdAt", "direction": "DESC"}, ) + + chainlit_threads = [ + *map(LiteralToChainlitConverter.thread_to_threaddict, literal_response.data) + ] + return PaginatedResponse( pageInfo=PageInfo( - hasNextPage=literal_response.pageInfo.hasNextPage, - startCursor=literal_response.pageInfo.startCursor, - endCursor=literal_response.pageInfo.endCursor, + hasNextPage=literal_response.page_info.has_next_page, + startCursor=literal_response.page_info.start_cursor, + endCursor=literal_response.page_info.end_cursor, ), - data=literal_response.data, + data=chainlit_threads, ) - async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": - from chainlit.step import check_add_step_in_cot, stub_step - + async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: thread = await self.client.api.get_thread(id=thread_id) if not thread: return None - elements = [] # List[ElementDict] - steps = [] # List[StepDict] + + elements: List[ElementDict] = [] + steps: List[StepDict] = [] if thread.steps: for step in thread.steps: for attachment in step.attachments: - elements.append(self.attachment_to_element_dict(attachment)) - - if check_add_step_in_cot(step): - steps.append(self.step_to_step_dict(step)) + elements.append( + LiteralToChainlitConverter.attachment_to_elementdict(attachment) + ) + + chainlit_step = LiteralToChainlitConverter.step_to_step(step) + if check_add_step_in_cot(chainlit_step): + steps.append( + LiteralToChainlitConverter.step_to_stepdict(step) + ) # TODO: chainlit_step.to_dict() else: - steps.append(stub_step(step)) + steps.append(stub_step(chainlit_step)) return { "createdAt": thread.created_at or "", diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 3b58b38714..56384581ec 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -57,6 +57,8 @@ class ElementDict(TypedDict): @dataclass class Element: + # Thread id + thread_id: str = Field(default_factory=lambda: context.session.thread_id) # The type of the element. This will be used to determine how to display the element in the UI. type: ClassVar[ElementType] # Name of the element, this will be used to reference the element in the UI. @@ -88,7 +90,6 @@ def __post_init__(self) -> None: trace_event(f"init {self.__class__.__name__}") self.persisted = False self.updatable = False - self.thread_id = context.session.thread_id if not self.url and not self.path and not self.content: raise ValueError("Must provide url, path or content to instantiate element") diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index fde942c358..147f9de413 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -12,7 +12,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from literalai import ChatGeneration, CompletionGeneration, GenerationMessage from literalai.helper import utc_now -from literalai.step import TrueStepType +from literalai.observability.step import TrueStepType DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 637fa44957..7bb590f8be 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -23,7 +23,7 @@ FileDict, ) from literalai.helper import utc_now -from literalai.step import MessageStepType +from literalai.observability.step import MessageStepType class MessageBase(ABC): diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index bb8b7634f9..3d67ae7b11 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -16,7 +16,7 @@ from chainlit.types import FeedbackDict from literalai import BaseGeneration from literalai.helper import utc_now -from literalai.step import StepType, TrueStepType +from literalai.observability.step import StepType, TrueStepType def check_add_step_in_cot(step: "Step"): @@ -30,7 +30,7 @@ def check_add_step_in_cot(step: "Step"): return True -def stub_step(step: "Step"): +def stub_step(step: "Step") -> "StepDict": return { "type": step.type, "name": step.name, @@ -189,12 +189,13 @@ def __init__( tags: Optional[List[str]] = None, language: Optional[str] = None, show_input: Union[bool, str] = "json", + thread_id: Optional[str] = None, ): trace_event(f"init {self.__class__.__name__} {type}") time.sleep(0.001) self._input = "" self._output = "" - self.thread_id = context.session.thread_id + self.thread_id = thread_id or context.session.thread_id self.name = name or "" self.type = type self.id = id or str(uuid.uuid4()) diff --git a/backend/poetry.lock b/backend/poetry.lock index 345bd1e430..4c717509ad 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -2313,12 +2313,12 @@ testing = ["packaging", "pytest"] [[package]] name = "literalai" -version = "0.0.607" +version = "0.0.623" description = "An SDK for observability in Python applications" optional = false python-versions = "*" files = [ - {file = "literalai-0.0.607.tar.gz", hash = "sha256:783c495d9fb2ae9f84e5877ecb34587f6a6ac0224d9f0936336dbc8b6765b30d"}, + {file = "literalai-0.0.623.tar.gz", hash = "sha256:d65c04dde6b1e99d585e4112a607e5fd574d282b70f600c55a671018340dfb0f"}, ] [package.dependencies] @@ -3332,6 +3332,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -3352,6 +3353,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -3363,8 +3365,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5538,4 +5540,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0.0" -content-hash = "abd6fcef6a72a72b26f8c5842a877c57fd32414a729f8c6b1ec37a488b217b09" +content-hash = "ef9341345f921f6b78cccbcf94fd539d9d0814428d40de0ade9a244362616e04" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 7b02912a16..5f6deda7ae 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -27,7 +27,7 @@ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] python = ">=3.9,<4.0.0" httpx = ">=0.23.0" -literalai = "0.0.607" +literalai = "0.0.623" dataclasses_json = "^0.6.7" fastapi = ">=0.110.1,<0.113" starlette = "^0.37.2" @@ -127,3 +127,10 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" + + +[tool.ruff.lint] +select = ["I"] + +[tool.ruff.lint.isort] +combine-as-imports = true diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index cb76427d0b..94ae4a596a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,3 +1,4 @@ +import datetime from contextlib import asynccontextmanager from unittest.mock import AsyncMock, Mock @@ -10,10 +11,12 @@ @pytest.fixture -def mock_persisted_user(): - mock = Mock(spec=PersistedUser) - mock.id = "test_user_id" - return mock +def persisted_test_user(): + return PersistedUser( + id="test_user_id", + createdAt=datetime.datetime.now().isoformat(), + identifier="test_user_identifier", + ) @pytest.fixture @@ -44,8 +47,8 @@ async def create_chainlit_context(mock_session): @pytest_asyncio.fixture -async def mock_chainlit_context(mock_persisted_user, mock_session): - mock_session.user = mock_persisted_user +async def mock_chainlit_context(persisted_test_user, mock_session): + mock_session.user = persisted_test_user return create_chainlit_context(mock_session) diff --git a/backend/tests/data/conftest.py b/backend/tests/data/conftest.py index 0e03190426..19177756fc 100644 --- a/backend/tests/data/conftest.py +++ b/backend/tests/data/conftest.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock from chainlit.data.base import BaseStorageClient +from chainlit.user import User @pytest.fixture @@ -13,3 +14,8 @@ def mock_storage_client(): "object_key": "test_user/test_element/test.txt", } return mock_client + + +@pytest.fixture +def test_user() -> User: + return User(identifier="test_user_identifier", metadata={}) diff --git a/backend/tests/data/test_literalai.py b/backend/tests/data/test_literalai.py new file mode 100644 index 0000000000..ae9a2ee93b --- /dev/null +++ b/backend/tests/data/test_literalai.py @@ -0,0 +1,1060 @@ +import datetime +import uuid +from unittest.mock import ANY, AsyncMock, Mock, patch + +import pytest +from httpx import HTTPStatusError, RequestError +from literalai import ( + AsyncLiteralClient, + Attachment, + Attachment as LiteralAttachment, + PageInfo, + PaginatedResponse, + Score as LiteralScore, + Step as LiteralStep, + Thread, + Thread as LiteralThread, + User as LiteralUser, + UserDict, +) +from literalai.api import AsyncLiteralAPI +from literalai.observability.step import ( + AttachmentDict as LiteralAttachmentDict, + StepDict as LiteralStepDict, +) +from literalai.observability.thread import ThreadDict as LiteralThreadDict + +from chainlit.data.literalai import LiteralDataLayer, LiteralToChainlitConverter +from chainlit.element import Audio, File, Image, Pdf, Text, Video +from chainlit.step import Step, StepDict +from chainlit.types import ( + Feedback, + Pagination, + ThreadFilter, +) +from chainlit.user import PersistedUser, User + + +@pytest.fixture +async def mock_literal_client(monkeypatch: pytest.MonkeyPatch): + client = Mock(spec=AsyncLiteralClient) + client.api = AsyncMock(spec=AsyncLiteralAPI) + monkeypatch.setattr("literalai.AsyncLiteralClient", client) + yield client + + +@pytest.fixture +async def literal_data_layer(mock_literal_client): + data_layer = LiteralDataLayer(api_key="fake_api_key", server="https://fake.server") + data_layer.client = mock_literal_client + return data_layer + + +@pytest.fixture +def test_thread(): + return LiteralThread.from_dict( + { + "id": "test_thread_id", + "name": "Test Thread", + "createdAt": "2023-01-01T00:00:00Z", + "metadata": {}, + "participant": {}, + "steps": [], + "tags": [], + } + ) + + +@pytest.fixture +def test_step_dict(test_thread) -> StepDict: + return { + "createdAt": "2023-01-01T00:00:00Z", + "start": "2023-01-01T00:00:00Z", + "end": "2023-01-01T00:00:00Z", + "generation": {}, + "id": "test_step_id", + "name": "Test Step", + "threadId": test_thread.id, + "type": "user_message", + "tags": [], + "metadata": {"key": "value"}, + "input": "test input", + "output": "test output", + "waitForAnswer": True, + "showInput": True, + "language": "en", + } + + +@pytest.fixture +def test_step(test_thread: LiteralThread): + return LiteralStep.from_dict( + { + "id": str(uuid.uuid4()), + "name": "Test Step", + "type": "user_message", + "environment": None, + "threadId": test_thread.id, + "error": None, + "input": {}, + "output": {}, + "metadata": {}, + "tags": [], + "parentId": None, + "createdAt": "2023-01-01T00:00:00Z", + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:00Z", + "generation": {}, + "scores": [], + "attachments": [], + "rootRunId": None, + } + ) + + +@pytest.fixture +def literal_test_user(test_user: User): + return LiteralUser( + id=str(uuid.uuid4()), + created_at=datetime.datetime.now().isoformat(), + identifier=test_user.identifier, + metadata=test_user.metadata, + ) + + +@pytest.fixture +def test_filters() -> ThreadFilter: + return ThreadFilter(feedback=1, userId="user1", search="test") + + +@pytest.fixture +def test_pagination() -> Pagination: + return Pagination(first=10, cursor=None) + + +@pytest.fixture +def test_attachment( + test_thread: LiteralThread, test_step: LiteralStep +) -> LiteralAttachment: + return Attachment( + id="test_attachment_id", + step_id=test_step.id, + thread_id=test_thread.id, + metadata={ + "display": "side", + "language": "python", + "type": "file", + }, + mime="text/plain", + name="test_file.txt", + object_key="test_object_key", + url="https://example.com/test_file.txt", + ) + + +async def test_create_step( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_step_dict: StepDict, + mock_chainlit_context, +): + async with mock_chainlit_context: + await literal_data_layer.create_step(test_step_dict) + + mock_literal_client.api.send_steps.assert_awaited_once_with( + [ + { + "createdAt": "2023-01-01T00:00:00Z", + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:00Z", + "generation": {}, + "id": "test_step_id", + "parentId": None, + "name": "Test Step", + "threadId": "test_thread_id", + "type": "user_message", + "tags": [], + "metadata": { + "key": "value", + "waitForAnswer": True, + "language": "en", + "showInput": True, + }, + "input": {"content": "test input"}, + "output": {"content": "test output"}, + } + ] + ) + + +async def test_safely_send_steps_success( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, +): + test_steps = [{"id": "test_step_id", "name": "Test Step"}] + + async with mock_chainlit_context: + await literal_data_layer.safely_send_steps(test_steps) + + mock_literal_client.api.send_steps.assert_awaited_once_with(test_steps) + + +async def test_safely_send_steps_http_status_error( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, + caplog, +): + test_steps = [{"id": "test_step_id", "name": "Test Step"}] + mock_literal_client.api.send_steps.side_effect = HTTPStatusError( + "HTTP Error", request=Mock(), response=Mock(status_code=500) + ) + + async with mock_chainlit_context: + await literal_data_layer.safely_send_steps(test_steps) + + mock_literal_client.api.send_steps.assert_awaited_once_with(test_steps) + assert "HTTP Request: error sending steps: 500" in caplog.text + + +async def test_safely_send_steps_request_error( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, + caplog, +): + test_steps = [{"id": "test_step_id", "name": "Test Step"}] + mock_request = Mock() + mock_request.url = "https://example.com/api" + mock_literal_client.api.send_steps.side_effect = RequestError( + "Request Error", request=mock_request + ) + + async with mock_chainlit_context: + await literal_data_layer.safely_send_steps(test_steps) + + mock_literal_client.api.send_steps.assert_awaited_once_with(test_steps) + assert "HTTP Request: error for 'https://example.com/api'." in caplog.text + + +async def test_get_user( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + literal_test_user: LiteralUser, + persisted_test_user: PersistedUser, +): + mock_literal_client.api.get_user.return_value = literal_test_user + + user = await literal_data_layer.get_user("test_user_id") + + assert user is not None + assert user.id == literal_test_user.id + assert user.identifier == literal_test_user.identifier + + mock_literal_client.api.get_user.assert_awaited_once_with(identifier="test_user_id") + + +async def test_get_user_not_found( + literal_data_layer: LiteralDataLayer, mock_literal_client: Mock +): + mock_literal_client.api.get_user.return_value = None + + user = await literal_data_layer.get_user("non_existent_user_id") + + assert user is None + mock_literal_client.api.get_user.assert_awaited_once_with( + identifier="non_existent_user_id" + ) + + +async def test_create_user_not_existing( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_user: User, + literal_test_user: LiteralUser, +): + mock_literal_client.api.get_user.return_value = None + mock_literal_client.api.create_user.return_value = literal_test_user + + persisted_user = await literal_data_layer.create_user(test_user) + + mock_literal_client.api.create_user.assert_awaited_once_with( + identifier=test_user.identifier, metadata=test_user.metadata + ) + + assert persisted_user is not None + assert isinstance(persisted_user, PersistedUser) + assert persisted_user.id == literal_test_user.id + assert persisted_user.identifier == literal_test_user.identifier + + +async def test_create_user_update_existing( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_user: User, + literal_test_user: LiteralUser, + persisted_test_user: PersistedUser, +): + mock_literal_client.api.get_user.return_value = literal_test_user + + persisted_user = await literal_data_layer.create_user(test_user) + + mock_literal_client.api.create_user.assert_not_called() + mock_literal_client.api.update_user.assert_awaited_once_with( + id=literal_test_user.id, metadata=test_user.metadata + ) + + assert persisted_user is not None + assert isinstance(persisted_user, PersistedUser) + assert persisted_user.id == literal_test_user.id + assert persisted_user.identifier == literal_test_user.identifier + + +async def test_create_user_id_none( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_user: User, + literal_test_user: LiteralUser, +): + """Weird edge case; persisted user without an id. Do we need this!??""" + + literal_test_user.id = None + mock_literal_client.api.get_user.return_value = literal_test_user + + persisted_user = await literal_data_layer.create_user(test_user) + + mock_literal_client.api.create_user.assert_not_called() + mock_literal_client.api.update_user.assert_not_called() + + assert persisted_user is not None + assert isinstance(persisted_user, PersistedUser) + assert persisted_user.id == "" + assert persisted_user.identifier == literal_test_user.identifier + + +async def test_update_thread( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_thread: LiteralThread, +): + await literal_data_layer.update_thread(test_thread.id, name=test_thread.name) + + mock_literal_client.api.upsert_thread.assert_awaited_once_with( + id=test_thread.id, + name=test_thread.name, + participant_id=None, + metadata=None, + tags=None, + ) + + +async def test_get_thread_author( + literal_data_layer, mock_literal_client: Mock, test_thread: LiteralThread +): + test_thread.participant_identifier = "test_user_identifier" + mock_literal_client.api.get_thread.return_value = test_thread + + author = await literal_data_layer.get_thread_author(test_thread.id) + + assert author == "test_user_identifier" + mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) + + +async def test_get_thread( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_thread: LiteralThread, + test_step: LiteralStep, +): + assert isinstance(test_thread.steps, list) + test_thread.steps.append(test_step) + + mock_literal_client.api.get_thread.return_value = test_thread + + thread = await literal_data_layer.get_thread(test_thread.id) + mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) + + assert thread is not None + assert thread["id"] == test_thread.id + assert thread["name"] == test_thread.name + assert len(thread["steps"]) == 1 + assert thread["steps"][0].get("id") == test_step.id + + +async def test_get_thread_with_stub_step( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_thread: LiteralThread, +): + # Create a step that should be stubbed + stub_step = LiteralStep.from_dict( + { + "id": "stub_step_id", + "name": "Stub Step", + "type": "undefined", + "threadId": test_thread.id, + "createdAt": "2023-01-01T00:00:00Z", + } + ) + test_thread.steps = [stub_step] + + mock_literal_client.api.get_thread.return_value = test_thread + + # Mock the config.ui.cot value to ensure check_add_step_in_cot returns False + with patch("chainlit.config.config.ui.cot", "hidden"): + thread = await literal_data_layer.get_thread(test_thread.id) + + mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) + + assert thread is not None + assert thread["id"] == test_thread.id + assert thread["name"] == test_thread.name + assert len(thread["steps"]) == 1 + assert thread["steps"][0].get("id") == "stub_step_id" + assert thread["steps"][0].get("type") == "undefined" + assert thread["steps"][0].get("input") == "" + assert thread["steps"][0].get("output") == "" + + # Additional assertions to ensure the step is stubbed + assert "metadata" not in thread["steps"][0] + assert "createdAt" not in thread["steps"][0] + + +async def test_get_thread_with_attachment( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_thread: LiteralThread, + test_step: LiteralStep, + test_attachment: LiteralAttachment, +): + # Add the attachment to the test step + test_step.attachments = [test_attachment] + test_thread.steps = [test_step] + + mock_literal_client.api.get_thread.return_value = test_thread + + thread = await literal_data_layer.get_thread(test_thread.id) + mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) + + assert thread is not None + assert thread["id"] == test_thread.id + assert thread["name"] == test_thread.name + assert thread["steps"] is not None + assert len(thread["steps"]) == 1 + assert thread["elements"] is not None + assert len(thread["elements"]) == 1 + + element = thread["elements"][0] if thread["elements"] else None + assert element is not None + assert element["id"] == "test_attachment_id" + assert element["forId"] == test_step.id + assert element["threadId"] == test_thread.id + assert element["type"] == "file" + assert element["display"] == "side" + assert element["language"] == "python" + assert element["mime"] == "text/plain" + assert element["name"] == "test_file.txt" + assert element["objectKey"] == "test_object_key" + assert element["url"] == "https://example.com/test_file.txt" + + +async def test_get_thread_non_existing( + literal_data_layer: LiteralDataLayer, mock_literal_client: Mock +): + mock_literal_client.api.get_thread.return_value = None + + thread = await literal_data_layer.get_thread("non_existent_thread_id") + mock_literal_client.api.get_thread.assert_awaited_once_with( + id="non_existent_thread_id" + ) + + assert thread is None + + +async def test_delete_thread( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_thread: LiteralThread, +): + await literal_data_layer.delete_thread(test_thread.id) + + mock_literal_client.api.delete_thread.assert_awaited_once_with(id=test_thread.id) + + +async def test_list_threads( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_filters: ThreadFilter, + test_pagination: Pagination, +): + response: PaginatedResponse[Thread] = PaginatedResponse( + page_info=PageInfo( + has_next_page=True, start_cursor="start_cursor", end_cursor="end_cursor" + ), + data=[ + Thread( + id="thread1", + name="Thread 1", + ), + Thread( + id="thread2", + name="Thread 2", + ), + ], + ) + + mock_literal_client.api.list_threads.return_value = response + + result = await literal_data_layer.list_threads(test_pagination, test_filters) + + mock_literal_client.api.list_threads.assert_awaited_once_with( + first=10, + after=None, + filters=[ + {"field": "participantId", "operator": "eq", "value": "user1"}, + { + "field": "stepOutput", + "operator": "ilike", + "value": "test", + "path": "content", + }, + { + "field": "scoreValue", + "operator": "eq", + "value": 1, + "path": "user-feedback", + }, + ], + order_by={"column": "createdAt", "direction": "DESC"}, + ) + + assert result.pageInfo.hasNextPage + assert result.pageInfo.startCursor == "start_cursor" + assert result.pageInfo.endCursor == "end_cursor" + assert len(result.data) == 2 + assert result.data[0]["id"] == "thread1" + assert result.data[1]["id"] == "thread2" + + +async def test_create_element( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, +): + mock_literal_client.api.upload_file.return_value = {"object_key": "test_object_key"} + + async with mock_chainlit_context: + text_element = Text( + id=str(uuid.uuid4()), + name="test.txt", + mime="text/plain", + content="test content", + for_id="test_step_id", + ) + + await literal_data_layer.create_element(text_element) + + mock_literal_client.api.upload_file.assert_awaited_once_with( + content=text_element.content, + mime=text_element.mime, + thread_id=text_element.thread_id, + ) + + mock_literal_client.api.send_steps.assert_awaited_once_with( + [ + { + "id": text_element.for_id, + "threadId": text_element.thread_id, + "attachments": [ + { + "id": ANY, + "name": text_element.name, + "metadata": { + "size": None, + "language": None, + "display": text_element.display, + "type": text_element.type, + "page": None, + }, + "mime": text_element.mime, + "url": None, + "objectKey": "test_object_key", + } + ], + } + ] + ) + + +async def test_get_element( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + test_attachment: LiteralAttachment, +): + mock_literal_client.api.get_attachment.return_value = test_attachment + + element_dict = await literal_data_layer.get_element( + "test_thread_id", "test_element_id" + ) + + mock_literal_client.api.get_attachment.assert_awaited_once_with( + id="test_element_id" + ) + + assert element_dict is not None + + # Compare element_dict attributes to attachment attributes + assert element_dict["id"] == test_attachment.id + assert element_dict["forId"] == test_attachment.step_id + assert element_dict["threadId"] == test_attachment.thread_id + assert element_dict["name"] == test_attachment.name + assert element_dict["mime"] == test_attachment.mime + assert element_dict["url"] == test_attachment.url + assert element_dict["objectKey"] == test_attachment.object_key + assert test_attachment.metadata + assert element_dict["display"] == test_attachment.metadata["display"] + assert element_dict["language"] == test_attachment.metadata["language"] + assert element_dict["type"] == test_attachment.metadata["type"] + + +async def test_upsert_feedback_create( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, +): + feedback = Feedback(forId="test_step_id", value=1, comment="Great!") + mock_literal_client.api.create_score.return_value = Mock(id="new_feedback_id") + + result = await literal_data_layer.upsert_feedback(feedback) + + mock_literal_client.api.create_score.assert_awaited_once_with( + step_id="test_step_id", + value=1, + comment="Great!", + name="user-feedback", + type="HUMAN", + ) + assert result == "new_feedback_id" + + +async def test_upsert_feedback_update( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, +): + feedback = Feedback( + id="existing_feedback_id", + forId="test_step_id", + value=0, + comment="Needs improvement", + ) + + result = await literal_data_layer.upsert_feedback(feedback) + + mock_literal_client.api.update_score.assert_awaited_once_with( + id="existing_feedback_id", + update_params={ + "comment": "Needs improvement", + "value": 0, + }, + ) + assert result == "existing_feedback_id" + + +async def test_delete_feedback( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, +): + feedback_id = "test_feedback_id" + + result = await literal_data_layer.delete_feedback(feedback_id) + + mock_literal_client.api.delete_score.assert_awaited_once_with(id=feedback_id) + assert result is True + + +async def test_delete_feedback_empty_id( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, +): + feedback_id = "" + + result = await literal_data_layer.delete_feedback(feedback_id) + + mock_literal_client.api.delete_score.assert_not_awaited() + assert result is False + + +async def test_build_debug_url( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, +): + mock_literal_client.api.get_my_project_id.return_value = "test_project_id" + mock_literal_client.api.url = "https://api.example.com" + + result = await literal_data_layer.build_debug_url() + + mock_literal_client.api.get_my_project_id.assert_awaited_once() + assert ( + result + == "https://api.example.com/projects/test_project_id/logs/threads/[thread_id]?currentStepId=[step_id]" + ) + + +async def test_build_debug_url_error( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + caplog, +): + mock_literal_client.api.get_my_project_id.side_effect = Exception("API Error") + + result = await literal_data_layer.build_debug_url() + + mock_literal_client.api.get_my_project_id.assert_awaited_once() + assert result == "" + assert "Error building debug url: API Error" in caplog.text + + +async def test_delete_element( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, +): + element_id = "test_element_id" + + async with mock_chainlit_context: + await literal_data_layer.delete_element(element_id) + + mock_literal_client.api.delete_attachment.assert_awaited_once_with(id=element_id) + + +async def test_delete_step( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, +): + step_id = "test_step_id" + + async with mock_chainlit_context: + await literal_data_layer.delete_step(step_id) + + mock_literal_client.api.delete_step.assert_awaited_once_with(id=step_id) + + +async def test_update_step( + literal_data_layer: LiteralDataLayer, + mock_literal_client: Mock, + mock_chainlit_context, + test_step_dict: StepDict, +): + async with mock_chainlit_context: + await literal_data_layer.update_step(test_step_dict) + + mock_literal_client.api.send_steps.assert_awaited_once_with( + [ + { + "createdAt": "2023-01-01T00:00:00Z", + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:00Z", + "generation": {}, + "id": "test_step_id", + "parentId": None, + "name": "Test Step", + "threadId": "test_thread_id", + "type": "user_message", + "tags": [], + "metadata": { + "key": "value", + "waitForAnswer": True, + "language": "en", + "showInput": True, + }, + "input": {"content": "test input"}, + "output": {"content": "test output"}, + } + ] + ) + + +def test_steptype_to_steptype(): + assert ( + LiteralToChainlitConverter.steptype_to_steptype("user_message") == "undefined" + ) + assert ( + LiteralToChainlitConverter.steptype_to_steptype("assistant_message") + == "undefined" + ) + assert ( + LiteralToChainlitConverter.steptype_to_steptype("system_message") == "undefined" + ) + assert LiteralToChainlitConverter.steptype_to_steptype("tool") == "tool" + assert LiteralToChainlitConverter.steptype_to_steptype(None) == "undefined" + + +def test_score_to_feedbackdict(): + score = LiteralScore( + id="test_score_id", + step_id="test_step_id", + value=1, + comment="Great job!", + name="user-feedback", + type="HUMAN", + dataset_experiment_item_id=None, + tags=None, + ) + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) + assert feedback_dict == { + "id": "test_score_id", + "forId": "test_step_id", + "value": 1, + "comment": "Great job!", + } + + assert LiteralToChainlitConverter.score_to_feedbackdict(None) is None + + score.value = 0 + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) + assert feedback_dict is not None + assert feedback_dict["value"] == 0 + + score.id = None + score.step_id = None + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) + assert feedback_dict is not None + assert feedback_dict["id"] == "" + assert feedback_dict["forId"] == "" + + +def test_step_to_stepdict(): + literal_step = LiteralStep.from_dict( + { + "id": "test_step_id", + "threadId": "test_thread_id", + "type": "user_message", + "name": "Test Step", + "input": {"content": "test input"}, + "output": {"content": "test output"}, + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:01Z", + "createdAt": "2023-01-01T00:00:00Z", + "metadata": {"showInput": True, "language": "en"}, + "error": None, + "scores": [ + { + "id": "test_score_id", + "stepId": "test_step_id", + "value": 1, + "comment": "Great job!", + "name": "user-feedback", + "type": "HUMAN", + } + ], + } + ) + + step_dict = LiteralToChainlitConverter.step_to_stepdict(literal_step) + + assert step_dict.get("id") == "test_step_id" + assert step_dict.get("threadId") == "test_thread_id" + assert step_dict.get("type") == "user_message" + assert step_dict.get("name") == "Test Step" + assert step_dict.get("input") == "test input" + assert step_dict.get("output") == "test output" + assert step_dict.get("start") == "2023-01-01T00:00:00Z" + assert step_dict.get("end") == "2023-01-01T00:00:01Z" + assert step_dict.get("createdAt") == "2023-01-01T00:00:00Z" + assert step_dict.get("showInput") == True + assert step_dict.get("language") == "en" + assert step_dict.get("isError") == False + assert step_dict.get("feedback") == { + "id": "test_score_id", + "forId": "test_step_id", + "value": 1, + "comment": "Great job!", + } + + +def test_attachment_to_elementdict(): + attachment = Attachment( + id="test_attachment_id", + step_id="test_step_id", + thread_id="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + object_key="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "file", + "size": "large", + }, + ) + + element_dict = LiteralToChainlitConverter.attachment_to_elementdict(attachment) + + assert element_dict["id"] == "test_attachment_id" + assert element_dict["forId"] == "test_step_id" + assert element_dict["threadId"] == "test_thread_id" + assert element_dict["name"] == "test.txt" + assert element_dict["mime"] == "text/plain" + assert element_dict["url"] == "https://example.com/test.txt" + assert element_dict["objectKey"] == "test_object_key" + assert element_dict["display"] == "side" + assert element_dict["language"] == "python" + assert element_dict["type"] == "file" + assert element_dict["size"] == "large" + + +def test_attachment_to_element(): + attachment = Attachment( + id="test_attachment_id", + step_id="test_step_id", + thread_id="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + object_key="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "text", + "size": "small", + }, + ) + + element = LiteralToChainlitConverter.attachment_to_element(attachment) + + assert isinstance(element, Text) + assert element.id == "test_attachment_id" + assert element.for_id == "test_step_id" + assert element.thread_id == "test_thread_id" + assert element.name == "test.txt" + assert element.mime == "text/plain" + assert element.url == "https://example.com/test.txt" + assert element.object_key == "test_object_key" + assert element.display == "side" + assert element.language == "python" + assert element.size == "small" + + # Test other element types + for element_type in ["file", "image", "audio", "video", "pdf"]: + attachment.metadata = {"type": element_type, "size": "small"} + + element = LiteralToChainlitConverter.attachment_to_element(attachment) + assert isinstance( + element, + { + "file": File, + "image": Image, + "audio": Audio, + "video": Video, + "text": Text, + "pdf": Pdf, + }[element_type], + ) + + +def test_step_to_step(): + literal_step = LiteralStep.from_dict( + { + "id": "test_step_id", + "threadId": "test_thread_id", + "type": "user_message", + "name": "Test Step", + "input": {"content": "test input"}, + "output": {"content": "test output"}, + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:01Z", + "createdAt": "2023-01-01T00:00:00Z", + "metadata": {"showInput": True, "language": "en"}, + "error": None, + "attachments": [ + { + "id": "test_attachment_id", + "name": "test.txt", + "mime": "text/plain", + "url": "https://example.com/test.txt", + "objectKey": "test_object_key", + "metadata": { + "display": "side", + "language": "python", + "type": "text", + }, + } + ], + } + ) + + chainlit_step = LiteralToChainlitConverter.step_to_step(literal_step) + + assert isinstance(chainlit_step, Step) + assert chainlit_step.id == "test_step_id" + assert chainlit_step.thread_id == "test_thread_id" + assert chainlit_step.type == "undefined" + assert chainlit_step.name == "Test Step" + assert chainlit_step.input == "test input" + assert chainlit_step.output == "test output" + assert chainlit_step.start == "2023-01-01T00:00:00Z" + assert chainlit_step.end == "2023-01-01T00:00:01Z" + assert chainlit_step.created_at == "2023-01-01T00:00:00Z" + assert chainlit_step.metadata == {"showInput": True, "language": "en"} + assert not chainlit_step.is_error + assert chainlit_step.elements is not None + assert len(chainlit_step.elements) == 1 + assert isinstance(chainlit_step.elements[0], Text) + + +def test_thread_to_threaddict(): + attachment_dict = LiteralAttachmentDict( + id="test_attachment_id", + stepId="test_step_id", + threadId="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + objectKey="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "text", + }, + ) + step_dict = LiteralStepDict( + id="test_step_id", + threadId="test_thread_id", + type="user_message", + name="Test Step", + input={"content": "test input"}, + output={"content": "test output"}, + startTime="2023-01-01T00:00:00Z", + endTime="2023-01-01T00:00:01Z", + createdAt="2023-01-01T00:00:00Z", + metadata={"showInput": True, "language": "en"}, + error=None, + attachments=[attachment_dict], + ) + literal_thread = LiteralThread.from_dict( + LiteralThreadDict( + id="test_thread_id", + name="Test Thread", + createdAt="2023-01-01T00:00:00Z", + participant=UserDict(id="test_user_id", identifier="test_user_identifier_"), + tags=["tag1", "tag2"], + metadata={"key": "value"}, + steps=[step_dict], + ) + ) + + thread_dict = LiteralToChainlitConverter.thread_to_threaddict(literal_thread) + + assert thread_dict["id"] == "test_thread_id" + assert thread_dict["name"] == "Test Thread" + assert thread_dict["createdAt"] == "2023-01-01T00:00:00Z" + assert thread_dict["userId"] == "test_user_id" + assert thread_dict["userIdentifier"] == "test_user_identifier_" + assert thread_dict["tags"] == ["tag1", "tag2"] + assert thread_dict["metadata"] == {"key": "value"} + assert thread_dict["steps"] is not None + assert len(thread_dict["steps"]) == 1 + assert thread_dict["elements"] is not None + assert len(thread_dict["elements"]) == 1 diff --git a/backend/tests/data/test_sql_alchemy.py b/backend/tests/data/test_sql_alchemy.py index d8dd5169e8..49b2619181 100644 --- a/backend/tests/data/test_sql_alchemy.py +++ b/backend/tests/data/test_sql_alchemy.py @@ -1,17 +1,14 @@ -from unittest.mock import Mock import uuid from pathlib import Path import pytest -import pytest_asyncio -from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine -from chainlit.data.base import BaseDataLayer, BaseStorageClient +from chainlit import User +from chainlit.data.base import BaseStorageClient from chainlit.data.sql_alchemy import SQLAlchemyDataLayer from chainlit.element import Text -from chainlit import User -from chainlit.user import PersistedUser @pytest.fixture @@ -116,11 +113,6 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path): yield data_layer -@pytest.fixture -def test_user() -> User: - return User(identifier="sqlalchemy_test_user_id") - - async def test_create_and_get_element( mock_chainlit_context, data_layer: SQLAlchemyDataLayer ):