From 179a43f9ab06e81e2b67f926778141845692da8e Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 15 Nov 2024 15:42:14 +0000 Subject: [PATCH] FS-99 Clarify answer --- backend/src/api/app.py | 12 ++++++ backend/src/chat_storage_service.py | 34 +++++++++++++++ backend/src/director.py | 21 ++++++++-- backend/src/utils/json.py | 7 ++++ backend/tests/BDD/step_defs/test_prompts.py | 4 +- backend/tests/api/app_test.py | 18 ++++++++ backend/tests/chat_storage_service_test.py | 32 ++++++++++++++ frontend/src/components/message.module.css | 46 +++++++++++++++++---- frontend/src/components/message.tsx | 38 ++++++++++++++--- frontend/src/server.ts | 14 +++---- frontend/src/useMessages.ts | 32 ++++++++++---- 11 files changed, 223 insertions(+), 35 deletions(-) create mode 100644 backend/src/chat_storage_service.py create mode 100644 backend/tests/chat_storage_service_test.py diff --git a/backend/src/api/app.py b/backend/src/api/app.py index 7f38c75d..433addb7 100644 --- a/backend/src/api/app.py +++ b/backend/src/api/app.py @@ -8,6 +8,7 @@ from fastapi import FastAPI, HTTPException, Response, WebSocket, UploadFile from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from src.chat_storage_service import get_chat_message from src.session.file_uploads import clear_session_file_uploads from src.session.redis_session_middleware import reset_session from src.utils.graph_db_utils import populate_db @@ -109,6 +110,17 @@ async def clear_chat(): logger.exception(e) return Response(status_code=500) +@app.get("/chat/{id}") +def chat_message(id: str): + logger.info(f"Get chat message called with id: {id}") + try: + final_result = get_chat_message(id) + if final_result is None: + return JSONResponse(status_code=404, content=f"Message with id {id} not found") + return JSONResponse(status_code=200, content=final_result) + except Exception as e: + logger.exception(e) + return JSONResponse(status_code=500, content=chat_fail_response) @app.get("/suggestions") async def suggestions(): diff --git a/backend/src/chat_storage_service.py b/backend/src/chat_storage_service.py new file mode 100644 index 00000000..f6b5dc59 --- /dev/null +++ b/backend/src/chat_storage_service.py @@ -0,0 +1,34 @@ + +import json +import logging +from typing import TypedDict +import redis + +from src.utils.json import try_parse_to_json +from src.utils import Config + +class ChatResponse(TypedDict): + id: str + question:str + answer: str + reasoning: str | None + +logger = logging.getLogger(__name__) + +config = Config() + +redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True) + +CHAT_KEY_PREFIX = "chat_" + +def store_chat_message(chat:ChatResponse): + redis_client.set(CHAT_KEY_PREFIX + chat["id"], json.dumps(chat)) + + +def get_chat_message(id: str) -> ChatResponse | None: + value = redis_client.get(CHAT_KEY_PREFIX + id) + if value and isinstance(value, str): + parsed_session_data = try_parse_to_json(value) + if parsed_session_data: + return parsed_session_data + return None diff --git a/backend/src/director.py b/backend/src/director.py index c65be9ee..fb4768c9 100644 --- a/backend/src/director.py +++ b/backend/src/director.py @@ -1,5 +1,9 @@ import json import logging +from uuid import uuid4 + +from src.utils.json import try_pretty_print +from src.chat_storage_service import ChatResponse, store_chat_message from src.utils import clear_scratchpad, update_scratchpad, get_scratchpad from src.session import update_session_chat from src.agents import get_intent_agent, get_answer_agent @@ -13,8 +17,7 @@ engine = PromptEngine() director_prompt = engine.load_prompt("director") - -async def question(question: str) -> str: +async def question(question: str) -> ChatResponse: intent = await get_intent_agent().invoke(question) intent_json = json.loads(intent) update_session_chat(role="user", content=question) @@ -33,12 +36,22 @@ async def question(question: str) -> str: generated_figure = entry["result"] await connection_manager.send_chart({"type": "image", "data": generated_figure}) clear_scratchpad() - return "" + return ChatResponse(id=str(uuid4()), + question=question, + answer="", + reasoning=try_pretty_print(current_scratchpad)) final_answer = await get_answer_agent().invoke(question) update_session_chat(role="system", content=final_answer) logger.info(f"final answer: {final_answer}") + response = ChatResponse(id=str(uuid4()), + question=question, + answer=final_answer, + reasoning=try_pretty_print(current_scratchpad)) + + store_chat_message(response) + clear_scratchpad() - return final_answer + return response diff --git a/backend/src/utils/json.py b/backend/src/utils/json.py index ac26f8dd..9ec5f6ed 100644 --- a/backend/src/utils/json.py +++ b/backend/src/utils/json.py @@ -15,3 +15,10 @@ def try_parse_to_json(json_string: str): except json.JSONDecodeError as error: logger.error(f"Error parsing json: {error}") return None + +def try_pretty_print(obj): + try: + return json.dumps(obj, indent=4) + except Exception as error: + logger.error(f"Error pretty printing json: {error}") + return None diff --git a/backend/tests/BDD/step_defs/test_prompts.py b/backend/tests/BDD/step_defs/test_prompts.py index 732356d2..440a1ae0 100644 --- a/backend/tests/BDD/step_defs/test_prompts.py +++ b/backend/tests/BDD/step_defs/test_prompts.py @@ -37,7 +37,7 @@ def get_response(context): @then(parsers.parse("the response to this '{prompt}' should match the '{expected_response}'")) def check_response_includes_expected_response(context, prompt, expected_response): response = send_prompt(prompt) - actual_response = response.json() + actual_response = response.json()["answer"] try: expected_value = Decimal(str(expected_response).strip()) @@ -81,5 +81,5 @@ def check_response_includes_expected_response(context, prompt, expected_response @then(parsers.parse("the response to this '{prompt}' should give a confident answer")) def check_bot_response_confidence(prompt): response = send_prompt(prompt) - result = check_response_confidence(prompt, response.json()) + result = check_response_confidence(prompt, response.json()["answer"]) assert result["score"] == 1, "The bot response is not confident enough. \nReasoning: " + result["reasoning"] diff --git a/backend/tests/api/app_test.py b/backend/tests/api/app_test.py index f87dfb4e..54aba2ca 100644 --- a/backend/tests/api/app_test.py +++ b/backend/tests/api/app_test.py @@ -1,5 +1,6 @@ from fastapi.testclient import TestClient import pytest +from src.chat_storage_service import ChatResponse from src.api import app, healthy_response, unhealthy_neo4j_response, chat_fail_response client = TestClient(app) @@ -73,6 +74,23 @@ def test_chat_delete(mocker): assert response.status_code == 204 +def test_chat_message_success(mocker): + message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning") + mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=message) + + response = client.get("/chat/123") + + mock_get_chat_message.assert_called_with("123") + assert response.status_code == 200 + assert response.json() == message + +def test_chat_message_not_found(mocker): + mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=None) + + response = client.get("/chat/123") + + mock_get_chat_message.assert_called_with("123") + assert response.status_code == 404 @pytest.mark.asyncio async def test_lifespan_populates_db(mocker, mock_initial_data) -> None: diff --git a/backend/tests/chat_storage_service_test.py b/backend/tests/chat_storage_service_test.py new file mode 100644 index 00000000..a80bbe78 --- /dev/null +++ b/backend/tests/chat_storage_service_test.py @@ -0,0 +1,32 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from src.chat_storage_service import ChatResponse, get_chat_message, store_chat_message + +@pytest.fixture +def mock_redis(): + with patch('src.chat_storage_service.redis_client') as mock_redis: + mock_instance = MagicMock() + mock_redis.return_value = mock_instance + yield mock_instance + +def test_store_chat_message(mocker, mock_redis): + mocker.patch('src.chat_storage_service.redis_client', mock_redis) + + message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning") + store_chat_message(message) + + mock_redis.set.assert_called_once_with("chat_1", json.dumps(message)) + + +def test_get_chat_message(mocker, mock_redis): + mocker.patch('src.chat_storage_service.redis_client', mock_redis) + + message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning") + mock_redis.get.return_value = json.dumps(message) + + value = get_chat_message("1") + + assert value == message diff --git a/frontend/src/components/message.module.css b/frontend/src/components/message.module.css index bb2b92d2..91be66f9 100644 --- a/frontend/src/components/message.module.css +++ b/frontend/src/components/message.module.css @@ -1,22 +1,20 @@ .container { - align-items: center; - display: flex; - flex-direction: row; - justify-content: left; border-radius: 16px; width: 100%; margin-bottom: 8px; } +.message_container { + display: flex; +} + .bot { - align-self: flex-start; background-color: var(--grey-50); border: 1px solid var(--grey-300); box-sizing: border-box; } .user { - align-self: flex-start; background-color: var(--primary); box-shadow: inset 0 0 0 100vmax rgba(255, 255, 255, 0.65); } @@ -25,7 +23,6 @@ width: 40px; height: 40px; margin: 16px; - align-self: flex-start; } .messageStyle { @@ -33,3 +30,38 @@ line-height: 22px; font-size: 16px; } + +.reasoning_header { + font-weight: bold; + font-size: 16px; + display: flex; + justify-content: space-between; + align-items: center; + border-top: 1px solid var(--grey-300); + padding: 0 24px 0 16px; + height: 40px; +} + +.reasoning_header:hover { + background-color: var(--grey-200); + cursor: pointer; +} + +.expandIcon { + width: 24px; + height: 24px; +} + +.reasoning_header_expanded { + border-top: 1px solid var(--grey-500); + + .expandIcon { + transform: rotate(90deg); + } +} + +.reason { + padding: 12px 32px; + font-size: 16px; + white-space: pre-wrap; +} diff --git a/frontend/src/components/message.tsx b/frontend/src/components/message.tsx index 19020c77..40c970c5 100644 --- a/frontend/src/components/message.tsx +++ b/frontend/src/components/message.tsx @@ -1,8 +1,9 @@ import classNames from 'classnames'; -import React, { useMemo } from 'react'; +import React, { useState } from 'react'; import styles from './message.module.css'; import UserIcon from '../icons/account-circle.svg'; import BotIcon from '../icons/logomark.svg'; +import ChevronIcon from '../icons/chevron.svg'; export enum Role { User = 'User', @@ -10,8 +11,10 @@ export enum Role { } export interface Message { + id?: string; role: Role; content: string; + reasoning?: string; time: string; } @@ -36,14 +39,39 @@ const roleStyleMap: Record = { }; export const MessageComponent = ({ message }: MessageProps) => { - const { content, role } = message; + const { content, role, reasoning } = message; - const { class: roleClass, icon } = useMemo(() => roleStyleMap[role], [role]); + const { class: roleClass, icon } = roleStyleMap[role]; + + const [expanded, setExpanded] = useState(false); return (
- -

{content}

+
+ +

{content}

+
+ {role == Role.Bot && reasoning && ( + <> +
setExpanded(!expanded)} + onKeyDown={(event) => { + if (event.key === 'Enter' || event.key === ' ') { + setExpanded(!expanded); + } + }} + > + How I came to this conclusion + +
+ {expanded &&
{reasoning}
} + + )}
); }; diff --git a/frontend/src/server.ts b/frontend/src/server.ts index 33b7b56c..1189b180 100644 --- a/frontend/src/server.ts +++ b/frontend/src/server.ts @@ -1,9 +1,12 @@ export interface ChatMessageResponse { - message: string; + id?: string; + question?: string; + answer: string; + reasoning?: string; } function createChatMessageResponse(message: string): ChatMessageResponse { - return { message }; + return { answer: message }; } export const getResponse = async ( @@ -50,9 +53,6 @@ const callChatEndpoint = async ( return response; }) .then((response) => response.json()) - .then((responseJson) => { - return createChatMessageResponse(responseJson); - }) .catch((error) => { console.error('Error making REST call to /chat: ', error); return unhappyChatResponse; @@ -63,9 +63,7 @@ export const getSuggestions = async (): Promise => { return await fetch(`${process.env.BACKEND_URL}/suggestions`, { credentials: 'include', }) - .then((response) => { - return response.json(); - }) + .then((response) => response.json()) .catch((error) => { console.error('Error making REST call to /suggestions: ', error); return []; diff --git a/frontend/src/useMessages.ts b/frontend/src/useMessages.ts index 1441fc82..25a38dce 100644 --- a/frontend/src/useMessages.ts +++ b/frontend/src/useMessages.ts @@ -1,6 +1,11 @@ import { useCallback, useState } from 'react'; import { Message, Role } from './components/message'; -import { getResponse, getSuggestions, resetChat } from './server'; +import { + ChatMessageResponse, + getResponse, + getSuggestions, + resetChat, +} from './server'; const starterMessage: Message = { role: Role.Bot, @@ -29,20 +34,29 @@ export const useMessages = (): UseMessagesHook => { } }, []); - const appendMessage = useCallback((message: string, role: Role) => { - setMessages((prevMessages) => [ - ...prevMessages, - { role, content: message, time: new Date().toLocaleTimeString() }, - ]); - }, []); + const appendMessage = useCallback( + (response: ChatMessageResponse, role: Role) => { + setMessages((prevMessages) => [ + ...prevMessages, + { + role, + id: response.id, + content: response.answer, + reasoning: response.reasoning, + time: new Date().toLocaleTimeString(), + }, + ]); + }, + [], + ); const sendMessage = useCallback( async (message: string) => { - appendMessage(message, Role.User); + appendMessage({ answer: message }, Role.User); setWaiting(true); const response = await getResponse(message); setWaiting(false); - appendMessage(response.message, Role.Bot); + appendMessage(response, Role.Bot); if (message !== 'healthcheck') { fetchSuggestions(); }