diff --git a/.env.example b/.env.example index a3bac444..0210416f 100644 --- a/.env.example +++ b/.env.example @@ -18,7 +18,6 @@ FILES_DIRECTORY=files # redis cache configuration REDIS_HOST="localhost" -REDIS_CACHE_DURATION=3600 # backend LLM properties MISTRAL_KEY=my-api-key diff --git a/backend/src/api/app.py b/backend/src/api/app.py index 82bc01f1..dc4eb205 100644 --- a/backend/src/api/app.py +++ b/backend/src/api/app.py @@ -6,7 +6,8 @@ from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from src.utils.scratchpad import ScratchPadMiddleware -from src.chat_storage_service import get_chat_message +from src.session.chat_response import get_session_chat_response_ids +from src.chat_storage_service import clear_chat_messages, get_chat_message from src.directors.report_director import report_on_file_upload from src.session.file_uploads import clear_session_file_uploads, get_report from src.session.redis_session_middleware import reset_session @@ -90,7 +91,8 @@ async def chat(utterance: str): async def clear_chat(): logger.info("Delete the chat session") try: - # clear files first as need session data for file keys + # clear chatresponses and files first as need session data for keys + clear_chat_messages(get_session_chat_response_ids()) clear_session_file_uploads() reset_session() return Response(status_code=204) diff --git a/backend/src/chat_storage_service.py b/backend/src/chat_storage_service.py index c048c06c..ce13e5fd 100644 --- a/backend/src/chat_storage_service.py +++ b/backend/src/chat_storage_service.py @@ -1,5 +1,6 @@ import json +import logging from typing import TypedDict import redis @@ -14,6 +15,7 @@ class ChatResponse(TypedDict): reasoning: str | None config = Config() +logger = logging.getLogger(__name__) redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True) @@ -29,3 +31,9 @@ def get_chat_message(id: str) -> ChatResponse | None: if parsed_session_data := try_parse_to_json(value): return parsed_session_data return None + +def clear_chat_messages(ids:list[str]): + if ids: + logger.info(f"Clearing chat message keys {ids}") + for id in ids: + redis_client.delete(CHAT_KEY_PREFIX + id) diff --git a/backend/src/directors/chat_director.py b/backend/src/directors/chat_director.py index 002c3bd8..8c89bec4 100644 --- a/backend/src/directors/chat_director.py +++ b/backend/src/directors/chat_director.py @@ -4,6 +4,7 @@ from typing import Optional from uuid import uuid4 +from src.session.chat_response import update_session_chat_response_ids from src.utils.json import try_pretty_print from src.chat_storage_service import ChatResponse, store_chat_message from src.utils import clear_scratchpad, update_scratchpad, get_scratchpad @@ -69,6 +70,7 @@ async def question(question: str) -> ChatResponse: reasoning=try_pretty_print(current_scratchpad)) store_chat_message(response) + update_session_chat_response_ids(response.get("id")) clear_scratchpad() diff --git a/backend/src/session/chat_response.py b/backend/src/session/chat_response.py new file mode 100644 index 00000000..bfe1e093 --- /dev/null +++ b/backend/src/session/chat_response.py @@ -0,0 +1,20 @@ +import logging + +from .redis_session_middleware import get_session, set_session + +logger = logging.getLogger(__name__) + +CHAT_RESPONSE_SESSION_KEY = "chatresponse" + +def get_session_chat_response_ids() -> list[str]: + return get_session(CHAT_RESPONSE_SESSION_KEY, []) + + +def update_session_chat_response_ids(id:str): + ids = get_session_chat_response_ids() + ids.append(id) + set_session(CHAT_RESPONSE_SESSION_KEY, ids) + + +def clear_session_chat_response_ids(): + set_session(CHAT_RESPONSE_SESSION_KEY, []) diff --git a/backend/src/session/file_uploads.py b/backend/src/session/file_uploads.py index 6b0d7108..4c462171 100644 --- a/backend/src/session/file_uploads.py +++ b/backend/src/session/file_uploads.py @@ -76,9 +76,9 @@ def clear_session_file_uploads(): keys.append(REPORT_KEY_PREFIX + meta["uploadId"]) if keys: - keystr = " ".join(keys) - logger.info("Deleting keys " + keystr) - redis_client.delete(keystr) + logger.info(f"Deleting keys {keys}") + for key in keys: + redis_client.delete(key) set_session(UPLOADS_META_SESSION_KEY, []) diff --git a/backend/src/session/redis_session_middleware.py b/backend/src/session/redis_session_middleware.py index 1b6a1606..d04072fd 100644 --- a/backend/src/session/redis_session_middleware.py +++ b/backend/src/session/redis_session_middleware.py @@ -14,7 +14,6 @@ REQUEST_CONTEXT_KEY = "redis_session_context" SESSION_COOKIE_NAME = "session_id" -SESSION_TTL = int(config.redis_cache_duration) # config value or default to 1 hour request_context = contextvars.ContextVar(REQUEST_CONTEXT_KEY) redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True) @@ -25,7 +24,8 @@ async def dispatch(self, request: Request, call_next): request_context.set(request) redis_healthy = test_redis_connection() - if (not redis_healthy): + + if (not redis_healthy or ignore_request(request)): response = await call_next(request) else: session_data = get_redis_session(request) @@ -37,16 +37,19 @@ async def dispatch(self, request: Request, call_next): response.set_cookie( SESSION_COOKIE_NAME, session_id, - max_age=SESSION_TTL, domain=request.url.hostname, samesite='strict', httponly=True, secure=config.redis_host != "redis" ) - redis_client.set(session_id, json.dumps(request.state.session), ex=SESSION_TTL) + redis_client.set(session_id, json.dumps(request.state.session)) + return response +def ignore_request(request:Request) -> bool: + # prevent generating new session for each health check request + return request.url.path == '/health' or request.method == 'OPTIONS' def get_session(key: str, default=[]): request: Request = request_context.get() @@ -62,6 +65,7 @@ def reset_session(): logger.info("Reset chat session") request: Request = request_context.get() request.state.session = {} + logger.info(f"db size {redis_client.dbsize()}") def get_redis_session(request: Request): diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py index 2b8e2001..6f3cb724 100644 --- a/backend/src/utils/config.py +++ b/backend/src/utils/config.py @@ -5,7 +5,6 @@ default_neo4j_uri = "bolt://localhost:7687" default_files_directory = "files" default_redis_host = "localhost" -default_redis_cache_duration = 3600 class Config(object): @@ -40,7 +39,6 @@ def __init__(self): self.router_model = None self.files_directory = default_files_directory self.redis_host = default_redis_host - self.redis_cache_duration = default_redis_cache_duration self.suggestions_model = None self.dynamic_knowledge_graph_model = None self.load_env() @@ -81,7 +79,6 @@ def load_env(self): self.chart_generator_model = os.getenv("CHART_GENERATOR_MODEL") self.router_model = os.getenv("ROUTER_MODEL") self.redis_host = os.getenv("REDIS_HOST", default_redis_host) - self.redis_cache_duration = os.getenv("REDIS_CACHE_DURATION", default_redis_cache_duration) self.suggestions_model = os.getenv("SUGGESTIONS_MODEL") self.dynamic_knowledge_graph_model = os.getenv("DYNAMIC_KNOWLEDGE_GRAPH_MODEL") except FileNotFoundError: diff --git a/backend/tests/api/app_test.py b/backend/tests/api/app_test.py index 2fb72e74..f44f79a7 100644 --- a/backend/tests/api/app_test.py +++ b/backend/tests/api/app_test.py @@ -53,9 +53,13 @@ def test_chat_response_failure(mocker): def test_chat_delete(mocker): mock_reset_session = mocker.patch("src.api.app.reset_session") mock_clear_files = mocker.patch("src.api.app.clear_session_file_uploads") + mock_clear_chat_messages = mocker.patch("src.api.app.clear_chat_messages") + mock_get_session_chat_response_ids = mocker.patch("src.api.app.get_session_chat_response_ids") response = client.delete("/chat") + mock_clear_chat_messages.assert_called_once() + mock_get_session_chat_response_ids.assert_called_once() mock_clear_files.assert_called_once() mock_reset_session.assert_called_once() diff --git a/backend/tests/session/test_chat_response.py b/backend/tests/session/test_chat_response.py new file mode 100644 index 00000000..02fbb9e2 --- /dev/null +++ b/backend/tests/session/test_chat_response.py @@ -0,0 +1,29 @@ +import pytest +from unittest.mock import patch, MagicMock +from src.session.chat_response import (clear_session_chat_response_ids, + get_session_chat_response_ids, + update_session_chat_response_ids) + +@pytest.fixture +def mock_request_context(): + with patch('src.session.redis_session_middleware.request_context'): + mock_instance = MagicMock() + mock_instance.get.return_value.state.session = {} + yield mock_instance + + +def test_session_chat(mocker, mock_request_context): + mocker.patch("src.session.redis_session_middleware.request_context", mock_request_context) + + update_session_chat_response_ids("one") + update_session_chat_response_ids("two") + assert get_session_chat_response_ids() == ["one", "two"] + + +def test_clear_session_chat(mocker, mock_request_context): + mocker.patch("src.session.redis_session_middleware.request_context", mock_request_context) + + update_session_chat_response_ids("123") + assert get_session_chat_response_ids() == ["123"] + clear_session_chat_response_ids() + assert get_session_chat_response_ids() == [] diff --git a/backend/tests/session/test_file_uploads.py b/backend/tests/session/test_file_uploads.py index dac086e2..07832a12 100644 --- a/backend/tests/session/test_file_uploads.py +++ b/backend/tests/session/test_file_uploads.py @@ -77,7 +77,8 @@ def test_clear_session_file_uploads_meta(mocker, mock_redis, mock_request_contex clear_session_file_uploads() assert get_session_file_uploads_meta() == [] - mock_redis.delete.assert_called_with("file_upload_1234 report_1234") + mock_redis.delete.assert_any_call("file_upload_1234") + mock_redis.delete.assert_any_call("report_1234") update_session_file_uploads(file_upload=file) update_session_file_uploads(file_upload=file2) @@ -86,8 +87,10 @@ def test_clear_session_file_uploads_meta(mocker, mock_redis, mock_request_contex clear_session_file_uploads() assert get_session_file_uploads_meta() == [] - mock_redis.delete.assert_called_with("file_upload_1234 report_1234 file_upload_12345 report_12345") - + mock_redis.delete.assert_any_call("file_upload_1234") + mock_redis.delete.assert_any_call("report_1234") + mock_redis.delete.assert_any_call("file_upload_12345") + mock_redis.delete.assert_any_call("report_12345") def test_store_report(mocker, mock_redis): mocker.patch("src.session.file_uploads.redis_client", mock_redis)