From 1b3ca5b5ea5efaa84c5ed8eb665bb2e2bf521876 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:14:11 +0530 Subject: [PATCH] Fix pricing for multiple outputs --- recipes/CompareText2Img.py | 2 +- recipes/GoogleGPT.py | 3 +++ recipes/VideoBots.py | 6 ++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index dc5ea1ae2..4280eeb4b 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -264,4 +264,4 @@ def get_raw_price(self, state: dict) -> int: total += 15 case _: total += 2 - return total + return total * state.get("num_outputs", 1) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 35d650acb..48c512560 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -279,3 +279,6 @@ def run_v2( max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, ) + + def get_raw_price(self, state: dict) -> float: + return self.price * state.get("num_outputs", 1) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 410416342..1daf35269 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -601,11 +601,13 @@ def get_raw_price(self, state: dict): "raw_tts_text", state.get("raw_output_text", []) ) tts_state = {"text_prompt": "".join(output_text_list)} - return super().get_raw_price(state) + TextToSpeechPage().get_raw_price( + total = super().get_raw_price(state) + TextToSpeechPage().get_raw_price( tts_state ) case _: - return super().get_raw_price(state) + total = super().get_raw_price(state) + + return total * state.get("num_outputs", 1) def additional_notes(self): tts_provider = st.session_state.get("tts_provider")