Skip to content

Commit

Permalink
FS-99 Clarify answer
Browse files Browse the repository at this point in the history
  • Loading branch information
mic-smith committed Nov 19, 2024
1 parent ac8e5c4 commit 179a43f
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 35 deletions.
12 changes: 12 additions & 0 deletions backend/src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import FastAPI, HTTPException, Response, WebSocket, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from src.chat_storage_service import get_chat_message
from src.session.file_uploads import clear_session_file_uploads
from src.session.redis_session_middleware import reset_session
from src.utils.graph_db_utils import populate_db
Expand Down Expand Up @@ -109,6 +110,17 @@ async def clear_chat():
logger.exception(e)
return Response(status_code=500)

@app.get("/chat/{id}")
def chat_message(id: str):
logger.info(f"Get chat message called with id: {id}")
try:
final_result = get_chat_message(id)
if final_result is None:
return JSONResponse(status_code=404, content=f"Message with id {id} not found")
return JSONResponse(status_code=200, content=final_result)
except Exception as e:
logger.exception(e)
return JSONResponse(status_code=500, content=chat_fail_response)

@app.get("/suggestions")
async def suggestions():
Expand Down
34 changes: 34 additions & 0 deletions backend/src/chat_storage_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

import json
import logging
from typing import TypedDict
import redis

from src.utils.json import try_parse_to_json
from src.utils import Config

class ChatResponse(TypedDict):
id: str
question:str
answer: str
reasoning: str | None

logger = logging.getLogger(__name__)

config = Config()

redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True)

CHAT_KEY_PREFIX = "chat_"

def store_chat_message(chat:ChatResponse):
redis_client.set(CHAT_KEY_PREFIX + chat["id"], json.dumps(chat))


def get_chat_message(id: str) -> ChatResponse | None:
value = redis_client.get(CHAT_KEY_PREFIX + id)
if value and isinstance(value, str):
parsed_session_data = try_parse_to_json(value)
if parsed_session_data:
return parsed_session_data
return None
21 changes: 17 additions & 4 deletions backend/src/director.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
import logging
from uuid import uuid4

from src.utils.json import try_pretty_print
from src.chat_storage_service import ChatResponse, store_chat_message
from src.utils import clear_scratchpad, update_scratchpad, get_scratchpad
from src.session import update_session_chat
from src.agents import get_intent_agent, get_answer_agent
Expand All @@ -13,8 +17,7 @@
engine = PromptEngine()
director_prompt = engine.load_prompt("director")


async def question(question: str) -> str:
async def question(question: str) -> ChatResponse:
intent = await get_intent_agent().invoke(question)
intent_json = json.loads(intent)
update_session_chat(role="user", content=question)
Expand All @@ -33,12 +36,22 @@ async def question(question: str) -> str:
generated_figure = entry["result"]
await connection_manager.send_chart({"type": "image", "data": generated_figure})
clear_scratchpad()
return ""
return ChatResponse(id=str(uuid4()),
question=question,
answer="",
reasoning=try_pretty_print(current_scratchpad))

final_answer = await get_answer_agent().invoke(question)
update_session_chat(role="system", content=final_answer)
logger.info(f"final answer: {final_answer}")

response = ChatResponse(id=str(uuid4()),
question=question,
answer=final_answer,
reasoning=try_pretty_print(current_scratchpad))

store_chat_message(response)

clear_scratchpad()

return final_answer
return response
7 changes: 7 additions & 0 deletions backend/src/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ def try_parse_to_json(json_string: str):
except json.JSONDecodeError as error:
logger.error(f"Error parsing json: {error}")
return None

def try_pretty_print(obj):
try:
return json.dumps(obj, indent=4)
except Exception as error:
logger.error(f"Error pretty printing json: {error}")
return None
4 changes: 2 additions & 2 deletions backend/tests/BDD/step_defs/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_response(context):
@then(parsers.parse("the response to this '{prompt}' should match the '{expected_response}'"))
def check_response_includes_expected_response(context, prompt, expected_response):
response = send_prompt(prompt)
actual_response = response.json()
actual_response = response.json()["answer"]

try:
expected_value = Decimal(str(expected_response).strip())
Expand Down Expand Up @@ -81,5 +81,5 @@ def check_response_includes_expected_response(context, prompt, expected_response
@then(parsers.parse("the response to this '{prompt}' should give a confident answer"))
def check_bot_response_confidence(prompt):
response = send_prompt(prompt)
result = check_response_confidence(prompt, response.json())
result = check_response_confidence(prompt, response.json()["answer"])
assert result["score"] == 1, "The bot response is not confident enough. \nReasoning: " + result["reasoning"]
18 changes: 18 additions & 0 deletions backend/tests/api/app_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi.testclient import TestClient
import pytest
from src.chat_storage_service import ChatResponse
from src.api import app, healthy_response, unhealthy_neo4j_response, chat_fail_response

client = TestClient(app)
Expand Down Expand Up @@ -73,6 +74,23 @@ def test_chat_delete(mocker):

assert response.status_code == 204

def test_chat_message_success(mocker):
message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=message)

response = client.get("/chat/123")

mock_get_chat_message.assert_called_with("123")
assert response.status_code == 200
assert response.json() == message

def test_chat_message_not_found(mocker):
mock_get_chat_message = mocker.patch("src.api.app.get_chat_message", return_value=None)

response = client.get("/chat/123")

mock_get_chat_message.assert_called_with("123")
assert response.status_code == 404

@pytest.mark.asyncio
async def test_lifespan_populates_db(mocker, mock_initial_data) -> None:
Expand Down
32 changes: 32 additions & 0 deletions backend/tests/chat_storage_service_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from unittest.mock import MagicMock, patch

import pytest

from src.chat_storage_service import ChatResponse, get_chat_message, store_chat_message

@pytest.fixture
def mock_redis():
with patch('src.chat_storage_service.redis_client') as mock_redis:
mock_instance = MagicMock()
mock_redis.return_value = mock_instance
yield mock_instance

def test_store_chat_message(mocker, mock_redis):
mocker.patch('src.chat_storage_service.redis_client', mock_redis)

message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
store_chat_message(message)

mock_redis.set.assert_called_once_with("chat_1", json.dumps(message))


def test_get_chat_message(mocker, mock_redis):
mocker.patch('src.chat_storage_service.redis_client', mock_redis)

message = ChatResponse(id="1", question="Question", answer="Answer", reasoning="Reasoning")
mock_redis.get.return_value = json.dumps(message)

value = get_chat_message("1")

assert value == message
46 changes: 39 additions & 7 deletions frontend/src/components/message.module.css
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
.container {
align-items: center;
display: flex;
flex-direction: row;
justify-content: left;
border-radius: 16px;
width: 100%;
margin-bottom: 8px;
}

.message_container {
display: flex;
}

.bot {
align-self: flex-start;
background-color: var(--grey-50);
border: 1px solid var(--grey-300);
box-sizing: border-box;
}

.user {
align-self: flex-start;
background-color: var(--primary);
box-shadow: inset 0 0 0 100vmax rgba(255, 255, 255, 0.65);
}
Expand All @@ -25,11 +23,45 @@
width: 40px;
height: 40px;
margin: 16px;
align-self: flex-start;
}

.messageStyle {
margin: 24px 32px 16px 0;
line-height: 22px;
font-size: 16px;
}

.reasoning_header {
font-weight: bold;
font-size: 16px;
display: flex;
justify-content: space-between;
align-items: center;
border-top: 1px solid var(--grey-300);
padding: 0 24px 0 16px;
height: 40px;
}

.reasoning_header:hover {
background-color: var(--grey-200);
cursor: pointer;
}

.expandIcon {
width: 24px;
height: 24px;
}

.reasoning_header_expanded {
border-top: 1px solid var(--grey-500);

.expandIcon {
transform: rotate(90deg);
}
}

.reason {
padding: 12px 32px;
font-size: 16px;
white-space: pre-wrap;
}
38 changes: 33 additions & 5 deletions frontend/src/components/message.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import classNames from 'classnames';
import React, { useMemo } from 'react';
import React, { useState } from 'react';
import styles from './message.module.css';
import UserIcon from '../icons/account-circle.svg';
import BotIcon from '../icons/logomark.svg';
import ChevronIcon from '../icons/chevron.svg';

export enum Role {
User = 'User',
Bot = 'Bot',
}

export interface Message {
id?: string;
role: Role;
content: string;
reasoning?: string;
time: string;
}

Expand All @@ -36,14 +39,39 @@ const roleStyleMap: Record<Role, MessageStyle> = {
};

export const MessageComponent = ({ message }: MessageProps) => {
const { content, role } = message;
const { content, role, reasoning } = message;

const { class: roleClass, icon } = useMemo(() => roleStyleMap[role], [role]);
const { class: roleClass, icon } = roleStyleMap[role];

const [expanded, setExpanded] = useState(false);

return (
<div className={classNames(styles.container, roleClass)}>
<img src={icon} className={styles.iconStyle} />
<p className={styles.messageStyle}>{content}</p>
<div className={styles.message_container}>
<img src={icon} className={styles.iconStyle} />
<p className={styles.messageStyle}>{content}</p>
</div>
{role == Role.Bot && reasoning && (
<>
<div
className={classNames(styles.reasoning_header, {
[styles.reasoning_header_expanded]: expanded,
})}
role="button"
tabIndex={0}
onClick={() => setExpanded(!expanded)}
onKeyDown={(event) => {
if (event.key === 'Enter' || event.key === ' ') {
setExpanded(!expanded);
}
}}
>
How I came to this conclusion
<img className={styles.expandIcon} src={ChevronIcon} />
</div>
{expanded && <div className={styles.reason}>{reasoning}</div>}
</>
)}
</div>
);
};
14 changes: 6 additions & 8 deletions frontend/src/server.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
export interface ChatMessageResponse {
message: string;
id?: string;
question?: string;
answer: string;
reasoning?: string;
}

function createChatMessageResponse(message: string): ChatMessageResponse {
return { message };
return { answer: message };
}

export const getResponse = async (
Expand Down Expand Up @@ -50,9 +53,6 @@ const callChatEndpoint = async (
return response;
})
.then((response) => response.json())
.then((responseJson) => {
return createChatMessageResponse(responseJson);
})
.catch((error) => {
console.error('Error making REST call to /chat: ', error);
return unhappyChatResponse;
Expand All @@ -63,9 +63,7 @@ export const getSuggestions = async (): Promise<string[]> => {
return await fetch(`${process.env.BACKEND_URL}/suggestions`, {
credentials: 'include',
})
.then((response) => {
return response.json();
})
.then((response) => response.json())
.catch((error) => {
console.error('Error making REST call to /suggestions: ', error);
return [];
Expand Down
Loading

0 comments on commit 179a43f

Please sign in to comment.