Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mic-smith committed Dec 6, 2024
1 parent 9188071 commit 1f1d9c3
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions backend/src/directors/chat_director.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import json
import logging
from typing import Any
from typing import Optional
from uuid import uuid4

from src.utils.json import try_pretty_print
Expand All @@ -20,6 +21,11 @@
engine = PromptEngine()
director_prompt = engine.load_prompt("chat_director")

@dataclass
class FinalAnswer:
message: Optional[str] = ""
dataset: Optional[str] = None


async def question(question: str) -> ChatResponse:
intent = await get_intent_agent().invoke(question)
Expand All @@ -46,10 +52,10 @@ async def question(question: str) -> ChatResponse:
dataset=None,
reasoning=try_pretty_print(current_scratchpad))

final_answer = {}
final_answer = FinalAnswer()
try:
final_answer = await __get_final_answer(question, intent_json)
update_session_chat(role="system", content=final_answer.get("message"))
final_answer = await __create_final_answer(question, intent_json)
update_session_chat(role="system", content=final_answer.message)
except Exception as error:
logger.error(f"Error during answer generation: {error}", error)
update_scratchpad(error=str(error))
Expand All @@ -58,8 +64,8 @@ async def question(question: str) -> ChatResponse:

response = ChatResponse(id=str(uuid4()),
question=question,
answer=final_answer.get("message") or '',
dataset=final_answer.get("dataset"),
answer=final_answer.message or '',
dataset=final_answer.dataset,
reasoning=try_pretty_print(current_scratchpad))

store_chat_message(response)
Expand All @@ -69,7 +75,7 @@ async def question(question: str) -> ChatResponse:
return response


async def __get_final_answer(question: str, intent_json: dict) -> dict[str, Any]:
async def __create_final_answer(question: str, intent_json: dict) -> FinalAnswer:
dataset = None
if intent_json['result_type'] == 'dataset':
# get the last DatastoreAgent result dataset from the scratchpad
Expand All @@ -80,7 +86,7 @@ async def __get_final_answer(question: str, intent_json: dict) -> dict[str, Any]

message = await get_answer_agent().invoke(question)

return { "message": message, "dataset": dataset }
return FinalAnswer(message, dataset)


async def dataset_upload() -> None:
Expand Down

0 comments on commit 1f1d9c3

Please sign in to comment.