Skip to content

Commit

Permalink
FS-73 Add grid view (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
mic-smith authored Dec 9, 2024
1 parent 777b3b8 commit d17eb78
Show file tree
Hide file tree
Showing 20 changed files with 952 additions and 511 deletions.
1 change: 1 addition & 0 deletions backend/src/chat_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ChatResponse(TypedDict):
id: str
question:str
answer: str
dataset: str | None
reasoning: str | None

logger = logging.getLogger(__name__)
Expand Down
35 changes: 32 additions & 3 deletions backend/src/directors/chat_director.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
import json
import logging
from typing import Optional
from uuid import uuid4

from src.utils.json import try_pretty_print
Expand All @@ -19,6 +21,11 @@
engine = PromptEngine()
director_prompt = engine.load_prompt("chat_director")

@dataclass
class FinalAnswer:
message: str = ""
dataset: Optional[str] = None


async def question(question: str) -> ChatResponse:
intent = await get_intent_agent().invoke(question)
Expand All @@ -42,15 +49,23 @@ async def question(question: str) -> ChatResponse:
return ChatResponse(id=str(uuid4()),
question=question,
answer="",
dataset=None,
reasoning=try_pretty_print(current_scratchpad))

final_answer = await get_answer_agent().invoke(question)
update_session_chat(role="system", content=final_answer)
final_answer = FinalAnswer()
try:
final_answer = await __create_final_answer(question, intent_json)
update_session_chat(role="system", content=final_answer.message)
except Exception as error:
logger.error(f"Error during answer generation: {error}", error)
update_scratchpad(error=str(error))

logger.info(f"final answer: {final_answer}")

response = ChatResponse(id=str(uuid4()),
question=question,
answer=final_answer,
answer=final_answer.message or '',
dataset=final_answer.dataset,
reasoning=try_pretty_print(current_scratchpad))

store_chat_message(response)
Expand All @@ -60,6 +75,20 @@ async def question(question: str) -> ChatResponse:
return response


async def __create_final_answer(question: str, intent_json: dict) -> FinalAnswer:
dataset = None
if intent_json['result_type'] == 'dataset':
# get the last DatastoreAgent result dataset from the scratchpad
datastore_agents = [scratch for scratch in get_scratchpad() if scratch['agent_name'] == 'DatastoreAgent']
query_result = datastore_agents[-1]['result'] if datastore_agents else None
if query_result is not None:
dataset = query_result

message = await get_answer_agent().invoke(question)

return FinalAnswer(message, dataset)


async def dataset_upload() -> None:
dataset_file = "./datasets/bloomberg.csv"

Expand Down
1 change: 1 addition & 0 deletions backend/src/prompts/templates/intent-system.j2
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ Output your result in the following json format:
{
"question": "string of the original question",
"user_intent": "string of the intent of the user's question",
"result_type": "string of the type of result expected, this will be either 'text' or 'dataset'",
"questions": array of singular objective questions or if the question mentions csv, dataset or database an empty array
}
2 changes: 1 addition & 1 deletion backend/tests/api/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_chat_delete(mocker):
assert response.status_code == 204

def test_chat_message_success(mocker):
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning", dataset="dataset")
mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=message)

response = client.get("/chat/123")
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/chat_storage_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mock_redis():
def test_store_chat_message(mocker, mock_redis):
mocker.patch('src.chat_storage_service.redis_client', mock_redis)

message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning", dataset="dataset")
store_chat_message(message)

mock_redis.set.assert_called_once_with("chat_1", json.dumps(message))
Expand All @@ -24,7 +24,7 @@ def test_store_chat_message(mocker, mock_redis):
def test_get_chat_message(mocker, mock_redis):
mocker.patch('src.chat_storage_service.redis_client', mock_redis)

message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning", dataset="dataset")
mock_redis.get.return_value = json.dumps(message)

value = get_chat_message("1")
Expand Down
Loading

0 comments on commit d17eb78

Please sign in to comment.