Skip to content

Commit

Permalink
Refactor load_state_from_sr method to current_sr_to_session_state acr…
Browse files Browse the repository at this point in the history
…oss the codebase
  • Loading branch information
devxpy committed Sep 6, 2024
1 parent e3d2fc5 commit 73573fc
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
7 changes: 3 additions & 4 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,12 +1683,11 @@ def _render_after_output(self):
gui.session_state[StateKeys.pressed_randomize] = True
gui.rerun()

@classmethod
def load_state_from_sr(cls, sr: SavedRun) -> dict:
state = sr.to_dict()
def current_sr_to_session_state(self) -> dict:
state = self.current_sr.to_dict()
if state is None:
raise HTTPException(status_code=404)
return cls.load_state_defaults(state)
return self.load_state_defaults(state)

@classmethod
def load_state_defaults(cls, state: dict):
Expand Down
4 changes: 2 additions & 2 deletions recipes/asr_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class ResponseModel(BaseModel):
raw_output_text: list[str] | None
output_text: list[str | AsrOutputJson]

def load_state_from_sr(self, sr: SavedRun) -> dict:
state = super().load_state_from_sr(sr)
def current_sr_to_session_state(self, sr: SavedRun) -> dict:
state = super().current_sr_to_session_state(sr)
google_translate_target = state.pop("google_translate_target", None)
translation_model = state.get("translation_model")
if google_translate_target and not translation_model:
Expand Down
3 changes: 1 addition & 2 deletions routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ def submit_api_call(
self = page_cls(request=SimpleNamespace(user=user, query_params=query_params))

# get saved state from db
sr = self.current_sr
state = self.load_state_from_sr(sr)
state = self.current_sr_to_session_state()
# load request data
state.update(request_body)

Expand Down
5 changes: 2 additions & 3 deletions routers/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,10 @@ def render_recipe_page(
request._query_params = dict(request.query_params) | dict(example_id=example_id)

page = page_cls(tab=tab, request=request)
sr = page.current_sr
page.run_user = get_run_user(request, sr.uid)
page.run_user = get_run_user(request, page.current_sr.uid)

if not gui.session_state:
gui.session_state.update(page.load_state_from_sr(sr))
gui.session_state.update(page.current_sr_to_session_state())

with page_wrapper(request):
page.render()
Expand Down

0 comments on commit 73573fc

Please sign in to comment.