forked from ScottLogic/InferLLM
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
F-62: Implement Redis sessions to replace conversation-history.txt (#10)
* Removed the current conversation history settings on InferGPT * Included the redis session files from POC Repo --------- Co-authored-by: Gagan Singh <[email protected]>
- Loading branch information
1 parent
6ca8e97
commit 675b6ee
Showing
23 changed files
with
374 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,6 @@ googlesearch-python==1.2.4 | |
matplotlib==3.9.1 | ||
pillow==10.4.0 | ||
pypdf==4.3.1 | ||
hiredis==3.0.0 | ||
redis==5.0.8 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .redis_session_middleware import RedisSessionMiddleware | ||
from .chat import Message, clear_session_chat, get_session_chat, update_session_chat | ||
from .cypher_query import CypherQuery, clear_session_cypher_query, get_session_cypher_query, update_session_cypher_query | ||
|
||
__all__ = [ | ||
"RedisSessionMiddleware", | ||
"Message", | ||
"clear_session_chat", | ||
"get_session_chat", | ||
"update_session_chat", | ||
"CypherQuery", | ||
"clear_session_cypher_query", | ||
"get_session_cypher_query", | ||
"update_session_cypher_query" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import TypedDict | ||
import logging | ||
|
||
from .redis_session_middleware import get_session, set_session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
CHAT_SESSION_KEY = "chat" | ||
|
||
class Message(TypedDict): | ||
role: str | None # user or system | ||
content: str | None | ||
|
||
|
||
def get_session_chat() -> list[Message] | None: | ||
return get_session(CHAT_SESSION_KEY, []) | ||
|
||
|
||
def update_session_chat(role=None, content=None): | ||
chat_session = get_session(CHAT_SESSION_KEY, []) | ||
if not chat_session: | ||
chat_session = [] | ||
chat_session.append({"role": role, "content": content}) | ||
set_session(CHAT_SESSION_KEY, chat_session) | ||
|
||
|
||
def clear_session_chat(): | ||
set_session(CHAT_SESSION_KEY, []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import TypedDict | ||
import uuid | ||
import logging | ||
|
||
from .redis_session_middleware import get_session, set_session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
CYPHER_QUERY_SESSION_KEY = "cypher_query" | ||
|
||
class CypherQuery(TypedDict): | ||
queryId: uuid.UUID | ||
cypher_query: str | ||
|
||
|
||
def get_session_cypher_query() -> list[CypherQuery] | None: | ||
return get_session(CYPHER_QUERY_SESSION_KEY, []) | ||
|
||
|
||
def update_session_cypher_query(queryid=None, cypher_query=None): | ||
cypher_query_session = get_session(CYPHER_QUERY_SESSION_KEY, []) | ||
if not cypher_query_session: | ||
# initialise the session object | ||
set_session(CYPHER_QUERY_SESSION_KEY, cypher_query_session) | ||
|
||
cypher_query_session.append({"queryid": str(queryid), "cypher_query": cypher_query}) | ||
|
||
|
||
def clear_session_cypher_query(): | ||
logger.info("Cypher query session cleared") | ||
set_session(CYPHER_QUERY_SESSION_KEY, []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
from uuid import uuid4 | ||
import redis | ||
from src.utils import test_redis_connection | ||
from src.utils import Config | ||
from src.utils import try_parse_to_json | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.requests import Request | ||
import contextvars | ||
import logging | ||
|
||
config = Config() | ||
logger = logging.getLogger(__name__) | ||
|
||
REQUEST_CONTEXT_KEY = "redis_session_context" | ||
SESSION_COOKIE_NAME = "session_id" | ||
SESSION_TTL = int(config.redis_cache_duration) # config value or default to 1 hour | ||
|
||
request_context = contextvars.ContextVar(REQUEST_CONTEXT_KEY) | ||
redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True) | ||
|
||
class RedisSessionMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
request_context.set(request) | ||
|
||
redis_healthy = test_redis_connection() | ||
if (not redis_healthy): | ||
response = await call_next(request) | ||
else: | ||
session_data = get_redis_session(request) | ||
request.state.session = session_data | ||
|
||
response = await call_next(request) | ||
|
||
session_id = request.cookies.get(SESSION_COOKIE_NAME) or str(uuid4()) | ||
response.set_cookie( | ||
SESSION_COOKIE_NAME, | ||
session_id, | ||
max_age=SESSION_TTL, | ||
domain=request.url.hostname, | ||
samesite='strict', | ||
httponly=True, | ||
secure=config.redis_host != "redis" | ||
) | ||
|
||
redis_client.set(session_id, json.dumps(request.state.session), ex=SESSION_TTL) | ||
return response | ||
|
||
|
||
def get_session(key: str, default=[]): | ||
request: Request = request_context.get() | ||
return request.state.session.get(key, default) | ||
|
||
|
||
def set_session(key: str, value): | ||
request: Request = request_context.get() | ||
request.state.session[key] = value | ||
|
||
|
||
def get_redis_session(request: Request): | ||
session_id = request.cookies.get(SESSION_COOKIE_NAME) | ||
logger.info(f"Attempting to get session for session_id: {session_id}") | ||
if session_id: | ||
session_data = redis_client.get(session_id) | ||
logger.info(f"***************** Session data retrieved from Redis for {session_id}: {session_data}") | ||
if session_data and isinstance(session_data, str): | ||
parsed_session_data = try_parse_to_json(session_data) | ||
if parsed_session_data: | ||
logger.info(f"Parsed session data: {parsed_session_data}") | ||
return parsed_session_data | ||
return {} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,17 @@ | ||
import json | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def to_json(input, error_message="Failed to interpret JSON"): | ||
try: | ||
return json.loads(input) | ||
except Exception: | ||
raise Exception(f'{error_message}: "{input}"') | ||
|
||
def try_parse_to_json(json_string: str): | ||
try: | ||
return json.loads(json_string) | ||
except json.JSONDecodeError as error: | ||
logger.error(f"Error parsing json: {error}") | ||
return None |
Oops, something went wrong.