From 72629cf1f026e563f83590cdc0b116135cac7c3a Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Thu, 12 Dec 2024 10:35:13 +0000 Subject: [PATCH] FS-133 Make scratchpad per request --- backend/src/api/app.py | 2 ++ backend/src/utils/scratchpad.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/backend/src/api/app.py b/backend/src/api/app.py index 71bcac78..64c2ab80 100644 --- a/backend/src/api/app.py +++ b/backend/src/api/app.py @@ -5,6 +5,7 @@ from fastapi import FastAPI, HTTPException, Response, WebSocket, UploadFile from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from src.utils.scratchpad import ScratchPadMiddleware from src.chat_storage_service import get_chat_message from src.directors.report_director import report_on_file_upload from src.session.file_uploads import clear_session_file_uploads @@ -45,6 +46,7 @@ async def lifespan(app: FastAPI): ) app.add_middleware(RedisSessionMiddleware) +app.add_middleware(ScratchPadMiddleware) health_prefix = "InferESG healthcheck: " further_guidance = "Please check the README files for further guidance." diff --git a/backend/src/utils/scratchpad.py b/backend/src/utils/scratchpad.py index 4a5376f8..cb55d314 100644 --- a/backend/src/utils/scratchpad.py +++ b/backend/src/utils/scratchpad.py @@ -1,4 +1,7 @@ from typing import TypedDict +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +import contextvars import logging logger = logging.getLogger(__name__) @@ -10,20 +13,23 @@ class Answer(TypedDict): result: str | None error: str | None +scratchpad_context = contextvars.ContextVar("scratchpad", default=[]) Scratchpad = list[Answer] -scratchpad: Scratchpad = [] - +class ScratchPadMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + scratchpad_context.set([]) + return await call_next(request) def get_scratchpad() -> Scratchpad: - return scratchpad + return scratchpad_context.get() def update_scratchpad(agent_name=None, question=None, result=None, error=None): - scratchpad.append({"agent_name": agent_name, "question": question, "result": result, "error": error}) + get_scratchpad().append({"agent_name": agent_name, "question": question, "result": result, "error": error}) def clear_scratchpad(): logger.debug("Scratchpad cleared") - scratchpad.clear() + get_scratchpad().clear()