From c49d55f00456842f7ceab252d25f2ea42f502bca Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 30 Aug 2024 20:17:19 +0530 Subject: [PATCH] Refactor load_state_from_sr method to current_sr_to_session_state across the codebase --- daras_ai_v2/base.py | 7 +++---- recipes/asr_page.py | 4 ++-- routers/api.py | 3 +-- routers/root.py | 5 ++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index e4d58e2a2..39e9e3464 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1680,12 +1680,11 @@ def _render_after_output(self): gui.session_state[StateKeys.pressed_randomize] = True gui.rerun() - @classmethod - def load_state_from_sr(cls, sr: SavedRun) -> dict: - state = sr.to_dict() + def current_sr_to_session_state(self) -> dict: + state = self.current_sr.to_dict() if state is None: raise HTTPException(status_code=404) - return cls.load_state_defaults(state) + return self.load_state_defaults(state) @classmethod def load_state_defaults(cls, state: dict): diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 58d49ffa4..903150a6b 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -64,8 +64,8 @@ class ResponseModel(BaseModel): raw_output_text: list[str] | None output_text: list[str | AsrOutputJson] - def load_state_from_sr(self, sr: SavedRun) -> dict: - state = super().load_state_from_sr(sr) + def current_sr_to_session_state(self, sr: SavedRun) -> dict: + state = super().current_sr_to_session_state(sr) google_translate_target = state.pop("google_translate_target", None) translation_model = state.get("translation_model") if google_translate_target and not translation_model: diff --git a/routers/api.py b/routers/api.py index dd74b5a00..d1878bd53 100644 --- a/routers/api.py +++ b/routers/api.py @@ -344,8 +344,7 @@ def submit_api_call( self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - sr = self.current_sr - state = self.load_state_from_sr(sr) + state = self.current_sr_to_session_state() # load request data state.update(request_body) diff --git a/routers/root.py b/routers/root.py index 5d31b1ceb..80b7e8194 100644 --- a/routers/root.py +++ b/routers/root.py @@ -671,11 +671,10 @@ def render_recipe_page( request._query_params = dict(request.query_params) | dict(example_id=example_id) page = page_cls(tab=tab, request=request) - sr = page.current_sr - page.run_user = get_run_user(request, sr.uid) + page.run_user = get_run_user(request, page.current_sr.uid) if not gui.session_state: - gui.session_state.update(page.load_state_from_sr(sr)) + gui.session_state.update(page.current_sr_to_session_state()) with page_wrapper(request): page.render()