From 1f1d9c392c248acc1869626f41aa3cc772450a91 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 6 Dec 2024 14:19:07 +0000 Subject: [PATCH] PR comments --- backend/src/directors/chat_director.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/backend/src/directors/chat_director.py b/backend/src/directors/chat_director.py index 95965a12..2aa5a513 100644 --- a/backend/src/directors/chat_director.py +++ b/backend/src/directors/chat_director.py @@ -1,6 +1,7 @@ +from dataclasses import dataclass import json import logging -from typing import Any +from typing import Optional from uuid import uuid4 from src.utils.json import try_pretty_print @@ -20,6 +21,11 @@ engine = PromptEngine() director_prompt = engine.load_prompt("chat_director") +@dataclass +class FinalAnswer: + message: Optional[str] = "" + dataset: Optional[str] = None + async def question(question: str) -> ChatResponse: intent = await get_intent_agent().invoke(question) @@ -46,10 +52,10 @@ async def question(question: str) -> ChatResponse: dataset=None, reasoning=try_pretty_print(current_scratchpad)) - final_answer = {} + final_answer = FinalAnswer() try: - final_answer = await __get_final_answer(question, intent_json) - update_session_chat(role="system", content=final_answer.get("message")) + final_answer = await __create_final_answer(question, intent_json) + update_session_chat(role="system", content=final_answer.message) except Exception as error: logger.error(f"Error during answer generation: {error}", error) update_scratchpad(error=str(error)) @@ -58,8 +64,8 @@ async def question(question: str) -> ChatResponse: response = ChatResponse(id=str(uuid4()), question=question, - answer=final_answer.get("message") or '', - dataset=final_answer.get("dataset"), + answer=final_answer.message or '', + dataset=final_answer.dataset, reasoning=try_pretty_print(current_scratchpad)) store_chat_message(response) @@ -69,7 +75,7 @@ async def question(question: str) -> ChatResponse: return response -async def __get_final_answer(question: str, intent_json: dict) -> dict[str, Any]: +async def __create_final_answer(question: str, intent_json: dict) -> FinalAnswer: dataset = None if intent_json['result_type'] == 'dataset': # get the last DatastoreAgent result dataset from the scratchpad @@ -80,7 +86,7 @@ async def __get_final_answer(question: str, intent_json: dict) -> dict[str, Any] message = await get_answer_agent().invoke(question) - return { "message": message, "dataset": dataset } + return FinalAnswer(message, dataset) async def dataset_upload() -> None: