Skip to content

Commit

Permalink
FS-125 Remove redis TTL
Browse files Browse the repository at this point in the history
  • Loading branch information
mic-smith committed Dec 13, 2024
1 parent a8d76c8 commit 34e8e90
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 16 deletions.
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ FILES_DIRECTORY=files

# redis cache configuration
REDIS_HOST="localhost"
REDIS_CACHE_DURATION=3600

# backend LLM properties
MISTRAL_KEY=my-api-key
Expand Down
6 changes: 4 additions & 2 deletions backend/src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from fastapi import FastAPI, HTTPException, Response, WebSocket, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from src.chat_storage_service import get_chat_message
from src.session.chat_response import get_session_chat_response_ids
from src.chat_storage_service import clear_chat_messages, get_chat_message
from src.directors.report_director import report_on_file_upload
from src.session.file_uploads import clear_session_file_uploads, get_report
from src.session.redis_session_middleware import reset_session
Expand Down Expand Up @@ -88,7 +89,8 @@ async def chat(utterance: str):
async def clear_chat():
logger.info("Delete the chat session")
try:
# clear files first as need session data for file keys
# clear chatresponses and files first as need session data for keys
clear_chat_messages(get_session_chat_response_ids())
clear_session_file_uploads()
reset_session()
return Response(status_code=204)
Expand Down
8 changes: 8 additions & 0 deletions backend/src/chat_storage_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import json
import logging
from typing import TypedDict
import redis

Expand All @@ -14,6 +15,7 @@ class ChatResponse(TypedDict):
reasoning: str | None

config = Config()
logger = logging.getLogger(__name__)

redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True)

Expand All @@ -29,3 +31,9 @@ def get_chat_message(id: str) -> ChatResponse | None:
if parsed_session_data := try_parse_to_json(value):
return parsed_session_data
return None

def clear_chat_messages(ids:list[str]):
if ids:
logger.info(f"Clearing chat message keys {ids}")
for id in ids:
redis_client.delete(CHAT_KEY_PREFIX + id)
2 changes: 2 additions & 0 deletions backend/src/directors/chat_director.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional
from uuid import uuid4

from src.session.chat_response import update_session_chat_response_ids
from src.utils.json import try_pretty_print
from src.chat_storage_service import ChatResponse, store_chat_message
from src.utils import clear_scratchpad, update_scratchpad, get_scratchpad
Expand Down Expand Up @@ -69,6 +70,7 @@ async def question(question: str) -> ChatResponse:
reasoning=try_pretty_print(current_scratchpad))

store_chat_message(response)
update_session_chat_response_ids(response.get("id"))

clear_scratchpad()

Expand Down
20 changes: 20 additions & 0 deletions backend/src/session/chat_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging

from .redis_session_middleware import get_session, set_session

logger = logging.getLogger(__name__)

CHAT_RESPONSE_SESSION_KEY = "chatresponse"

def get_session_chat_response_ids() -> list[str]:
return get_session(CHAT_RESPONSE_SESSION_KEY, [])


def update_session_chat_response_ids(id:str):
ids = get_session_chat_response_ids()
ids.append(id)
set_session(CHAT_RESPONSE_SESSION_KEY, ids)


def clear_session_chat_response_ids():
set_session(CHAT_RESPONSE_SESSION_KEY, [])
6 changes: 3 additions & 3 deletions backend/src/session/file_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def clear_session_file_uploads():
keys.append(REPORT_KEY_PREFIX + meta["uploadId"])

if keys:
keystr = " ".join(keys)
logger.info("Deleting keys " + keystr)
redis_client.delete(keystr)
logger.info(f"Deleting keys {keys}")
for key in keys:
redis_client.delete(key)

set_session(UPLOADS_META_SESSION_KEY, [])

Expand Down
12 changes: 8 additions & 4 deletions backend/src/session/redis_session_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

REQUEST_CONTEXT_KEY = "redis_session_context"
SESSION_COOKIE_NAME = "session_id"
SESSION_TTL = int(config.redis_cache_duration) # config value or default to 1 hour

request_context = contextvars.ContextVar(REQUEST_CONTEXT_KEY)
redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True)
Expand All @@ -24,7 +23,8 @@ async def dispatch(self, request: Request, call_next):
request_context.set(request)

redis_healthy = test_redis_connection()
if (not redis_healthy):

if (not redis_healthy or ignore_request(request)):
response = await call_next(request)
else:
session_data = get_redis_session(request)
Expand All @@ -36,16 +36,19 @@ async def dispatch(self, request: Request, call_next):
response.set_cookie(
SESSION_COOKIE_NAME,
session_id,
max_age=SESSION_TTL,
domain=request.url.hostname,
samesite='strict',
httponly=True,
secure=config.redis_host != "redis"
)

redis_client.set(session_id, json.dumps(request.state.session), ex=SESSION_TTL)
redis_client.set(session_id, json.dumps(request.state.session))

return response

def ignore_request(request:Request) -> bool:
# prevent generating new session for each health check request
return request.url.path == '/health' or request.method == 'OPTIONS'

def get_session(key: str, default=[]):
request: Request = request_context.get()
Expand All @@ -60,6 +63,7 @@ def reset_session():
logger.info("Reset chat session")
request: Request = request_context.get()
request.state.session = {}
logger.info(f"db size {redis_client.dbsize()}")

def get_redis_session(request: Request):
session_id = request.cookies.get(SESSION_COOKIE_NAME)
Expand Down
3 changes: 0 additions & 3 deletions backend/src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
default_neo4j_uri = "bolt://localhost:7687"
default_files_directory = "files"
default_redis_host = "localhost"
default_redis_cache_duration = 3600


class Config(object):
Expand Down Expand Up @@ -38,7 +37,6 @@ def __init__(self):
self.router_model = None
self.files_directory = default_files_directory
self.redis_host = default_redis_host
self.redis_cache_duration = default_redis_cache_duration
self.suggestions_model = None
self.dynamic_knowledge_graph_model = None
self.load_env()
Expand Down Expand Up @@ -77,7 +75,6 @@ def load_env(self):
self.chart_generator_model = os.getenv("CHART_GENERATOR_MODEL")
self.router_model = os.getenv("ROUTER_MODEL")
self.redis_host = os.getenv("REDIS_HOST", default_redis_host)
self.redis_cache_duration = os.getenv("REDIS_CACHE_DURATION", default_redis_cache_duration)
self.suggestions_model = os.getenv("SUGGESTIONS_MODEL")
self.dynamic_knowledge_graph_model = os.getenv("DYNAMIC_KNOWLEDGE_GRAPH_MODEL")
except FileNotFoundError:
Expand Down
4 changes: 4 additions & 0 deletions backend/tests/api/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def test_chat_response_failure(mocker):
def test_chat_delete(mocker):
mock_reset_session = mocker.patch("src.api.app.reset_session")
mock_clear_files = mocker.patch("src.api.app.clear_session_file_uploads")
mock_clear_chat_messages = mocker.patch("src.api.app.clear_chat_messages")
mock_get_session_chat_response_ids = mocker.patch("src.api.app.get_session_chat_response_ids")

response = client.delete("/chat")

mock_clear_chat_messages.assert_called_once()
mock_get_session_chat_response_ids.assert_called_once()
mock_clear_files.assert_called_once()
mock_reset_session.assert_called_once()

Expand Down
29 changes: 29 additions & 0 deletions backend/tests/session/test_chat_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from unittest.mock import patch, MagicMock
from src.session.chat_response import (clear_session_chat_response_ids,
get_session_chat_response_ids,
update_session_chat_response_ids)

@pytest.fixture
def mock_request_context():
with patch('src.session.redis_session_middleware.request_context'):
mock_instance = MagicMock()
mock_instance.get.return_value.state.session = {}
yield mock_instance


def test_session_chat(mocker, mock_request_context):
mocker.patch("src.session.redis_session_middleware.request_context", mock_request_context)

update_session_chat_response_ids("one")
update_session_chat_response_ids("two")
assert get_session_chat_response_ids() == ["one", "two"]


def test_clear_session_chat(mocker, mock_request_context):
mocker.patch("src.session.redis_session_middleware.request_context", mock_request_context)

update_session_chat_response_ids("123")
assert get_session_chat_response_ids() == ["123"]
clear_session_chat_response_ids()
assert get_session_chat_response_ids() == []
9 changes: 6 additions & 3 deletions backend/tests/session/test_file_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_clear_session_file_uploads_meta(mocker, mock_redis, mock_request_contex

clear_session_file_uploads()
assert get_session_file_uploads_meta() == []
mock_redis.delete.assert_called_with("file_upload_1234 report_1234")
mock_redis.delete.assert_any_call("file_upload_1234")
mock_redis.delete.assert_any_call("report_1234")

update_session_file_uploads(file_upload=file)
update_session_file_uploads(file_upload=file2)
Expand All @@ -86,8 +87,10 @@ def test_clear_session_file_uploads_meta(mocker, mock_redis, mock_request_contex

clear_session_file_uploads()
assert get_session_file_uploads_meta() == []
mock_redis.delete.assert_called_with("file_upload_1234 report_1234 file_upload_12345 report_12345")

mock_redis.delete.assert_any_call("file_upload_1234")
mock_redis.delete.assert_any_call("report_1234")
mock_redis.delete.assert_any_call("file_upload_12345")
mock_redis.delete.assert_any_call("report_12345")

def test_store_report(mocker, mock_redis):
mocker.patch("src.session.file_uploads.redis_client", mock_redis)
Expand Down

0 comments on commit 34e8e90

Please sign in to comment.