Skip to content

Commit

Permalink
F-62: Implement Redis sessions to replace conversation-history.txt (#10)
Browse files Browse the repository at this point in the history
* Removed the current conversation history settings on InferGPT

* Included the redis session files from POC Repo

---------

Co-authored-by: Gagan Singh <[email protected]>
  • Loading branch information
gaganahluwalia and Gagan Singh authored Oct 29, 2024
1 parent 6ca8e97 commit 675b6ee
Show file tree
Hide file tree
Showing 23 changed files with 374 additions and 47 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ WEB_AGENT_MODEL="gpt-4o mini"
CHART_GENERATOR_MODEL="gpt-4o mini"
ROUTER_MODEL="gpt-4o mini"
FILE_AGENT_MODEL="gpt-4o mini"

REDIS_HOST="redis"
REDIS_CACHE_DURATION=3600
3 changes: 3 additions & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ googlesearch-python==1.2.4
matplotlib==3.9.1
pillow==10.4.0
pypdf==4.3.1
hiredis==3.0.0
redis==5.0.8

8 changes: 7 additions & 1 deletion backend/src/agents/answer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from src.utils import get_scratchpad
from src.prompts import PromptEngine
from src.agents import Agent, agent
from src.session import get_session_chat

engine = PromptEngine()

Expand All @@ -14,6 +15,11 @@
class AnswerAgent(Agent):
async def invoke(self, utterance: str) -> str:
final_scratchpad = get_scratchpad()
create_answer = engine.load_prompt("create-answer", final_scratchpad=final_scratchpad, datetime=datetime.now())
create_answer = engine.load_prompt(
"create-answer",
chat_history=get_session_chat(),
final_scratchpad=final_scratchpad,
datetime=datetime.now()
)

return await self.llm.chat(self.model, create_answer, user_prompt=utterance)
36 changes: 2 additions & 34 deletions backend/src/agents/intent_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from src.prompts import PromptEngine
from src.agents import Agent, agent
from src.session import get_session_chat
import logging
import os
import json
from src.utils.config import Config


Expand All @@ -11,12 +10,7 @@
engine = PromptEngine()
intent_format = engine.load_prompt("intent-format")
logger = logging.getLogger(__name__)
FILES_DIRECTORY = f"/app/{config.files_directory}"

# Constants for response status
IGNORE_VALIDATION = "true"
STATUS_SUCCESS = "success"
STATUS_ERROR = "error"

@agent(
name="IntentAgent",
Expand All @@ -25,32 +19,6 @@
)
class IntentAgent(Agent):

async def read_file_core(self, file_path: str) -> str:
full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
try:
with open(full_path, 'r') as file:
content = file.read()
return content
except FileNotFoundError:
error_message = f"File {file_path} not found."
logger.error(error_message)
return ""
except Exception as e:
logger.error(f"Error reading file {full_path}: {e}")
return ""

async def invoke(self, utterance: str) -> str:
chat_history = await self.read_file_core("conversation-history.txt")

user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history)

user_prompt = engine.load_prompt("intent", question=utterance, chat_history=get_session_chat())
return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True)


# Utility function for error responses
def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
return json.dumps({
"content": content,
"ignore_validation": IGNORE_VALIDATION,
"status": status
}, indent=4)
4 changes: 3 additions & 1 deletion backend/src/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from src.utils import Config, test_connection
from src.director import question
from src.websockets.connection_manager import connection_manager, parse_message
from src.session import RedisSessionMiddleware
from src.utils.cyper_import_data_from_csv import import_data_from_csv_script

config_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "config.ini"))
Expand Down Expand Up @@ -55,6 +56,8 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)

app.add_middleware(RedisSessionMiddleware)

health_prefix = "InferESG healthcheck: "
further_guidance = "Please check the README files for further guidance."

Expand All @@ -77,7 +80,6 @@ async def health_check():
finally:
return response


@app.get("/chat")
async def chat(utterance: str):
logger.info(f"Chat method called with utterance: {utterance}")
Expand Down
3 changes: 3 additions & 0 deletions backend/src/director.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from src.utils import clear_scratchpad, update_scratchpad, get_scratchpad
from src.session import update_session_chat
from src.agents import get_intent_agent, get_answer_agent
from src.prompts import PromptEngine
from src.supervisors import solve_all
Expand All @@ -16,6 +17,7 @@
async def question(question: str) -> str:
intent = await get_intent_agent().invoke(question)
intent_json = json.loads(intent)
update_session_chat(role="user", content=question)
logger.info(f"Intent determined: {intent}")

try:
Expand All @@ -34,6 +36,7 @@ async def question(question: str) -> str:
return ""

final_answer = await get_answer_agent().invoke(question)
update_session_chat(role="system", content=final_answer)
logger.info(f"final answer: {final_answer}")

clear_scratchpad()
Expand Down
3 changes: 3 additions & 0 deletions backend/src/prompts/templates/create-answer.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
You have been provided the final scratchpad which contains the results for the question in the user prompt.
Your goal is to turn the results into a natural language format to present to the user.

The conversation history is:
{{ chat_history }}

By using the final scratchpad below:
{{ final_scratchpad }}

Expand Down
13 changes: 4 additions & 9 deletions backend/src/prompts/templates/intent.j2
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
You are an expert in determining the intent behind a user's question.

The question is:

{{ question }}

The previous chat history is:

The conversation history is:
{{ chat_history }}

The question is:
{{ question }}

Your task is to accurately comprehend the intentions behind the current question.
The question can be composed of different intents and may depend on the context provided by the previous question and its response.

Expand Down Expand Up @@ -82,6 +80,3 @@ Finally, if no tool fits the task, return the following:
"tool_parameters": "{}",
"reasoning": "No tool was appropriate for the task"
}

Important:
Please always create the last intent to append the retrieved info in a 'conversation-history.txt' file and make sure this history file is always named 'conversation-history.txt'
15 changes: 15 additions & 0 deletions backend/src/session/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .redis_session_middleware import RedisSessionMiddleware
from .chat import Message, clear_session_chat, get_session_chat, update_session_chat
from .cypher_query import CypherQuery, clear_session_cypher_query, get_session_cypher_query, update_session_cypher_query

__all__ = [
"RedisSessionMiddleware",
"Message",
"clear_session_chat",
"get_session_chat",
"update_session_chat",
"CypherQuery",
"clear_session_cypher_query",
"get_session_cypher_query",
"update_session_cypher_query"
]
28 changes: 28 additions & 0 deletions backend/src/session/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import TypedDict
import logging

from .redis_session_middleware import get_session, set_session

logger = logging.getLogger(__name__)

CHAT_SESSION_KEY = "chat"

class Message(TypedDict):
role: str | None # user or system
content: str | None


def get_session_chat() -> list[Message] | None:
return get_session(CHAT_SESSION_KEY, [])


def update_session_chat(role=None, content=None):
chat_session = get_session(CHAT_SESSION_KEY, [])
if not chat_session:
chat_session = []
chat_session.append({"role": role, "content": content})
set_session(CHAT_SESSION_KEY, chat_session)


def clear_session_chat():
set_session(CHAT_SESSION_KEY, [])
31 changes: 31 additions & 0 deletions backend/src/session/cypher_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import TypedDict
import uuid
import logging

from .redis_session_middleware import get_session, set_session

logger = logging.getLogger(__name__)

CYPHER_QUERY_SESSION_KEY = "cypher_query"

class CypherQuery(TypedDict):
queryId: uuid.UUID
cypher_query: str


def get_session_cypher_query() -> list[CypherQuery] | None:
return get_session(CYPHER_QUERY_SESSION_KEY, [])


def update_session_cypher_query(queryid=None, cypher_query=None):
cypher_query_session = get_session(CYPHER_QUERY_SESSION_KEY, [])
if not cypher_query_session:
# initialise the session object
set_session(CYPHER_QUERY_SESSION_KEY, cypher_query_session)

cypher_query_session.append({"queryid": str(queryid), "cypher_query": cypher_query})


def clear_session_cypher_query():
logger.info("Cypher query session cleared")
set_session(CYPHER_QUERY_SESSION_KEY, [])
72 changes: 72 additions & 0 deletions backend/src/session/redis_session_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
from uuid import uuid4
import redis
from src.utils import test_redis_connection
from src.utils import Config
from src.utils import try_parse_to_json
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import contextvars
import logging

config = Config()
logger = logging.getLogger(__name__)

REQUEST_CONTEXT_KEY = "redis_session_context"
SESSION_COOKIE_NAME = "session_id"
SESSION_TTL = int(config.redis_cache_duration) # config value or default to 1 hour

request_context = contextvars.ContextVar(REQUEST_CONTEXT_KEY)
redis_client = redis.Redis(host=config.redis_host, port=6379, decode_responses=True)

class RedisSessionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_context.set(request)

redis_healthy = test_redis_connection()
if (not redis_healthy):
response = await call_next(request)
else:
session_data = get_redis_session(request)
request.state.session = session_data

response = await call_next(request)

session_id = request.cookies.get(SESSION_COOKIE_NAME) or str(uuid4())
response.set_cookie(
SESSION_COOKIE_NAME,
session_id,
max_age=SESSION_TTL,
domain=request.url.hostname,
samesite='strict',
httponly=True,
secure=config.redis_host != "redis"
)

redis_client.set(session_id, json.dumps(request.state.session), ex=SESSION_TTL)
return response


def get_session(key: str, default=[]):
request: Request = request_context.get()
return request.state.session.get(key, default)


def set_session(key: str, value):
request: Request = request_context.get()
request.state.session[key] = value


def get_redis_session(request: Request):
session_id = request.cookies.get(SESSION_COOKIE_NAME)
logger.info(f"Attempting to get session for session_id: {session_id}")
if session_id:
session_data = redis_client.get(session_id)
logger.info(f"***************** Session data retrieved from Redis for {session_id}: {session_data}")
if session_data and isinstance(session_data, str):
parsed_session_data = try_parse_to_json(session_data)
if parsed_session_data:
logger.info(f"Parsed session data: {parsed_session_data}")
return parsed_session_data
return {}

5 changes: 4 additions & 1 deletion backend/src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .config import Config
from .graph_db_utils import test_connection
from .json import to_json
from .redis_utils import test_redis_connection
from .json import to_json, try_parse_to_json
from .scratchpad import clear_scratchpad, get_scratchpad, update_scratchpad, Scratchpad

__all__ = [
Expand All @@ -11,4 +12,6 @@
"test_connection",
"to_json",
"update_scratchpad",
"test_redis_connection",
"try_parse_to_json",
]
6 changes: 6 additions & 0 deletions backend/src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
default_frontend_url = "http://localhost:8650"
default_neo4j_uri = "bolt://localhost:7687"
default_files_directory = "files"
default_redis_host = "redis"
default_redis_cache_duration = 3600


class Config(object):
Expand Down Expand Up @@ -37,6 +39,8 @@ def __init__(self):
self.router_model = None
self.files_directory = default_files_directory
self.file_agent_model = None
self.redis_host = default_redis_host
self.redis_cache_duration = default_redis_cache_duration
self.load_env()

def load_env(self):
Expand Down Expand Up @@ -75,6 +79,8 @@ def load_env(self):
self.maths_agent_model = os.getenv("MATHS_AGENT_MODEL")
self.router_model = os.getenv("ROUTER_MODEL")
self.file_agent_model = os.getenv("FILE_AGENT_MODEL")
self.redis_host = os.getenv("REDIS_HOST", default_redis_host)
self.redis_cache_duration = os.getenv("REDIS_CACHE_DURATION", default_redis_cache_duration)
except FileNotFoundError:
raise FileNotFoundError("Please provide a .env file. See the Getting Started guide on the README.md")
except Exception:
Expand Down
9 changes: 9 additions & 0 deletions backend/src/utils/json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import json
import logging

logger = logging.getLogger(__name__)

def to_json(input, error_message="Failed to interpret JSON"):
try:
return json.loads(input)
except Exception:
raise Exception(f'{error_message}: "{input}"')

def try_parse_to_json(json_string: str):
try:
return json.loads(json_string)
except json.JSONDecodeError as error:
logger.error(f"Error parsing json: {error}")
return None
Loading

0 comments on commit 675b6ee

Please sign in to comment.