From b0c80dac8e22faafa319d5466947df8723dfaa4a Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sat, 27 Jul 2024 19:34:02 +0530 Subject: [PATCH] Move ui lib to gooey_gui package, rename gooey_ui to gooey_gui for consistency --- celeryapp/tasks.py | 18 +- components_doc.py | 2 +- conftest.py | 2 +- daras_ai_v2/analysis_results.py | 72 +- daras_ai_v2/api_examples_widget.py | 18 +- daras_ai_v2/asr.py | 18 +- daras_ai_v2/base.py | 538 +++++---- daras_ai_v2/billing.py | 265 ++--- daras_ai_v2/bot_integration_widgets.py | 166 +-- daras_ai_v2/bots.py | 4 +- daras_ai_v2/breadcrumbs.py | 12 +- daras_ai_v2/chat_explore.py | 26 +- .../copy_to_clipboard_button_widget.py | 2 +- daras_ai_v2/descriptions.py | 10 +- daras_ai_v2/doc_search_settings_widgets.py | 50 +- daras_ai_v2/enum_selector_widget.py | 26 +- daras_ai_v2/grid_layout_widget.py | 4 +- daras_ai_v2/html_error_widget.py | 4 +- daras_ai_v2/html_spinner_widget.py | 6 +- daras_ai_v2/img_model_settings_widgets.py | 94 +- .../language_model_settings_widgets.py | 20 +- daras_ai_v2/lipsync_settings_widgets.py | 34 +- daras_ai_v2/loom_video_widget.py | 6 +- daras_ai_v2/manage_api_keys_widget.py | 16 +- daras_ai_v2/patch_widgets.py | 214 ---- daras_ai_v2/profiles.py | 186 ++- daras_ai_v2/prompt_vars.py | 42 +- daras_ai_v2/query_params.py | 7 - daras_ai_v2/repositioning.py | 4 +- daras_ai_v2/scrollable_html_widget.py | 2 +- daras_ai_v2/search_ref.py | 1 - daras_ai_v2/send_email.py | 63 - daras_ai_v2/serp_search_locations.py | 14 +- daras_ai_v2/text_output_widget.py | 10 +- .../text_to_speech_settings_widgets.py | 84 +- daras_ai_v2/text_training_data_widget.py | 18 +- daras_ai_v2/vector_search.py | 2 +- daras_ai_v2/workflow_url_input.py | 33 +- explore.py | 2 +- functions/recipe_functions.py | 22 +- gooey_ui/__init__.py | 10 - gooey_ui/components/__init__.py | 1019 ----------------- gooey_ui/components/modal.py | 96 -- gooey_ui/components/pills.py | 23 - gooey_ui/components/url_button.py | 13 - gooey_ui/pubsub.py | 194 ---- gooey_ui/state.py | 229 ---- pages/Stats.py | 42 +- poetry.lock | 37 +- pyproject.toml | 5 +- recipes/BulkEval.py | 56 +- recipes/BulkRunner.py | 108 +- recipes/ChyronPlant.py | 23 +- recipes/CompareLLM.py | 20 +- recipes/CompareText2Img.py | 42 +- recipes/CompareUpscaler.py | 40 +- recipes/DeforumSD.py | 104 +- recipes/DocExtract.py | 20 +- recipes/DocSearch.py | 44 +- recipes/DocSummary.py | 34 +- recipes/EmailFaceInpainting.py | 68 +- recipes/FaceInpainting.py | 92 +- recipes/Functions.py | 28 +- recipes/GoogleGPT.py | 56 +- recipes/GoogleImageGen.py | 36 +- recipes/GoogleTTS.py | 30 +- recipes/ImageSegmentation.py | 84 +- recipes/Img2Img.py | 26 +- recipes/LetterWriter.py | 76 +- recipes/Lipsync.py | 22 +- recipes/LipsyncTTS.py | 22 +- recipes/ObjectInpainting.py | 82 +- recipes/QRCodeGenerator.py | 152 +-- recipes/RelatedQnA.py | 28 +- recipes/RelatedQnADoc.py | 28 +- recipes/SEOSummary.py | 78 +- recipes/SmartGPT.py | 26 +- recipes/SocialLookupEmail.py | 52 +- recipes/Text2Audio.py | 20 +- recipes/TextToSpeech.py | 30 +- recipes/Translation.py | 30 +- recipes/VideoBots.py | 470 ++++---- recipes/VideoBotsStats.py | 105 +- recipes/asr_page.py | 28 +- recipes/embeddings_page.py | 28 +- recipes/uberduck.py | 21 +- routers/account.py | 60 +- routers/api.py | 8 +- routers/root.py | 93 +- tests/test_checkout.py | 2 +- tests/test_pricing.py | 16 +- url_shortener/models.py | 4 +- usage_costs/cost_utils.py | 4 +- 93 files changed, 2163 insertions(+), 4018 deletions(-) delete mode 100644 daras_ai_v2/patch_widgets.py delete mode 100644 daras_ai_v2/query_params.py delete mode 100644 gooey_ui/__init__.py delete mode 100644 gooey_ui/components/__init__.py delete mode 100644 gooey_ui/components/modal.py delete mode 100644 gooey_ui/components/pills.py delete mode 100644 gooey_ui/components/url_button.py delete mode 100644 gooey_ui/pubsub.py delete mode 100644 gooey_ui/state.py diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 29e480602..439f9114e 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -5,13 +5,13 @@ from time import time from types import SimpleNamespace +import gooey_gui as gui import requests import sentry_sdk from django.db.models import Sum from django.utils import timezone from fastapi import HTTPException -import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.admin_links import change_obj_url from bots.models import SavedRun, Platform, Workflow @@ -22,8 +22,6 @@ from daras_ai_v2.exceptions import UserError from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email from daras_ai_v2.settings import templates -from gooey_ui.pubsub import realtime_push -from gooey_ui.state import set_query_params from gooeysite.bg_db_conn import db_middleware from payments.auto_recharge import ( should_attempt_auto_recharge, @@ -69,7 +67,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # extract outputs from local state | { k: v - for k, v in st.session_state.items() + for k, v in gui.session_state.items() if k in page.ResponseModel.__fields__ } # add extra outputs from the run @@ -77,20 +75,20 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False ) # send outputs to ui - realtime_push(channel, output) + gui.realtime_push(channel, output) # save to db - page.dump_state_to_sr(st.session_state | output, sr) + page.dump_state_to_sr(gui.session_state | output, sr) user = AppUser.objects.get(id=user_id) page = page_cls(request=SimpleNamespace(user=user)) page.setup_sentry() sr = page.run_doc_sr(run_id, uid) - st.set_session_state(sr.to_dict() | (unsaved_state or {})) - set_query_params(dict(run_id=run_id, uid=uid)) + gui.set_session_state(sr.to_dict() | (unsaved_state or {})) + gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() - for val in page.main(sr, st.session_state): + for val in page.main(sr, gui.session_state): save_on_step(val) # render errors nicely @@ -107,7 +105,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # run completed successfully, deduct credits else: - sr.transaction, sr.price = page.deduct_credits(st.session_state) + sr.transaction, sr.price = page.deduct_credits(gui.session_state) # save everything, mark run as completed finally: diff --git a/components_doc.py b/components_doc.py index 74c9cdde2..f6ec26d8a 100644 --- a/components_doc.py +++ b/components_doc.py @@ -1,7 +1,7 @@ import inspect from functools import wraps -import gooey_ui as gui +import gooey_gui as gui META_TITLE = "Gooey Components" META_DESCRIPTION = "Explore the Gooey Component Library" diff --git a/conftest.py b/conftest.py index a38c6a11a..57d5d5cca 100644 --- a/conftest.py +++ b/conftest.py @@ -55,7 +55,7 @@ def mock_celery_tasks(): with ( patch("celeryapp.tasks.runner_task", _mock_runner_task), patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks), - patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe), + patch("gooey_gui.realtime_subscribe", _mock_realtime_subscribe), ): yield diff --git a/daras_ai_v2/analysis_results.py b/daras_ai_v2/analysis_results.py index 7a361c4aa..5bd519982 100644 --- a/daras_ai_v2/analysis_results.py +++ b/daras_ai_v2/analysis_results.py @@ -4,14 +4,14 @@ from django.db.models import IntegerChoices -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import BotIntegration, Message from daras_ai_v2.base import RecipeTabs from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.workflow_url_input import del_button -from gooey_ui import QueryParamsRedirectException +from gooey_gui import QueryParamsRedirectException from gooeysite.custom_filters import related_json_field_summary from recipes.BulkRunner import list_view_editor from recipes.VideoBots import VideoBotsPage @@ -41,7 +41,7 @@ def render_analysis_results_page( render_title_breadcrumb_share(bi, current_url, current_user) if title: - st.write(title) + gui.write(title) if graphs_json: graphs = json.loads(graphs_json) @@ -55,40 +55,40 @@ def render_analysis_results_page( except Message.DoesNotExist: results = None if not results: - with st.center(): - st.error("No analysis results found") + with gui.center(): + gui.error("No analysis results found") return - with st.div(className="pb-5 pt-3"): + with gui.div(className="pb-5 pt-3"): grid_layout(2, graphs, partial(render_graph_data, bi, results), separator=False) - st.checkbox("๐Ÿ”„ Refresh every 10s", key="autorefresh") + gui.checkbox("๐Ÿ”„ Refresh every 10s", key="autorefresh") - with st.expander("โœ๏ธ Edit"): - title = st.text_area("##### Title", value=title) + with gui.expander("โœ๏ธ Edit"): + title = gui.text_area("##### Title", value=title) - st.session_state.setdefault("selected_graphs", graphs) + gui.session_state.setdefault("selected_graphs", graphs) selected_graphs = list_view_editor( add_btn_label="โž• Add a Graph", key="selected_graphs", render_inputs=partial(render_inputs, results), ) - with st.center(): - if st.button("โœ… Update"): + with gui.center(): + if gui.button("โœ… Update"): _on_press_update(title, selected_graphs) def render_inputs(results: dict, key: str, del_key: str, d: dict): - ocol1, ocol2 = st.columns([11, 1], responsive=False) + ocol1, ocol2 = gui.columns([11, 1], responsive=False) with ocol1: - col1, col2, col3 = st.columns(3) + col1, col2, col3 = gui.columns(3) with ocol2: ocol2.node.props["style"] = dict(paddingTop="2rem") del_button(del_key) with col1: - d["key"] = st.selectbox( + d["key"] = gui.selectbox( label="##### Key", options=results.keys(), key=f"{key}_key", @@ -96,7 +96,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): ) with col2: col2.node.props["style"] = dict(paddingTop="0.45rem") - d["graph_type"] = st.selectbox( + d["graph_type"] = gui.selectbox( label="###### Graph Type", options=[g.value for g in GraphType], format_func=lambda x: GraphType(x).label, @@ -105,7 +105,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): ) with col3: col3.node.props["style"] = dict(paddingTop="0.45rem") - d["data_selection"] = st.selectbox( + d["data_selection"] = gui.selectbox( label="###### Data Selection", options=[d.value for d in DataSelection], format_func=lambda x: DataSelection(x).label, @@ -115,10 +115,10 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): def _autorefresh_script(): - if not st.session_state.get("autorefresh"): + if not gui.session_state.get("autorefresh"): return - st.session_state.pop("__cache__", None) - st.js( + gui.session_state.pop("__cache__", None) + gui.js( # language=JavaScript """ setTimeout(() => { @@ -139,23 +139,23 @@ def render_title_breadcrumb_share( else: run_title = bi.saved_run.page_title # this is mostly for backwards compat query_params = dict(run_id=bi.saved_run.run_id, uid=bi.saved_run.uid) - with st.div(className="d-flex justify-content-between mt-4"): - with st.div(className="d-lg-flex d-block align-items-center"): - with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - with st.breadcrumbs(): + with gui.div(className="d-flex justify-content-between mt-4"): + with gui.div(className="d-lg-flex d-block align-items-center"): + with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with gui.breadcrumbs(): metadata = VideoBotsPage.workflow.get_or_create_metadata() - st.breadcrumb_item( + gui.breadcrumb_item( metadata.short_title, link_to=VideoBotsPage.app_url(), className="text-muted", ) if not (bi.published_run_id and bi.published_run.is_root()): - st.breadcrumb_item( + gui.breadcrumb_item( run_title, link_to=VideoBotsPage.app_url(**query_params), className="text-muted", ) - st.breadcrumb_item( + gui.breadcrumb_item( "Integrations", link_to=VideoBotsPage.app_url( **query_params, @@ -170,8 +170,8 @@ def render_title_breadcrumb_share( show_as_link=current_user and VideoBotsPage.is_user_admin(current_user), ) - with st.div(className="d-flex align-items-center"): - with st.div(className="d-flex align-items-start right-action-icons"): + with gui.div(className="d-flex align-items-center"): + with gui.div(className="d-flex align-items-start right-action-icons"): copy_to_clipboard_button( f' Copy Link', value=current_url, @@ -197,7 +197,7 @@ def _on_press_update(title: str, selected_graphs: list[dict]): raise QueryParamsRedirectException(dict(title=title, graphs=graphs_json)) -@st.cache_in_session_state +@gui.cache_in_session_state def fetch_analysis_results(bi: BotIntegration) -> dict: msgs = Message.objects.filter( conversation__bot_integration=bi, @@ -217,7 +217,7 @@ def fetch_analysis_results(bi: BotIntegration) -> dict: def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): key = graph_data["key"] - st.write(f"##### {key}") + gui.write(f"##### {key}") obj_key = f"analysis_result__{key}" if graph_data["data_selection"] == DataSelection.last.value: @@ -227,7 +227,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): .latest() ) if not latest_msg: - st.write("No analysis results found") + gui.write("No analysis results found") return values = [[latest_msg.analysis_result.get(key), 1]] elif graph_data["data_selection"] == DataSelection.convo_last.value: @@ -249,7 +249,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): for val in values: if not val: continue - st.write(val[0]) + gui.write(val[0]) case GraphType.table_count.value: render_table_count(values) case GraphType.bar_count.value: @@ -261,8 +261,8 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): def render_table_count(values): - st.div(className="p-1") - st.data_table( + gui.div(className="p-1") + gui.data_table( [["Value", "Count"]] + [[result[0], result[1]] for result in values], ) @@ -318,4 +318,4 @@ def render_data_in_plotly(*data): dragmode="pan", ), ) - st.plotly_chart(fig) + gui.plotly_chart(fig) diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index 44086ba02..9cd9beeb1 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -5,7 +5,7 @@ from furl import furl -import gooey_ui as st +import gooey_gui as gui from auth.token_authentication import auth_keyword from daras_ai_v2 import settings from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url @@ -30,7 +30,7 @@ def get_filenames(request_body): def api_example_generator( *, api_url: furl, request_body: dict, as_form_data: bool, as_async: bool ): - js, python, curl = st.tabs(["`node.js`", "`python`", "`curl`"]) + js, python, curl = gui.tabs(["`node.js`", "`python`", "`curl`"]) filenames = [] if as_async: @@ -95,7 +95,7 @@ def api_example_generator( json=shlex.quote(json.dumps(request_body, indent=2)), ) - st.write( + gui.write( """ 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -193,7 +193,7 @@ def api_example_generator( from black.mode import Mode py_code = format_str(py_code, mode=Mode()) - st.write( + gui.write( rf""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -308,7 +308,7 @@ def api_example_generator( js_code += "\n}\n\ngooeyAPI();" - st.write( + gui.write( r""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -385,7 +385,7 @@ def bot_api_example_generator(integration_id: str): integration_id=integration_id, ) - st.write( + gui.write( f""" Your Integration ID: `{integration_id}` @@ -407,14 +407,14 @@ def bot_api_example_generator(integration_id: str): ) / "docs" ) - st.markdown( + gui.markdown( f""" Read our complete API for features like conversation history, input media files, and more. """, unsafe_allow_html=True, ) - st.js( + gui.js( """ document.startStreaming = async function() { document.getElementById('stream-output').style.display = 'flex'; @@ -426,7 +426,7 @@ def bot_api_example_generator(integration_id: str): ).strip() ) - st.html( + gui.html( f"""
diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 78df47297..eb7f03bd3 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -10,7 +10,7 @@ from django.db.models import F from furl import furl -import gooey_ui as st +import gooey_gui as gui from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.azure_asr import azure_asr @@ -297,7 +297,7 @@ def translation_language_selector( **kwargs, ) -> str | None: if not model: - st.session_state[key] = None + gui.session_state[key] = None return if model == TranslationModels.google: @@ -308,7 +308,7 @@ def translation_language_selector( raise ValueError("Unsupported translation model: " + str(model)) options = list(languages.keys()) - return st.selectbox( + return gui.selectbox( label=label, key=key, format_func=lang_format_func, @@ -351,7 +351,7 @@ def google_translate_language_selector( """ languages = google_translate_target_languages() options = list(languages.keys()) - return st.selectbox( + return gui.selectbox( label=label, key=key, format_func=lambda k: languages[k] if k else "โ€”โ€”โ€”", @@ -411,7 +411,7 @@ def asr_language_selector( # don't show language selector for models with forced language forced_lang = forced_asr_languages.get(selected_model) if forced_lang: - st.session_state[key] = forced_lang + gui.session_state[key] = forced_lang return forced_lang options = list(asr_supported_languages.get(selected_model, [])) @@ -419,14 +419,14 @@ def asr_language_selector( options.insert(0, None) # handle non-canonical language codes - old_lang = st.session_state.get(key) + old_lang = gui.session_state.get(key) if old_lang: try: - st.session_state[key] = normalised_lang_in_collection(old_lang, options) + gui.session_state[key] = normalised_lang_in_collection(old_lang, options) except UserError: - st.session_state[key] = None + gui.session_state[key] = None - return st.selectbox( + return gui.selectbox( label=label, key=key, format_func=lang_format_func, diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 88f78add4..a90f2504b 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -12,6 +12,7 @@ from time import sleep from types import SimpleNamespace +import gooey_gui as gui import sentry_sdk from django.db.models import Sum from django.utils import timezone @@ -25,7 +26,6 @@ ) from starlette.requests import Request -import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.models import ( SavedRun, @@ -35,6 +35,7 @@ Workflow, RetentionPolicy, ) +from daras_ai.image_input import truncate_text_words from daras_ai.text_format import format_number_with_suffix from daras_ai_v2 import settings, urls from daras_ai_v2.api_examples_widget import api_example_generator @@ -55,9 +56,6 @@ from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.prompt_vars import variables_input -from daras_ai_v2.query_params import ( - gooey_get_query_params, -) from daras_ai_v2.query_params_util import ( extract_query_params, ) @@ -76,20 +74,12 @@ is_functions_enabled, render_called_functions, ) -from gooey_ui import ( - realtime_clear_subs, - RedirectException, -) -from gooey_ui.components.modal import Modal -from gooey_ui.components.pills import pill -from gooey_ui.pubsub import realtime_pull from payments.auto_recharge import ( should_attempt_auto_recharge, run_auto_recharge_gracefully, ) from routers.account import AccountTabs from routers.root import RecipeTabs -from daras_ai.image_input import truncate_text_words DEFAULT_META_IMG = ( # Small @@ -183,7 +173,7 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + example_id, run_id, uid = extract_query_params(gui.get_query_params()) return cls.app_url( tab=tab, example_id=example_id, @@ -277,7 +267,7 @@ def setup_sentry(self): def sentry_event_set_request(self, event, hint): request = event.setdefault("request", {}) request.setdefault("method", "POST") - request["data"] = st.session_state + request["data"] = gui.session_state if url := request.get("url"): f = furl(url) request["url"] = str( @@ -285,7 +275,7 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=st.get_query_params() + tab=self.tab, query_params=gui.get_query_params() ) return event @@ -314,36 +304,36 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - _, run_id, uid = extract_query_params(gooey_get_query_params()) + _, run_id, uid = extract_query_params(gui.get_query_params()) channel = self.realtime_channel_name(run_id, uid) - output = realtime_pull([channel])[0] + output = gui.realtime_pull([channel])[0] if output: - st.session_state.update(output) + gui.session_state.update(output) def render(self): self.setup_sentry() - if self.get_run_state(st.session_state) == RecipeRunState.running: + if self.get_run_state(gui.session_state) == RecipeRunState.running: self.refresh_state() else: - realtime_clear_subs() + gui.realtime_clear_subs() self._user_disabled_check() self._check_if_flagged() - if st.session_state.get("show_report_workflow"): + if gui.session_state.get("show_report_workflow"): self.render_report_form() return self._render_header() - st.newline() + gui.newline() - with st.nav_tabs(): + with gui.nav_tabs(): for tab in self.get_tabs(): url = self.current_app_url(tab) - with st.nav_item(url, active=tab == self.tab): - st.html(tab.title) - with st.nav_tab_content(): + with gui.nav_item(url, active=tab == self.tab): + gui.html(tab.title) + with gui.nav_tab_content(): self.render_selected_tab() def _render_header(self): @@ -355,13 +345,13 @@ def _render_header(self): self, current_run, published_run, tab=self.tab ) - with st.div(className="d-flex justify-content-between mt-4"): - with st.div(className="d-lg-flex d-block align-items-center"): + with gui.div(className="d-flex justify-content-between mt-4"): + with gui.div(className="d-lg-flex d-block align-items-center"): if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: self._render_title(tbreadcrumbs.h1_title) if tbreadcrumbs: - with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, is_api_call=( @@ -377,7 +367,7 @@ def _render_header(self): if not is_root_example: self.render_author(author) - with st.div(className="d-flex align-items-center"): + with gui.div(className="d-flex align-items-center"): can_user_edit_run = self.can_user_edit_run(current_run, published_run) has_unpublished_changes = ( published_run @@ -390,8 +380,8 @@ def _render_header(self): if can_user_edit_run and has_unpublished_changes: self._render_unpublished_changes_indicator() - with st.div(className="d-flex align-items-start right-action-icons"): - st.html( + with gui.div(className="d-flex align-items-start right-action-icons"): + gui.html( """ """ ) - st.image(user.photo_url, className=class_name) + gui.image(user.photo_url, className=class_name) if user.display_name: name_style = {"fontSize": text_size} if text_size else {} - with st.tag("span", style=name_style): - st.html(html.escape(user.display_name)) + with gui.tag("span", style=name_style): + gui.html(html.escape(user.display_name)) def get_credits_click_url(self): if self.request.user and self.request.user.is_anonymous: @@ -1321,9 +1313,9 @@ def get_submit_container_props(self): ) def render_submit_button(self, key="--submit-1"): - with st.div(**self.get_submit_container_props()): - st.write("---") - col1, col2 = st.columns([2, 1], responsive=False) + with gui.div(**self.get_submit_container_props()): + gui.write("---") + col1, col2 = gui.columns([2, 1], responsive=False) col2.node.props[ "className" ] += " d-flex justify-content-end align-items-center" @@ -1331,25 +1323,25 @@ def render_submit_button(self, key="--submit-1"): with col1: self.render_run_cost() with col2: - submitted = st.button( + submitted = gui.button( "๐Ÿƒ Submit", key=key, type="primary", - # disabled=bool(st.session_state.get(StateKeys.run_status)), + # disabled=bool(gui.session_state.get(StateKeys.run_status)), ) if not submitted: return False try: self.validate_form_v2() except AssertionError as e: - st.error(str(e)) + gui.error(str(e)) return False else: return True def render_run_cost(self): url = self.get_credits_click_url() - run_cost = self.get_price_roundoff(st.session_state) + run_cost = self.get_price_roundoff(gui.session_state) ret = f'Run cost = {run_cost} credits' cost_note = self.get_cost_note() @@ -1360,18 +1352,18 @@ def render_run_cost(self): if additional_notes: ret += f" \n{additional_notes}" - st.caption(ret, line_clamp=1, unsafe_allow_html=True) + gui.caption(ret, line_clamp=1, unsafe_allow_html=True) def _render_step_row(self): key = "details-expander" - with st.expander("**โ„น๏ธ Details**", key=key): - if not st.session_state.get(key): + with gui.expander("**โ„น๏ธ Details**", key=key): + if not gui.session_state.get(key): return - col1, col2 = st.columns([1, 2]) + col1, col2 = gui.columns([1, 2]) with col1: self.render_description() with col2: - placeholder = st.div() + placeholder = gui.div() render_called_functions( saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre ) @@ -1381,33 +1373,33 @@ def _render_step_row(self): pass else: with placeholder: - st.write("##### ๐Ÿ‘ฃ Steps") + gui.write("##### ๐Ÿ‘ฃ Steps") render_called_functions( saved_run=self.get_current_sr(), trigger=FunctionTrigger.post ) def _render_help(self): - placeholder = st.div() + placeholder = gui.div() try: self.render_usage_guide() except NotImplementedError: pass else: with placeholder: - st.write( + gui.write( """ ## How to Use This Recipe """ ) key = "discord-expander" - with st.expander( + with gui.expander( f"**๐Ÿ™‹๐Ÿฝโ€โ™€๏ธ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**", key=key, ): - if not st.session_state.get(key): + if not gui.session_state.get(key): return - st.markdown( + gui.markdown( """
@@ -1464,34 +1456,34 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + example_id, run_id, uid = extract_query_params(gui.get_query_params()) # only logged in users can report a run (but not examples/default runs) if not (self.request.user and run_id and uid): return - reported = st.button( + reported = gui.button( ' Report', type="tertiary" ) if not reported: return - st.session_state["show_report_workflow"] = reported - st.experimental_rerun() + gui.session_state["show_report_workflow"] = reported + gui.rerun() def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): ref = self.run_doc_sr(uid=uid, run_id=run_id) ref.is_flagged = is_flagged ref.save(update_fields=["is_flagged"]) - st.session_state["is_flagged"] = is_flagged + gui.session_state["is_flagged"] = is_flagged # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True def _render_input_col(self): self.render_form_v2() - placeholder = st.div() + placeholder = gui.div() - with st.expander("โš™๏ธ Settings"): + with gui.expander("โš™๏ธ Settings"): if self.functions_in_settings: functions_input(self.request.user) self.render_settings() @@ -1500,8 +1492,8 @@ def _render_input_col(self): self.render_variables() submitted = self.render_submit_button() - with st.div(style={"textAlign": "right"}): - st.caption( + with gui.div(style={"textAlign": "right"}): + gui.caption( "_By submitting, you agree to Gooey.AI's [terms](https://gooey.ai/terms) & " "[privacy policy](https://gooey.ai/privacy)._" ) @@ -1509,7 +1501,7 @@ def _render_input_col(self): def render_variables(self): if not self.functions_in_settings: - st.write("---") + gui.write("---") functions_input(self.request.user) variables_input( template_keys=self.template_keys, allow_add=is_functions_enabled() @@ -1528,30 +1520,30 @@ def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: return RecipeRunState.starting def render_deleted_output(self): - col1, *_ = st.columns(2) + col1, *_ = gui.columns(2) with col1: - st.error( + gui.error( "This data has been deleted as per the retention policy.", icon="๐Ÿ—‘๏ธ", color="rgba(255, 200, 100, 0.5)", ) - st.newline() + gui.newline() self._render_output_col(is_deleted=True) - st.newline() + gui.newline() self.render_run_cost() def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = False): assert inspect.isgeneratorfunction(self.run) - if st.session_state.get(StateKeys.pressed_randomize): - st.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED)) - st.session_state.pop(StateKeys.pressed_randomize, None) + if gui.session_state.get(StateKeys.pressed_randomize): + gui.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED)) + gui.session_state.pop(StateKeys.pressed_randomize, None) submitted = True if submitted or self.should_submit_after_login(): self.on_submit() - run_state = self.get_run_state(st.session_state) + run_state = self.get_run_state(gui.session_state) match run_state: case RecipeRunState.completed: self._render_completed_output() @@ -1575,16 +1567,16 @@ def _render_completed_output(self): pass def _render_failed_output(self): - err_msg = st.session_state.get(StateKeys.error_msg) - st.error(err_msg, unsafe_allow_html=True) + err_msg = gui.session_state.get(StateKeys.error_msg) + gui.error(err_msg, unsafe_allow_html=True) def _render_running_output(self): - run_status = st.session_state.get(StateKeys.run_status) + run_status = gui.session_state.get(StateKeys.run_status) html_spinner(run_status) self.render_extra_waiting_output() def render_extra_waiting_output(self): - created_at = st.session_state.get("created_at") + created_at = gui.session_state.get("created_at") if not created_at: return @@ -1595,15 +1587,15 @@ def render_extra_waiting_output(self): estimated_run_time = self.estimate_run_duration() if not estimated_run_time: return - with st.countdown_timer( + with gui.countdown_timer( end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: - st.write( + gui.write( f"""We'll email **{self.request.user.email}** when your workflow is done.""" ) - st.write( + gui.write( f"""In the meantime, check out [๐Ÿš€ Examples]({self.current_app_url(RecipeTabs.examples)}) for inspiration.""" ) @@ -1615,21 +1607,21 @@ def on_submit(self): try: sr = self.create_new_run(enable_rate_limits=True) except ValidationError as e: - st.session_state[StateKeys.run_status] = None - st.session_state[StateKeys.error_msg] = str(e) + gui.session_state[StateKeys.run_status] = None + gui.session_state[StateKeys.error_msg] = str(e) return except RateLimitExceeded as e: - st.session_state[StateKeys.run_status] = None - st.session_state[StateKeys.error_msg] = e.detail.get("error", "") + gui.session_state[StateKeys.run_status] = None + gui.session_state[StateKeys.error_msg] = e.detail.get("error", "") return self.call_runner_task(sr) - raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) def should_submit_after_login(self) -> bool: return ( - st.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) + gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) and self.request and self.request.user and not self.request.user.is_anonymous @@ -1638,9 +1630,9 @@ def should_submit_after_login(self) -> bool: def create_new_run( self, *, enable_rate_limits: bool = False, **defaults ) -> SavedRun: - st.session_state[StateKeys.run_status] = "Starting..." - st.session_state.pop(StateKeys.error_msg, None) - st.session_state.pop(StateKeys.run_time, None) + gui.session_state[StateKeys.run_status] = "Starting..." + gui.session_state.pop(StateKeys.error_msg, None) + gui.session_state.pop(StateKeys.run_time, None) self._setup_rng_seed() self.clear_outputs() @@ -1660,7 +1652,7 @@ def create_new_run( run_id = get_random_doc_id() parent_example_id, parent_run_id, parent_uid = extract_query_params( - gooey_get_query_params() + gui.get_query_params() ) parent = self.get_sr_from_query_params( parent_example_id, parent_run_id, parent_uid @@ -1679,8 +1671,8 @@ def create_new_run( ) # ensure the request is validated - state = st.session_state | json.loads( - self.RequestModel.parse_obj(st.session_state).json(exclude_unset=True) + state = gui.session_state | json.loads( + self.RequestModel.parse_obj(gui.session_state).json(exclude_unset=True) ) self.dump_state_to_sr(state, sr) @@ -1748,28 +1740,28 @@ def generate_credit_error_message(self, run_id, uid) -> str: return error_msg def _setup_rng_seed(self): - seed = st.session_state.get("seed") + seed = gui.session_state.get("seed") if not seed: return gooey_rng.seed(seed) def clear_outputs(self): # clear error msg - st.session_state.pop(StateKeys.error_msg, None) + gui.session_state.pop(StateKeys.error_msg, None) # clear outputs for field_name in self.ResponseModel.__fields__: - st.session_state.pop(field_name, None) + gui.session_state.pop(field_name, None) def _render_after_output(self): self._render_report_button() if "seed" in self.RequestModel.schema_json(): - randomize = st.button( + randomize = gui.button( ' Regenerate', type="tertiary" ) if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() + gui.session_state[StateKeys.pressed_randomize] = True + gui.rerun() @classmethod def load_state_from_sr(cls, sr: SavedRun) -> dict: @@ -1805,7 +1797,7 @@ def _unsaved_state(self) -> dict[str, typing.Any]: result = {} for field in self.fields_not_to_save(): try: - result[field] = st.session_state[field] + result[field] = gui.session_state[field] except KeyError: pass return result @@ -1842,12 +1834,12 @@ def _saved_tab(self): created_by=self.request.user, )[:50] if not published_runs: - st.write("No published runs yet") + gui.write("No published runs yet") return def _render(pr: PublishedRun): - with st.div(className="mb-2", style={"font-size": "0.9rem"}): - pill( + with gui.div(className="mb-2", style={"font-size": "0.9rem"}): + gui.pill( PublishedRunVisibility(pr.visibility).get_badge_html(), unsafe_allow_html=True, className="border border-dark", @@ -1864,7 +1856,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gooey_get_query_params().get("updated_at__lt", None) + before = gui.get_query_params().get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -1877,7 +1869,7 @@ def _history_tab(self): )[:25] ) if not run_history: - st.write("No history yet") + gui.write("No history yet") return grid_layout(3, run_history, self._render_run_preview) @@ -1886,15 +1878,15 @@ def _history_tab(self): RecipeTabs.history, query_params={"updated_at__lt": run_history[-1].to_dict()["updated_at"]}, ) - with st.link(to=str(next_url)): - st.html( + with gui.link(to=str(next_url)): + gui.html( # language=HTML f"""""" ) def ensure_authentication(self, next_url: str | None = None, anon_ok: bool = False): if not self.request.user or (self.request.user.is_anonymous and not anon_ok): - raise RedirectException(self.get_auth_url(next_url)) + raise gui.RedirectException(self.get_auth_url(next_url)) def get_auth_url(self, next_url: str | None = None) -> str: from routers.root import login @@ -1909,10 +1901,10 @@ def _render_run_preview(self, saved_run: SavedRun): is_latest_version = published_run and published_run.saved_run == saved_run tb = get_title_breadcrumbs(self, sr=saved_run, pr=published_run) - with st.link(to=saved_run.get_app_url()): - with st.div(className="mb-1", style={"fontSize": "0.9rem"}): + with gui.link(to=saved_run.get_app_url()): + with gui.div(className="mb-1", style={"fontSize": "0.9rem"}): if is_latest_version: - pill( + gui.pill( PublishedRunVisibility( published_run.visibility ).get_badge_html(), @@ -1920,7 +1912,7 @@ def _render_run_preview(self, saved_run: SavedRun): className="border border-dark", ) - st.write(f"#### {tb.h1_title}") + gui.write(f"#### {tb.h1_title}") updated_at = saved_run.updated_at if ( @@ -1928,34 +1920,34 @@ def _render_run_preview(self, saved_run: SavedRun): and isinstance(updated_at, datetime.datetime) and not saved_run.run_status ): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) if saved_run.run_status: started_at_text(saved_run.created_at) html_spinner(saved_run.run_status, scroll_into_view=False) elif saved_run.error_msg: - st.error(saved_run.error_msg, unsafe_allow_html=True) + gui.error(saved_run.error_msg, unsafe_allow_html=True) return self.render_example(saved_run.to_dict()) def render_published_run_preview(self, published_run: PublishedRun): tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) - with st.link(to=published_run.get_app_url()): - st.write(f"#### {tb.h1_title}") + with gui.link(to=published_run.get_app_url()): + gui.write(f"#### {tb.h1_title}") - with st.div(className="d-flex align-items-center justify-content-between"): - with st.div(): + with gui.div(className="d-flex align-items-center justify-content-between"): + with gui.div(): updated_at = published_run.saved_run.updated_at if updated_at and isinstance(updated_at, datetime.datetime): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) if published_run.visibility == PublishedRunVisibility.PUBLIC: run_icon = '' run_count = format_number_with_suffix(published_run.get_run_count()) - st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) if published_run.notes: - st.caption(published_run.notes, line_clamp=2) + gui.caption(published_run.notes, line_clamp=2) doc = published_run.saved_run.to_dict() self.render_example(doc) @@ -1969,28 +1961,28 @@ def _render_example_preview( tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) if published_run.created_by: - with st.div(className="mb-1 text-truncate", style={"height": "1.5rem"}): + with gui.div(className="mb-1 text-truncate", style={"height": "1.5rem"}): self.render_author( published_run.created_by, image_size="20px", text_size="0.9rem", ) - with st.link(to=published_run.get_app_url()): - st.write(f"#### {tb.h1_title}") + with gui.link(to=published_run.get_app_url()): + gui.write(f"#### {tb.h1_title}") - with st.div(className="d-flex align-items-center justify-content-between"): - with st.div(): + with gui.div(className="d-flex align-items-center justify-content-between"): + with gui.div(): updated_at = published_run.saved_run.updated_at if updated_at and isinstance(updated_at, datetime.datetime): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) run_icon = '' run_count = format_number_with_suffix(published_run.get_run_count()) - st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) if published_run.notes: - st.caption(published_run.notes, line_clamp=2) + gui.caption(published_run.notes, line_clamp=2) if allow_hide: self._example_hide_button(published_run=published_run) @@ -1999,7 +1991,7 @@ def _render_example_preview( self.render_example(doc) def _example_hide_button(self, published_run: PublishedRun): - pressed_delete = st.button( + pressed_delete = gui.button( "๐Ÿ™ˆ๏ธ Hide", key=f"delete_example_{published_run.published_run_id}", style={"color": "red"}, @@ -2009,11 +2001,11 @@ def _example_hide_button(self, published_run: PublishedRun): self.set_hidden(published_run=published_run, hidden=True) def set_hidden(self, *, published_run: PublishedRun, hidden: bool): - with st.spinner("Hiding..."): + with gui.spinner("Hiding..."): published_run.is_approved_example = not hidden published_run.save() - st.experimental_rerun() + gui.rerun() def render_example(self, state: dict): pass @@ -2055,25 +2047,25 @@ def run_as_api_tab(self): / "docs" ) - st.markdown( + gui.markdown( f'๐Ÿ“– To learn more, take a look at our complete API', unsafe_allow_html=True, ) - st.write("#### ๐Ÿ“ค Example Request") + gui.write("#### ๐Ÿ“ค Example Request") - include_all = st.checkbox("##### Show all fields") - as_async = st.checkbox("##### Run Async") - as_form_data = st.checkbox("##### Upload Files via Form Data") + include_all = gui.checkbox("##### Show all fields") + as_async = gui.checkbox("##### Run Async") + as_form_data = gui.checkbox("##### Upload Files via Form Data") pr = self.get_current_published_run() api_url, request_body = self.get_example_request( - st.session_state, + gui.session_state, include_all=include_all, pr=pr, ) response_body = self.get_example_response_body( - st.session_state, as_async=as_async, include_all=include_all + gui.session_state, as_async=as_async, include_all=include_all ) api_example_generator( @@ -2082,18 +2074,18 @@ def run_as_api_tab(self): as_form_data=as_form_data, as_async=as_async, ) - st.write("") + gui.write("") - st.write("#### ๐ŸŽ Example Response") - st.json(response_body, expanded=True) + gui.write("#### ๐ŸŽ Example Response") + gui.json(response_body, expanded=True) if not self.request.user or self.request.user.is_anonymous: - st.write("**Please Login to generate the `$GOOEY_API_KEY`**") + gui.write("**Please Login to generate the `$GOOEY_API_KEY`**") return - st.write("---") - with st.tag("a", id="api-keys"): - st.write("### ๐Ÿ” API keys") + gui.write("---") + with gui.tag("a", id="api-keys"): + gui.write("### ๐Ÿ” API keys") manage_api_keys(self.request.user) @@ -2187,7 +2179,7 @@ def get_example_response_body( include_all: bool = False, ) -> dict: run_id = get_random_doc_id() - created_at = st.session_state.get( + created_at = gui.session_state.get( StateKeys.created_at, datetime.datetime.utcnow().isoformat() ) web_url = self.app_url( @@ -2201,7 +2193,7 @@ def get_example_response_body( run_id=run_id, web_url=web_url, created_at=created_at, - run_time_sec=st.session_state.get(StateKeys.run_time, 0), + run_time_sec=gui.session_state.get(StateKeys.run_time, 0), status="completed", output=output, ) @@ -2239,12 +2231,12 @@ def is_current_user_owner(self) -> bool: def started_at_text(dt: datetime.datetime): - with st.div(className="d-flex"): + with gui.div(className="d-flex"): text = "Started" - if seed := st.session_state.get("seed"): + if seed := gui.session_state.get("seed"): text += f' with seed {seed}' - st.caption(text + " on ", unsafe_allow_html=True) - st.caption( + gui.caption(text + " on ", unsafe_allow_html=True) + gui.caption( "...", className="text-black", **render_local_dt_attrs(dt), @@ -2254,23 +2246,23 @@ def started_at_text(dt: datetime.datetime): def render_output_caption(): caption = "" - run_time = st.session_state.get(StateKeys.run_time, 0) + run_time = gui.session_state.get(StateKeys.run_time, 0) if run_time: caption += f'Generated in {run_time :.1f}s' - if seed := st.session_state.get("seed"): + if seed := gui.session_state.get("seed"): caption += f' with seed {seed} ' - updated_at = st.session_state.get(StateKeys.updated_at, datetime.datetime.today()) + updated_at = gui.session_state.get(StateKeys.updated_at, datetime.datetime.today()) if updated_at: if isinstance(updated_at, str): updated_at = datetime.datetime.fromisoformat(updated_at) caption += " on " - with st.div(className="d-flex"): - st.caption(caption, unsafe_allow_html=True) + with gui.div(className="d-flex"): + gui.caption(caption, unsafe_allow_html=True) if updated_at: - st.caption( + gui.caption( "...", className="text-black", **render_local_dt_attrs( diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 0e16d25db..09b48d375 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,18 +1,15 @@ from typing import Literal +import gooey_gui as gui import stripe from django.core.exceptions import ValidationError -import gooey_ui as st from app_users.models import AppUser, PaymentProvider from daras_ai_v2 import icons, settings, paypal -from daras_ai_v2.base import RedirectException from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.settings import templates from daras_ai_v2.user_date_widgets import render_local_date_attrs -from gooey_ui.components.modal import Modal -from gooey_ui.components.pills import pill from payments.models import PaymentMethodSummary from payments.plans import PricingPlan from scripts.migrate_existing_subscriptions import available_subscriptions @@ -29,30 +26,30 @@ def billing_page(user: AppUser): if user.subscription: render_current_plan(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_credit_balance(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): selected_payment_provider = render_all_plans(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_addon_section(user, selected_payment_provider) if user.subscription and user.subscription.payment_provider: if user.subscription.payment_provider == PaymentProvider.STRIPE: - with st.div(className="my-5"): + with gui.div(className="my-5"): render_auto_recharge_section(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_payment_information(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_billing_history(user) def render_payments_setup(): from routers.account import payment_processing_route - st.html( + gui.html( templates.get_template("payment_setup.html").render( settings=settings, payment_processing_url=get_app_route_url(payment_processing_route), @@ -68,25 +65,25 @@ def render_current_plan(user: AppUser): else None ) - with st.div(className=f"{rounded_border} border-dark"): + with gui.div(className=f"{rounded_border} border-dark"): # ROW 1: Plan title and next invoice date left, right = left_and_right() with left: - st.write(f"#### Gooey.AI {plan.title}") + gui.write(f"#### Gooey.AI {plan.title}") if provider: - st.write( + gui.write( f"[{icons.edit} Manage Subscription](#payment-information)", unsafe_allow_html=True, ) - with right, st.div(className="d-flex align-items-center gap-1"): + with right, gui.div(className="d-flex align-items-center gap-1"): if provider and ( - next_invoice_ts := st.run_in_thread( + next_invoice_ts := gui.run_in_thread( user.subscription.get_next_invoice_timestamp, cache=True ) ): - st.html("Next invoice on ") - pill( + gui.html("Next invoice on ") + gui.pill( "...", text_bg="dark", **render_local_date_attrs( @@ -102,22 +99,22 @@ def render_current_plan(user: AppUser): # ROW 2: Plan pricing details left, right = left_and_right(className="mt-5") with left: - st.write(f"# {plan.pricing_title()}", className="no-margin") + gui.write(f"# {plan.pricing_title()}", className="no-margin") if plan.monthly_charge: provider_text = f" **via {provider.label}**" if provider else "" - st.caption("per month" + provider_text) + gui.caption("per month" + provider_text) - with right, st.div(className="text-end"): - st.write(f"# {plan.credits:,} credits", className="no-margin") + with right, gui.div(className="text-end"): + gui.write(f"# {plan.credits:,} credits", className="no-margin") if plan.monthly_charge: - st.write( + gui.write( f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits" ) def render_credit_balance(user: AppUser): - st.write(f"## Credit Balance: {user.balance:,}") - st.caption( + gui.write(f"## Credit Balance: {user.balance:,}") + gui.caption( "Every time you submit a workflow or make an API call, we deduct credits from your account." ) @@ -130,13 +127,13 @@ def render_all_plans(user: AppUser) -> PaymentProvider: ) all_plans = [plan for plan in PricingPlan if not plan.deprecated] - st.write("## All Plans") - plans_div = st.div(className="mb-1") + gui.write("## All Plans") + plans_div = gui.div(className="mb-1") if user.subscription and user.subscription.payment_provider: selected_payment_provider = None else: - with st.div(): + with gui.div(): selected_payment_provider = PaymentProvider[ payment_provider_radio() or PaymentProvider.STRIPE.name ] @@ -146,8 +143,8 @@ def _render_plan(plan: PricingPlan): extra_class = "border-dark" else: extra_class = "bg-light" - with st.div(className="d-flex flex-column h-100"): - with st.div( + with gui.div(className="d-flex flex-column h-100"): + with gui.div( className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" ): _render_plan_details(plan) @@ -158,30 +155,32 @@ def _render_plan(plan: PricingPlan): with plans_div: grid_layout(4, all_plans, _render_plan, separator=False) - with st.div(className="my-2 d-flex justify-content-center"): - st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**") + with gui.div(className="my-2 d-flex justify-content-center"): + gui.caption( + f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**" + ) return selected_payment_provider def _render_plan_details(plan: PricingPlan): - with st.div(className="flex-grow-1"): - with st.div(className="mb-4"): - with st.tag("h4", className="mb-0"): - st.html(plan.title) - st.caption( + with gui.div(className="flex-grow-1"): + with gui.div(className="mb-4"): + with gui.tag("h4", className="mb-0"): + gui.html(plan.title) + gui.caption( plan.description, style={ "minHeight": "calc(var(--bs-body-line-height) * 2em)", "display": "block", }, ) - with st.div(className="my-3 w-100"): - with st.tag("h4", className="my-0 d-inline me-2"): - st.html(plan.pricing_title()) - with st.tag("span", className="text-muted my-0"): - st.html(plan.pricing_caption()) - st.write(plan.long_description, unsafe_allow_html=True) + with gui.div(className="my-3 w-100"): + with gui.tag("h4", className="my-0 d-inline me-2"): + gui.html(plan.pricing_title()) + with gui.tag("span", className="text-muted my-0"): + gui.html(plan.pricing_caption()) + gui.write(plan.long_description, unsafe_allow_html=True) def _render_plan_action_button( @@ -192,13 +191,13 @@ def _render_plan_action_button( ): btn_classes = "w-100 mt-3" if plan == current_plan: - st.button("Your Plan", className=btn_classes, disabled=True, type="tertiary") + gui.button("Your Plan", className=btn_classes, disabled=True, type="tertiary") elif plan.contact_us_link: - with st.link( + with gui.link( to=plan.contact_us_link, className=btn_classes + " btn btn-theme btn-primary", ): - st.html("Contact Us") + gui.html("Contact Us") elif user.subscription and not user.subscription.payment_provider: # don't show upgrade/downgrade buttons for enterprise customers # assumption: anyone without a payment provider attached is admin/enterprise @@ -262,11 +261,11 @@ def _render_update_subscription_button( key = f"change-sub-{plan.key}" match label: case "Downgrade": - downgrade_modal = Modal( + downgrade_modal = gui.Modal( "Confirm downgrade", key=f"downgrade-plan-modal-{plan.key}", ) - if st.button( + if gui.button( label, className=className, key=key, @@ -275,28 +274,28 @@ def _render_update_subscription_button( if downgrade_modal.is_open(): with downgrade_modal.container(): - st.write( + gui.write( f""" Are you sure you want to change from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**? """, className="d-block py-4", ) - with st.div(className="d-flex w-100"): - if st.button( + with gui.div(className="d-flex w-100"): + if gui.button( "Downgrade", className="btn btn-theme bg-danger border-danger text-white", key=f"{key}-confirm", ): change_subscription(user, plan) - if st.button( + if gui.button( "Cancel", className="border border-danger text-danger", key=f"{key}-cancel", ): downgrade_modal.close() case _: - if st.button(label, className=className, key=key): + if gui.button(label, className=className, key=key): change_subscription( user, plan, @@ -319,19 +318,19 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): current_plan = PricingPlan.from_sub(user.subscription) if new_plan == current_plan: - raise RedirectException(get_app_route_url(account_route), status_code=303) + raise gui.RedirectException(get_app_route_url(account_route), status_code=303) if new_plan == PricingPlan.STARTER: user.subscription.cancel() user.subscription.delete() - raise RedirectException( + raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) match user.subscription.payment_provider: case PaymentProvider.STRIPE: if not new_plan.supports_stripe(): - st.error(f"Stripe subscription not available for {new_plan}") + gui.error(f"Stripe subscription not available for {new_plan}") subscription = stripe.Subscription.retrieve(user.subscription.external_id) stripe.Subscription.modify( @@ -346,13 +345,13 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): **kwargs, proration_behavior="none", ) - raise RedirectException( + raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) case PaymentProvider.PAYPAL: if not new_plan.supports_paypal(): - st.error(f"Paypal subscription not available for {new_plan}") + gui.error(f"Paypal subscription not available for {new_plan}") subscription = paypal.Subscription.retrieve(user.subscription.external_id) paypal_plan_info = new_plan.get_paypal_plan() @@ -360,16 +359,16 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): plan_id=paypal_plan_info["plan_id"], plan=paypal_plan_info["plan"], ) - raise RedirectException(approval_url, status_code=303) + raise gui.RedirectException(approval_url, status_code=303) case _: - st.error("Not implemented for this payment provider") + gui.error("Not implemented for this payment provider") def payment_provider_radio(**props) -> str | None: - with st.div(className="d-flex"): - st.write("###### Pay Via", className="d-block me-3") - return st.radio( + with gui.div(className="d-flex"): + gui.write("###### Pay Via", className="d-block me-3") + return gui.radio( "", options=PaymentProvider.names, format_func=lambda name: f'{PaymentProvider[name].label}', @@ -379,10 +378,10 @@ def payment_provider_radio(**props) -> str | None: def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider): if user.subscription: - st.write("# Purchase More Credits") + gui.write("# Purchase More Credits") else: - st.write("# Purchase Credits") - st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") + gui.write("# Purchase Credits") + gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") if user.subscription: provider = PaymentProvider(user.subscription.payment_provider) @@ -396,22 +395,22 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid def render_paypal_addon_buttons(): - selected_amt = st.horizontal_radio( + selected_amt = gui.horizontal_radio( "", settings.ADDON_AMOUNT_CHOICES, format_func=lambda amt: f"${amt:,}", checked_by_default=False, ) if selected_amt: - st.js( + gui.js( f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" ) - st.div( + gui.div( id="paypal-addon-buttons", className="mt-2", style={"width": "fit-content"}, ) - st.div(id="paypal-result-message") + gui.div(id="paypal-result-message") def render_stripe_addon_buttons(user: AppUser): @@ -420,10 +419,10 @@ def render_stripe_addon_buttons(user: AppUser): def render_stripe_addon_button(dollat_amt: int, user: AppUser): - confirm_purchase_modal = Modal( + confirm_purchase_modal = gui.Modal( "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" ) - if st.button(f"${dollat_amt:,}", type="primary"): + if gui.button(f"${dollat_amt:,}", type="primary"): if user.subscription: confirm_purchase_modal.open() else: @@ -432,32 +431,32 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser): if not confirm_purchase_modal.is_open(): return with confirm_purchase_modal.container(): - st.write( + gui.write( f""" Please confirm your purchase: **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. """, className="py-4 d-block text-center", ) - with st.div(className="d-flex w-100 justify-content-end"): - if st.session_state.get("--confirm-purchase"): - success = st.run_in_thread( + with gui.div(className="d-flex w-100 justify-content-end"): + if gui.session_state.get("--confirm-purchase"): + success = gui.run_in_thread( user.subscription.stripe_attempt_addon_purchase, args=[dollat_amt], placeholder="Processing payment...", ) if success is None: return - st.session_state.pop("--confirm-purchase") + gui.session_state.pop("--confirm-purchase") if success: confirm_purchase_modal.close() else: - st.error("Payment failed... Please try again.") + gui.error("Payment failed... Please try again.") return - if st.button("Cancel", className="border border-danger text-danger me-2"): + if gui.button("Cancel", className="border border-danger text-danger me-2"): confirm_purchase_modal.close() - st.button("Buy", type="primary", key="--confirm-purchase") + gui.button("Buy", type="primary", key="--confirm-purchase") def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): @@ -478,7 +477,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): "payment_method_save": "enabled", }, ) - raise RedirectException(checkout_session.url, status_code=303) + raise gui.RedirectException(checkout_session.url, status_code=303) def render_stripe_subscription_button( @@ -490,13 +489,13 @@ def render_stripe_subscription_button( key: str, ): if not plan.supports_stripe(): - st.write("Stripe subscription not available") + gui.write("Stripe subscription not available") return # IMPORTANT: key=... is needed here to maintain uniqueness # of buttons with the same label. otherwise, all buttons # will be the same to the server - if st.button(label, key=key, type=btn_type): + if gui.button(label, key=key, type=btn_type): stripe_subscription_checkout_redirect(user=user, plan=plan) @@ -522,7 +521,7 @@ def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): "payment_method_save": "enabled", }, ) - raise RedirectException(checkout_session.url, status_code=303) + raise gui.RedirectException(checkout_session.url, status_code=303) def render_paypal_subscription_button( @@ -530,11 +529,11 @@ def render_paypal_subscription_button( plan: PricingPlan, ): if not plan.supports_paypal(): - st.write("Paypal subscription not available") + gui.write("Paypal subscription not available") return lookup_key = plan.key - st.html( + gui.html( f"""
str: @@ -617,8 +616,8 @@ def render_billing_history(user: AppUser, limit: int = 50): if not txns: return - st.write("## Billing History", className="d-block") - st.table( + gui.write("## Billing History", className="d-block") + gui.table( pd.DataFrame.from_records( [ { @@ -633,7 +632,7 @@ def render_billing_history(user: AppUser, limit: int = 50): ), ) if txns.count() > limit: - st.caption(f"Showing only the most recent {limit} transactions.") + gui.caption(f"Showing only the most recent {limit} transactions.") def render_auto_recharge_section(user: AppUser): @@ -643,9 +642,9 @@ def render_auto_recharge_section(user: AppUser): ) subscription = user.subscription - st.write("## Auto Recharge & Limits") - with st.div(className="h4"): - auto_recharge_enabled = st.checkbox( + gui.write("## Auto Recharge & Limits") + with gui.div(className="h4"): + auto_recharge_enabled = gui.checkbox( "Enable auto recharge", value=subscription.auto_recharge_enabled, ) @@ -656,20 +655,20 @@ def render_auto_recharge_section(user: AppUser): subscription.save(update_fields=["auto_recharge_enabled"]) if not auto_recharge_enabled: - st.caption( + gui.caption( "Enable auto recharge to automatically keep your credit balance topped up." ) return - col1, col2 = st.columns(2) - with col1, st.div(className="mb-2"): - subscription.auto_recharge_topup_amount = st.selectbox( + col1, col2 = gui.columns(2) + with col1, gui.div(className="mb-2"): + subscription.auto_recharge_topup_amount = gui.selectbox( "###### Automatically purchase", options=settings.ADDON_AMOUNT_CHOICES, format_func=lambda amt: f"{settings.ADDON_CREDITS_PER_DOLLAR * int(amt):,} credits for ${amt}", value=subscription.auto_recharge_topup_amount, ) - subscription.auto_recharge_balance_threshold = st.selectbox( + subscription.auto_recharge_balance_threshold = gui.selectbox( "###### when balance falls below", options=settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, format_func=lambda c: f"{c:,} credits", @@ -677,49 +676,51 @@ def render_auto_recharge_section(user: AppUser): ) with col2: - st.write("###### Monthly Recharge Budget") - st.caption( + gui.write("###### Monthly Recharge Budget") + gui.caption( """ If your account exceeds this budget in a given calendar month, subsequent runs & API requests will be rejected. """, ) - with st.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_budget = st.number_input( + with gui.div(className="d-flex align-items-center"): + user.subscription.monthly_spending_budget = gui.number_input( "", min_value=10, value=user.subscription.monthly_spending_budget, key="monthly-spending-budget", ) - st.write("USD", className="d-block ms-2") + gui.write("USD", className="d-block ms-2") - st.write("###### Email Notification Threshold") - st.caption( + gui.write("###### Email Notification Threshold") + gui.caption( """ If your account purchases exceed this threshold in a given calendar month, you will receive an email notification. """ ) - with st.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_notification_threshold = st.number_input( - "", - min_value=10, - value=user.subscription.monthly_spending_notification_threshold, - key="monthly-spending-notification-threshold", + with gui.div(className="d-flex align-items-center"): + user.subscription.monthly_spending_notification_threshold = ( + gui.number_input( + "", + min_value=10, + value=user.subscription.monthly_spending_notification_threshold, + key="monthly-spending-notification-threshold", + ) ) - st.write("USD", className="d-block ms-2") + gui.write("USD", className="d-block ms-2") - if st.button("Save", type="primary", key="save-auto-recharge-and-limits"): + if gui.button("Save", type="primary", key="save-auto-recharge-and-limits"): try: subscription.full_clean() except ValidationError as e: - st.error(str(e)) + gui.error(str(e)) else: subscription.save() - st.success("Settings saved!") + gui.success("Settings saved!") def left_and_right(*, className: str = "", **props): className += " d-flex flex-row justify-content-between align-items-center" - with st.div(className=className, **props): - return st.div(), st.div() + with gui.div(className=className, **props): + return gui.div(), gui.div() diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index c7174780e..2db423ed1 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -6,7 +6,7 @@ from django.utils.text import slugify from furl import furl -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform from daras_ai_v2 import settings, icons @@ -19,49 +19,49 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): - if st.session_state.get(f"_bi_reset_{bi.id}"): - st.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( + if gui.session_state.get(f"_bi_reset_{bi.id}"): + gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default ) - st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + gui.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( BotIntegration._meta.get_field("show_feedback_buttons").default ) - st.session_state["analysis_urls"] = [] - st.session_state.pop("--list-view:analysis_urls", None) + gui.session_state["analysis_urls"] = [] + gui.session_state.pop("--list-view:analysis_urls", None) if bi.platform != Platform.TWILIO: - bi.streaming_enabled = st.checkbox( + bi.streaming_enabled = gui.checkbox( "**๐Ÿ“ก Streaming Enabled**", value=bi.streaming_enabled, key=f"_bi_streaming_enabled_{bi.id}", ) - st.caption("Responses will be streamed to the user in real-time if enabled.") - bi.show_feedback_buttons = st.checkbox( + gui.caption("Responses will be streamed to the user in real-time if enabled.") + bi.show_feedback_buttons = gui.checkbox( "**๐Ÿ‘๐Ÿพ ๐Ÿ‘Ž๐Ÿฝ Show Feedback Buttons**", value=bi.show_feedback_buttons, key=f"_bi_show_feedback_buttons_{bi.id}", ) - st.caption( + gui.caption( "Users can rate and provide feedback on every copilot response if enabled." ) - st.write( + gui.write( """ ##### ๐Ÿง  Analysis Scripts Analyze each incoming message and the copilot's response using a Gooey.AI /LLM workflow. Must return a JSON object. [Learn more](https://gooey.ai/docs/guides/build-your-ai-copilot/conversation-analysis). """ ) - if "analysis_urls" not in st.session_state: - st.session_state["analysis_urls"] = [ + if "analysis_urls" not in gui.session_state: + gui.session_state["analysis_urls"] = [ (anal.published_run or anal.saved_run).get_app_url() for anal in bi.analysis_runs.all() ] - if st.session_state.get("analysis_urls"): + if gui.session_state.get("analysis_urls"): from recipes.VideoBots import VideoBotsPage - st.anchor( + gui.anchor( "๐Ÿ“Š View Results", str( furl( @@ -77,7 +77,7 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): input_analysis_runs = [] def render_workflow_url_input(key: str, del_key: str | None, d: dict): - with st.columns([3, 2])[0]: + with gui.columns([3, 2])[0]: ret = workflow_url_input( page_cls=CompareLLMPage, key=key, @@ -100,10 +100,10 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): flatten_dict_key="url", ) - with st.center(): - with st.div(): - pressed_update = st.button("โœ… Save") - pressed_reset = st.button( + with gui.center(): + with gui.div(): + pressed_update = gui.button("โœ… Save") + pressed_reset = gui.button( "Reset", key=f"_bi_reset_{bi.id}", type="tertiary" ) if pressed_update or pressed_reset: @@ -121,25 +121,25 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): # delete any analysis runs that were removed bi.analysis_runs.all().exclude(id__in=input_analysis_runs).delete() except ValidationError as e: - st.error(str(e)) - st.write("---") + gui.error(str(e)) + gui.write("---") def twilio_specific_settings(bi: BotIntegration): SETTINGS_FIELDS = ["twilio_use_missed_call", "twilio_initial_text", "twilio_initial_audio_url", "twilio_waiting_text", "twilio_waiting_audio_url"] # fmt:skip - if st.session_state.get(f"_bi_reset_{bi.id}"): + if gui.session_state.get(f"_bi_reset_{bi.id}"): for field in SETTINGS_FIELDS: - st.session_state[f"_bi_{field}_{bi.id}"] = BotIntegration._meta.get_field( + gui.session_state[f"_bi_{field}_{bi.id}"] = BotIntegration._meta.get_field( field ).default - bi.twilio_initial_text = st.text_area( + bi.twilio_initial_text = gui.text_area( "###### ๐Ÿ“ Initial Text (said at the beginning of each call)", value=bi.twilio_initial_text, key=f"_bi_twilio_initial_text_{bi.id}", ) bi.twilio_initial_audio_url = ( - st.file_uploader( + gui.file_uploader( "###### ๐Ÿ”Š Initial Audio (played at the beginning of each call)", accept=["audio/*"], key=f"_bi_twilio_initial_audio_url_{bi.id}", @@ -147,35 +147,35 @@ def twilio_specific_settings(bi: BotIntegration): or "" ) bi.twilio_waiting_audio_url = ( - st.file_uploader( + gui.file_uploader( "###### ๐ŸŽต Waiting Audio (played while waiting for a response -- Voice)", accept=["audio/*"], key=f"_bi_twilio_waiting_audio_url_{bi.id}", ) or "" ) - bi.twilio_waiting_text = st.text_area( + bi.twilio_waiting_text = gui.text_area( "###### ๐Ÿ“ Waiting Text (texted while waiting for a response -- SMS)", key=f"_bi_twilio_waiting_text_{bi.id}", ) - bi.twilio_use_missed_call = st.checkbox( + bi.twilio_use_missed_call = gui.checkbox( "๐Ÿ“ž Use Missed Call", value=bi.twilio_use_missed_call, key=f"_bi_twilio_use_missed_call_{bi.id}", ) - st.caption( + gui.caption( "When enabled, immediately hangs up incoming calls and calls back the user so they don't incur charges (depending on their carrier/plan)." ) def slack_specific_settings(bi: BotIntegration, default_name: str): - if st.session_state.get(f"_bi_reset_{bi.id}"): - st.session_state[f"_bi_name_{bi.id}"] = default_name - st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + if gui.session_state.get(f"_bi_reset_{bi.id}"): + gui.session_state[f"_bi_name_{bi.id}"] = default_name + gui.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( BotIntegration._meta.get_field("slack_read_receipt_msg").default ) - bi.slack_read_receipt_msg = st.text_input( + bi.slack_read_receipt_msg = gui.text_input( """ ##### โœ… Read Receipt This message is sent immediately after recieving a user message and replaced with the copilot's response once it's ready. @@ -185,7 +185,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str): value=bi.slack_read_receipt_msg, key=f"_bi_slack_read_receipt_msg_{bi.id}", ) - bi.name = st.text_input( + bi.name = gui.text_input( """ ##### ๐Ÿชช Channel Specific Bot Name This is the name the bot will post as in this specific channel (to be displayed in Slack) @@ -194,7 +194,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str): value=bi.name, key=f"_bi_name_{bi.id}", ) - st.caption("Enable streaming messages to Slack in real-time.") + gui.caption("Enable streaming messages to Slack in real-time.") def broadcast_input(bi: BotIntegration): @@ -209,7 +209,7 @@ def broadcast_input(bi: BotIntegration): ) / "docs" ) - text = st.text_area( + text = gui.text_area( f""" ###### Broadcast Message ๐Ÿ“ข Broadcast a message to all users of this integration using this bot account. \\ @@ -218,7 +218,7 @@ def broadcast_input(bi: BotIntegration): key=key + ":text", placeholder="Type your message here...", ) - audio = st.file_uploader( + audio = gui.file_uploader( "**๐ŸŽค Audio**", key=key + ":audio", help="Attach a video to this message.", @@ -229,20 +229,20 @@ def broadcast_input(bi: BotIntegration): documents = None medium = "Voice Call" if bi.platform == Platform.TWILIO: - medium = st.selectbox( + medium = gui.selectbox( "###### ๐Ÿ“ฑ Medium", ["Voice Call", "SMS/MMS"], key=key + ":medium", ) else: - video = st.file_uploader( + video = gui.file_uploader( "**๐ŸŽฅ Video**", key=key + ":video", help="Attach a video to this message.", optional=True, accept=["video/*"], ) - documents = st.file_uploader( + documents = gui.file_uploader( "**๐Ÿ“„ Documents**", key=key + ":documents", help="Attach documents to this message.", @@ -252,16 +252,16 @@ def broadcast_input(bi: BotIntegration): should_confirm_key = key + ":should_confirm" confirmed_send_btn = key + ":confirmed_send" - if st.button("๐Ÿ“ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"): - st.session_state[should_confirm_key] = True - if not st.session_state.get(should_confirm_key): + if gui.button("๐Ÿ“ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"): + gui.session_state[should_confirm_key] = True + if not gui.session_state.get(should_confirm_key): return convos = bi.conversations.all() - if st.session_state.get(confirmed_send_btn): - st.success("Started sending broadcast!") - st.session_state.pop(confirmed_send_btn) - st.session_state.pop(should_confirm_key) + if gui.session_state.get(confirmed_send_btn): + gui.success("Started sending broadcast!") + gui.session_state.pop(confirmed_send_btn) + gui.session_state.pop(should_confirm_key) send_broadcast_msgs_chunked( text=text, audio=audio, @@ -273,12 +273,12 @@ def broadcast_input(bi: BotIntegration): ) else: if not convos.exists(): - st.error("No users have interacted with this bot yet.", icon="โš ๏ธ") + gui.error("No users have interacted with this bot yet.", icon="โš ๏ธ") return - st.write( + gui.write( f"Are you sure? This will send a message to all {convos.count()} users that have ever interacted with this bot.\n" ) - st.button("โœ… Yes, Send", key=confirmed_send_btn) + gui.button("โœ… Yes, Send", key=confirmed_send_btn) def get_bot_test_link(bi: BotIntegration) -> str | None: @@ -328,43 +328,43 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str: def web_widget_config(bi: BotIntegration, user: AppUser | None): - with st.div(style={"width": "100%", "textAlign": "left"}): - col1, col2 = st.columns(2) + with gui.div(style={"width": "100%", "textAlign": "left"}): + col1, col2 = gui.columns(2) with col1: - if st.session_state.get("--update-display-picture"): - display_pic = st.file_uploader( + if gui.session_state.get("--update-display-picture"): + display_pic = gui.file_uploader( label="###### Display Picture", accept=["image/*"], ) if display_pic: bi.photo_url = display_pic else: - if st.button(f"{icons.camera} Change Photo"): - st.session_state["--update-display-picture"] = True - st.experimental_rerun() - bi.name = st.text_input("###### Name", value=bi.name) - bi.descripton = st.text_area( + if gui.button(f"{icons.camera} Change Photo"): + gui.session_state["--update-display-picture"] = True + gui.rerun() + bi.name = gui.text_input("###### Name", value=bi.name) + bi.descripton = gui.text_area( "###### Description", value=bi.descripton, ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - bi.by_line = st.text_input( + bi.by_line = gui.text_input( "###### By Line", value=bi.by_line or (user and f"By {user.display_name}"), ) with scol2: - bi.website_url = st.text_input( + bi.website_url = gui.text_input( "###### Website Link", value=bi.website_url or (user and user.website_url), ) - st.write("###### Conversation Starters") + gui.write("###### Conversation Starters") bi.conversation_starters = list( filter( None, [ - st.text_input("", key=f"--question-{i}", value=value) + gui.text_input("", key=f"--question-{i}", value=value) for i, value in zip_longest(range(4), bi.conversation_starters) ], ) @@ -385,39 +385,39 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): | bi.web_config_extras ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - config["showSources"] = st.checkbox( + config["showSources"] = gui.checkbox( "Show Sources", value=config["showSources"] ) - config["enablePhotoUpload"] = st.checkbox( + config["enablePhotoUpload"] = gui.checkbox( "Allow Photo Upload", value=config["enablePhotoUpload"] ) with scol2: - config["enableAudioMessage"] = st.checkbox( + config["enableAudioMessage"] = gui.checkbox( "Enable Audio Message", value=config["enableAudioMessage"] ) - config["enableLipsyncVideo"] = st.checkbox( + config["enableLipsyncVideo"] = gui.checkbox( "Enable Lipsync Video", value=config["enableLipsyncVideo"] ) - # config["branding"]["showPoweredByGooey"] = st.checkbox( + # config["branding"]["showPoweredByGooey"] = gui.checkbox( # "Show Powered By Gooey", value=config["branding"]["showPoweredByGooey"] # ) - with st.expander("Embed Settings"): - st.caption( + with gui.expander("Embed Settings"): + gui.caption( "These settings will take effect when you embed the widget on your website." ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - config["mode"] = st.selectbox( + config["mode"] = gui.selectbox( "###### Mode", ["popup", "inline", "fullscreen"], value=config["mode"], format_func=lambda x: x.capitalize(), ) if config["mode"] == "popup": - config["branding"]["fabLabel"] = st.text_input( + config["branding"]["fabLabel"] = gui.text_input( "###### Label", value=config["branding"].get("fabLabel", "Help"), ) @@ -427,28 +427,28 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): # remove defaults bi.web_config_extras = config - with st.div(className="d-flex justify-content-end"): - if st.button( + with gui.div(className="d-flex justify-content-end"): + if gui.button( f"{icons.save} Update Web Preview", type="primary", className="align-right", ): bi.save() - st.experimental_rerun() + gui.rerun() with col2: - with st.center(), st.div(): + with gui.center(), gui.div(): web_preview_tab = f"{icons.chat} Web Preview" api_tab = f"{icons.api} API" - selected = st.horizontal_radio("", [web_preview_tab, api_tab]) + selected = gui.horizontal_radio("", [web_preview_tab, api_tab]) if selected == web_preview_tab: - st.html( + gui.html( # language=html f"""
""" ) - st.js( + gui.js( # language=javascript """ async function loadGooeyEmbed() { diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index e00805385..e7721a6fe 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -2,6 +2,7 @@ import typing from datetime import datetime +import gooey_gui as gui from django.db import transaction from django.utils import timezone from fastapi import HTTPException @@ -24,7 +25,6 @@ from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT from daras_ai_v2.vector_search import doc_url_to_file_metadata -from gooey_ui.pubsub import realtime_subscribe from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe from recipes.VideoBots import VideoBotsPage, ReplyButton from routers.api import submit_api_call @@ -402,7 +402,7 @@ def _process_and_send_msg( if bot.streaming_enabled: # subscribe to the realtime channel for updates channel = page.realtime_channel_name(run_id, uid) - with realtime_subscribe(channel) as realtime_gen: + with gui.realtime_subscribe(channel) as realtime_gen: for state in realtime_gen: bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index c6d389b24..a0d06101e 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -1,6 +1,6 @@ import typing -import gooey_ui as st +import gooey_gui as gui from bots.models import ( SavedRun, PublishedRun, @@ -32,7 +32,7 @@ def has_breadcrumbs(self): def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, *, is_api_call: bool = False): - st.html( + gui.html( """ - """ - ) - - with st.div(className="blur-background"): - with st.div(className="modal-parent"): - container_class = "modal-container " + props.pop("className", "") - self._container = st.div(className=container_class, **props) - - with self._container: - with st.div(className="d-flex justify-content-between align-items-center"): - if self.title: - st.markdown(f"### {self.title}") - else: - st.div() - - close_ = st.button( - "✖", - type="tertiary", - key=f"{self.key}-close", - style={"padding": "0.375rem 0.75rem"}, - ) - if close_: - self.close() - yield self._container diff --git a/gooey_ui/components/pills.py b/gooey_ui/components/pills.py deleted file mode 100644 index c5f268099..000000000 --- a/gooey_ui/components/pills.py +++ /dev/null @@ -1,23 +0,0 @@ -import html -from typing import Literal - -import gooey_ui as gui - - -def pill( - title: str, - *, - unsafe_allow_html: bool = False, - text_bg: Literal["primary", "secondary", "light", "dark", None] = "light", - className: str = "", - **props, -): - if not unsafe_allow_html: - title = html.escape(title) - - className += f" badge rounded-pill" - if text_bg: - className += f" text-bg-{text_bg}" - - with gui.tag("span", className=className): - gui.html(title, **props) diff --git a/gooey_ui/components/url_button.py b/gooey_ui/components/url_button.py deleted file mode 100644 index 6b6732ad9..000000000 --- a/gooey_ui/components/url_button.py +++ /dev/null @@ -1,13 +0,0 @@ -import gooey_ui as st - - -def url_button(url): - st.html( - f""" - - - - """ - ) diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py deleted file mode 100644 index eb823ba37..000000000 --- a/gooey_ui/pubsub.py +++ /dev/null @@ -1,194 +0,0 @@ -import hashlib -import json -import threading -import typing -import uuid -from contextlib import contextmanager -from functools import lru_cache -from time import time - -from loguru import logger - -from daras_ai_v2 import settings - -T = typing.TypeVar("T") - -threadlocal = threading.local() - - -@lru_cache -def get_redis(): - import redis - - return redis.Redis.from_url(settings.REDIS_URL) - - -def realtime_clear_subs(): - threadlocal.channels = [] - - -def get_subscriptions() -> list[str]: - try: - return threadlocal.channels - except AttributeError: - threadlocal.channels = [] - return threadlocal.channels - - -def run_in_thread( - fn: typing.Callable, - *, - args: typing.Sequence = None, - kwargs: typing.Mapping = None, - placeholder: str = "...", - cache: bool = False, - ex=60, -): - from .state import session_state - from .components import write - - channel_key = f"--thread/{fn}" - try: - channel = session_state[channel_key] - except KeyError: - channel = session_state[channel_key] = ( - f"gooey-thread-fn/{fn.__name__}/{uuid.uuid1()}" - ) - - if args is None: - args = [] - if kwargs is None: - kwargs = {} - - def target(): - realtime_push(channel, dict(y=fn(*args, **kwargs)), ex=ex) - - threading.Thread(target=target).start() - - try: - return session_state[channel] - except KeyError: - pass - - result = realtime_pull([channel])[0] - if result: - ret = result["y"] - if cache: - session_state[channel] = ret - else: - session_state.pop(channel_key) - return ret - elif placeholder: - write(placeholder) - - -def realtime_pull(channels: list[str]) -> list[typing.Any]: - channels = [f"gooey-gui/state/{channel}" for channel in channels] - threadlocal.channels = channels - r = get_redis() - out = [ - json.loads(value) if (value := r.get(channel)) else None for channel in channels - ] - return out - - -def realtime_push(channel: str, value: typing.Any = "ping", ex=None): - from fastapi.encoders import jsonable_encoder - - channel = f"gooey-gui/state/{channel}" - msg = json.dumps(jsonable_encoder(value)) - r = get_redis() - r.set(channel, msg, ex=ex) - r.publish(channel, json.dumps(time())) - if isinstance(value, dict): - run_status = value.get("__run_status") - logger.info(f"publish {channel=} {run_status=}") - else: - logger.info(f"publish {channel=}") - - -@contextmanager -def realtime_subscribe(channel: str) -> typing.Generator: - channel = f"gooey-gui/state/{channel}" - r = get_redis() - pubsub = r.pubsub() - pubsub.subscribe(channel) - logger.info(f"subscribe {channel=}") - try: - yield _realtime_sub_gen(channel, pubsub) - finally: - logger.info(f"unsubscribe {channel=}") - pubsub.unsubscribe(channel) - pubsub.close() - - -def _realtime_sub_gen(channel: str, pubsub: "redis.client.PubSub") -> typing.Generator: - while True: - message = pubsub.get_message(timeout=10) - if not (message and message["type"] == "message"): - continue - r = get_redis() - value = json.loads(r.get(channel)) - if isinstance(value, dict): - run_status = value.get("__run_status") - logger.info(f"realtime_subscribe: {channel=} {run_status=}") - else: - logger.info(f"realtime_subscribe: {channel=}") - yield value - - -# def use_state( -# value: T = None, *, key: str = None -# ) -> tuple[T, typing.Callable[[T], None]]: -# if key is None: -# # use the session id & call stack to generate a unique key -# key = md5_values(get_session_id(), *traceback.format_stack()[:-1]) -# -# key = f"gooey-gui/state-hooks/{key}" -# -# hooks = get_hooks() -# hooks.setdefault(key, value) -# state = hooks[key] -# -# def set_state(value: T): -# publish_state(key, value) -# -# return state, set_state -# -# -# def publish_state(key: str, value: typing.Any): -# r = get_redis() -# jsonval = json.dumps(jsonable_encoder(value)) -# r.set(key, jsonval) -# r.publish(key, jsonval) -# -# -# def set_hooks(hooks: typing.Dict[str, typing.Any]): -# old = get_hooks() -# old.clear() -# old.update(hooks) -# -# -# def get_hooks() -> typing.Dict[str, typing.Any]: -# try: -# return threadlocal.hooks -# except AttributeError: -# threadlocal.hooks = {} -# return threadlocal.hooks -# -# -# def get_session_id() -> str | None: -# try: -# return threadlocal.session_id -# except AttributeError: -# threadlocal.session_id = None -# return threadlocal.session_id -# -# -# def set_session_id(session_id: str): -# threadlocal.session_id = session_id - - -def md5_values(*values) -> str: - strval = ".".join(map(repr, values)) - return hashlib.md5(strval.encode()).hexdigest() diff --git a/gooey_ui/state.py b/gooey_ui/state.py deleted file mode 100644 index 1bbcf88b7..000000000 --- a/gooey_ui/state.py +++ /dev/null @@ -1,229 +0,0 @@ -import hashlib -import inspect -import threading -import typing -import urllib.parse -from functools import wraps, partial - -from fastapi import Depends -from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel -from starlette.requests import Request -from starlette.responses import Response, RedirectResponse, JSONResponse - -from gooey_ui.pubsub import ( - get_subscriptions, - realtime_clear_subs, -) - -threadlocal = threading.local() - -session_state: dict[str, typing.Any] - - -def __getattr__(name): - if name == "session_state": - return get_session_state() - else: - raise AttributeError(name) - - -def get_query_params() -> dict[str, str]: - try: - return threadlocal.query_params - except AttributeError: - threadlocal.query_params = {} - return threadlocal.query_params - - -def set_query_params(params: dict[str, str]): - threadlocal.query_params = params - - -def get_session_state() -> dict[str, typing.Any]: - try: - return threadlocal.session_state - except AttributeError: - threadlocal.session_state = {} - return threadlocal.session_state - - -def set_session_state(state: dict[str, typing.Any]): - threadlocal.session_state = state - - -F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any]) - - -def cache_in_session_state(fn: F = None, key="__cache__") -> F: - def decorator(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - # hash the args and kwargs so they are not too long - args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest() - # create a readable cache key - cache_key = fn.__name__ + ":" + args_hash - state = get_session_state() - try: - # if the cache exists, return it - result = state[key][cache_key] - except KeyError: - # otherwise, run the function and cache the result - result = fn(*args, **kwargs) - state.setdefault(key, {})[cache_key] = result - return result - - return wrapper - - if fn: - return decorator(fn) - else: - return decorator - - -Style = dict[str, str | None] -ReactHTMLProps = dict[str, typing.Any] - - -class RenderTreeNode(BaseModel): - name: str - props: ReactHTMLProps = {} - children: list["RenderTreeNode"] = [] - - def mount(self) -> "RenderTreeNode": - threadlocal._render_root.children.append(self) - return self - - -class NestingCtx: - def __init__(self, node: RenderTreeNode = None): - self.node = node or threadlocal._render_root - self.parent = None - - def __enter__(self): - try: - self.parent = threadlocal._render_root - except AttributeError: - pass - threadlocal._render_root = self.node - - def __exit__(self, exc_type, exc_val, exc_tb): - threadlocal._render_root = self.parent - - def empty(self): - """Empty the children of the node""" - self.node.children = [] - return self - - -class RedirectException(Exception): - def __init__(self, url, status_code=302): - self.url = url - self.status_code = status_code - - -class QueryParamsRedirectException(RedirectException): - def __init__(self, query_params: dict, status_code=303): - query_params = {k: v for k, v in query_params.items() if v is not None} - url = "?" + urllib.parse.urlencode(query_params) - super().__init__(url, status_code) - - -def route(app, *paths, **kwargs): - def decorator(fn): - @wraps(fn) - def wrapper(request: Request, json_data: dict | None, **kwargs): - if "request" in fn_sig.parameters: - kwargs["request"] = request - if "json_data" in fn_sig.parameters: - kwargs["json_data"] = json_data - return renderer( - partial(fn, **kwargs), - query_params=dict(request.query_params), - state=json_data and json_data.get("state"), - ) - - fn_sig = inspect.signature(fn) - mod_params = dict(fn_sig.parameters) | dict( - request=inspect.Parameter( - "request", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=Request, - ), - json_data=inspect.Parameter( - "json_data", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=Depends(_request_json), - annotation=typing.Optional[dict], - ), - ) - mod_sig = fn_sig.replace(parameters=list(mod_params.values())) - wrapper.__signature__ = mod_sig - - for path in reversed(paths): - wrapper = app.get(path)(wrapper) - wrapper = app.post(path)(wrapper) - - return wrapper - - return decorator - - -def renderer( - fn: typing.Callable, - state: dict[str, typing.Any] = None, - query_params: dict[str, str] = None, -) -> dict | Response: - set_session_state(state or {}) - set_query_params(query_params or {}) - realtime_clear_subs() - while True: - try: - root = RenderTreeNode(name="root") - try: - with NestingCtx(root): - ret = fn() - except StopException: - ret = None - except RedirectException as e: - return RedirectResponse(e.url, status_code=e.status_code) - if isinstance(ret, Response): - return ret - return JSONResponse( - jsonable_encoder( - dict( - children=root.children, - state=get_session_state(), - channels=get_subscriptions(), - **(ret or {}), - ) - ), - headers={"X-GOOEY-GUI-ROUTE": "1"}, - ) - except RerunException: - continue - - -async def _request_json(request: Request) -> dict | None: - if request.headers.get("content-type") == "application/json": - return await request.json() - - -def experimental_rerun(): - raise RerunException() - - -def stop(): - raise StopException() - - -class StopException(Exception): - pass - - -class RerunException(Exception): - pass - - -class UploadedFile: - pass diff --git a/pages/Stats.py b/pages/Stats.py index 93716cd43..e2d64bf4d 100644 --- a/pages/Stats.py +++ b/pages/Stats.py @@ -5,7 +5,7 @@ import datetime import pandas as pd -import streamlit as st +import streamlit as gui from django.db.models import Count from bots.models import ( @@ -16,10 +16,10 @@ Conversation, ) -st.set_page_config(layout="wide") +gui.set_page_config(layout="wide") -@st.cache_data +@gui.cache_data def convert_df(df): return df.to_csv(index=True).encode("utf-8") @@ -29,7 +29,7 @@ def convert_df(df): def main(): - st.markdown( + gui.markdown( """ # Gooey.AI Bot Stats """ @@ -40,18 +40,18 @@ def main(): if bot.id == DEFAULT_BOT_ID: default_bot_index = index - col1, col2 = st.columns([25, 75]) + col1, col2 = gui.columns([25, 75]) with col1: - bot = st.selectbox( + bot = gui.selectbox( "Select Bot", index=default_bot_index, options=[b for b in bots], # format_func=lambda b: f"{b}", ) - view = st.radio("View", ["Daily", "Weekly"], index=0) - start_date = st.date_input("Start date", START_DATE) + view = gui.radio("View", ["Daily", "Weekly"], index=0) + start_date = gui.date_input("Start date", START_DATE) if bot and start_date: - with st.spinner("Loading stats..."): + with gui.spinner("Loading stats..."): date = start_date data = [] convo_by_msg_exchanged = ( @@ -137,19 +137,19 @@ def main(): df = pd.DataFrame(data) csv = convert_df(df) - st.write("### Stats") - st.download_button( + gui.write("### Stats") + gui.download_button( "Download as csv", csv, "file.csv", "text/csv", key="download-csv" ) with col2: - st.line_chart( + gui.line_chart( df, x="date", y=[ "Messages_Sent", ], ) - st.line_chart( + gui.line_chart( df, x="date", y=[ @@ -161,16 +161,16 @@ def main(): ], ) - st.dataframe(df, use_container_width=True) - col1, col2 = st.columns(2) + gui.dataframe(df, use_container_width=True) + col1, col2 = gui.columns(2) with col1: - st.write("### Top 5 users") - st.caption(f"{start_date} - {datetime.date.today()}") - st.dataframe(top_5_conv_by_message) + gui.write("### Top 5 users") + gui.caption(f"{start_date} - {datetime.date.today()}") + gui.dataframe(top_5_conv_by_message) with col2: - st.write("### Bottom 5 users") - st.caption(f"{start_date} - {datetime.date.today()}") - st.dataframe(bottom_5_conv_by_message) + gui.write("### Bottom 5 users") + gui.caption(f"{start_date} - {datetime.date.today()}") + gui.dataframe(bottom_5_conv_by_message) main() diff --git a/poetry.lock b/poetry.lock index 8c8439c32..73fc4dd20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1596,6 +1596,30 @@ face = ">=20.1.0" [package.extras] yaml = ["PyYAML"] +[[package]] +name = "gooey-gui" +version = "0.1.0" +description = "" +optional = false +python-versions = "<4.0,>=3.10" +files = [ + {file = "gooey_gui-0.1.0-py3-none-any.whl", hash = "sha256:508d1ac31b1f18e371a949aa6cc92b473d93db24f1856d69d4a0df588a726633"}, + {file = "gooey_gui-0.1.0.tar.gz", hash = "sha256:22c44ed4476c573b510d5610d71f3600a60341246322507f0e3e2f81c76c5c6d"}, +] + +[package.dependencies] +fastapi = ">=0.85.2,<0.86.0" +furl = ">=2.1.3,<3.0.0" +loguru = ">=0.7.2,<0.8.0" +pydantic = ">=1.10.12,<2.0.0" +python-decouple = ">=3.6,<4.0" +python-multipart = ">=0.0.6,<0.0.7" +redis = ">=4.5.1,<5.0.0" +uvicorn = {version = ">=0.18.3,<0.19.0", extras = ["standard"]} + +[package.extras] +image = ["numpy (>=1.25.0,<2.0.0)", "opencv-contrib-python (>=4.7.0.72,<5.0.0.0)"] + [[package]] name = "google-api-core" version = "1.34.0" @@ -4366,16 +4390,17 @@ cli = ["click (>=5.0)"] [[package]] name = "python-multipart" -version = "0.0.5" +version = "0.0.6" description = "A streaming multipart parser for Python" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "python-multipart-0.0.5.tar.gz", hash = "sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43"}, + {file = "python_multipart-0.0.6-py3-none-any.whl", hash = "sha256:ee698bab5ef148b0a760751c261902cd096e57e10558e11aca17646b74ee1c18"}, + {file = "python_multipart-0.0.6.tar.gz", hash = "sha256:e9925a80bb668529f1b67c7fdb0a5dacdd7cbfc6fb0bff3ea443fe22bdd62132"}, ] -[package.dependencies] -six = ">=1.4.0" +[package.extras] +dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] [[package]] name = "python-slugify" @@ -6441,4 +6466,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "13471406b7a3cbbd4df6d929cc5e2fc796ba31445756da270a8e06573383567b" +content-hash = "3955eb5901ce23cc6e25cf4d45c9a742d830ea2a63a60000e1cfc1d93c6299a6" diff --git a/pyproject.toml b/pyproject.toml index 73d5af60f..d44685d2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ Wand = "^0.6.10" readability-lxml = "^0.8.1" transformers = "^4.24.0" stripe = "^10.3.0" -python-multipart = "^0.0.5" +python-multipart = "^0.0.6" html-sanitizer = "^1.9.3" plotly = "^5.11.0" httpx = "^0.23.1" @@ -78,7 +78,7 @@ ua-parser = "^0.18.0" user-agents = "^2.2.0" openpyxl = "^3.1.2" loguru = "^0.7.2" -aifail = {git = "https://github.com/GooeyAI/aifail/", rev = "0.3.0"} +aifail = "^0.3.0" pytest-playwright = "^0.4.3" emoji = "^2.10.1" pyvespa = "^0.39.0" @@ -86,6 +86,7 @@ anthropic = "^0.25.5" azure-cognitiveservices-speech = "^1.37.0" twilio = "^9.2.3" sentry-sdk = {version = "1.45.0", extras = ["loguru"]} +gooey-gui = "^0.1.0" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 5cb6ddf34..ce25e512d 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -8,7 +8,7 @@ import typing_extensions from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.base import BasePage @@ -82,7 +82,7 @@ def _render_results(results: list[AggFunctionResult]): from plotly.colors import sample_colorscale for k, g in itertools.groupby(results, key=lambda d: d["function"]): - st.write("---\n###### **Aggregate**: " + k.capitalize()) + gui.write("---\n###### **Aggregate**: " + k.capitalize()) g = list(g) @@ -95,7 +95,7 @@ def _render_results(results: list[AggFunctionResult]): colors = sample_colorscale("RdYlGn", norm_values, colortype="tuple") colors = [f"rgba{(r * 255, g * 255, b * 255, 0.5)}" for r, g, b in colors] - st.data_table( + gui.data_table( [ ["Metric", k.capitalize(), "Count"], ] @@ -132,7 +132,7 @@ def _render_results(results: list[AggFunctionResult]): margin=dict(l=0, r=0, t=24, b=0), ), ) - st.plotly_chart(fig) + gui.plotly_chart(fig) class BulkEvalPage(BasePage): @@ -146,7 +146,7 @@ def preview_image(self, state: dict) -> str | None: return "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/9631fb74-9a97-11ee-971f-02420a0001c4/evaluator.png.png" def render_description(self): - st.write( + gui.write( """ Summarize and score every row of any CSV, google sheet or excel with GPT4 (or any LLM you choose). Then average every score in any column to generate automated evaluations. """ @@ -202,31 +202,31 @@ def render_form_v2(self): f"##### {field_title_desc(self.RequestModel, 'documents')}", accept=SUPPORTED_SPREADSHEET_TYPES, ) - st.session_state[NROWS_CACHE_KEY] = get_nrows(files) + gui.session_state[NROWS_CACHE_KEY] = get_nrows(files) if not files: return - st.write( + gui.write( """ ##### Input Data Preview Here's what you uploaded: """ ) for file in files: - st.data_table(file) - st.write("---") + gui.data_table(file) + gui.write("---") def render_inputs(key: str, del_key: str, d: EvalPrompt): - col1, col2 = st.columns([8, 1], responsive=False) + col1, col2 = gui.columns([8, 1], responsive=False) with col1: - d["name"] = st.text_input( + d["name"] = gui.text_input( label="", label_visibility="collapsed", placeholder="Metric Name", key=key + ":name", value=d.get("name"), ).strip() - d["prompt"] = st.text_area( + d["prompt"] = gui.text_area( label="", label_visibility="collapsed", placeholder="Prompt", @@ -237,7 +237,7 @@ def render_inputs(key: str, del_key: str, d: EvalPrompt): with col2: del_button(del_key) - st.write("##### " + field_title_desc(self.RequestModel, "eval_prompts")) + gui.write("##### " + field_title_desc(self.RequestModel, "eval_prompts")) list_view_editor( add_btn_label="โž• Add a Prompt", key="eval_prompts", @@ -245,10 +245,10 @@ def render_inputs(key: str, del_key: str, d: EvalPrompt): ) def render_agg_inputs(key: str, del_key: str, d: AggFunction): - col1, col3 = st.columns([8, 1], responsive=False) + col1, col3 = gui.columns([8, 1], responsive=False) with col1: - with st.div(className="pt-1"): - d["function"] = st.selectbox( + with gui.div(className="pt-1"): + d["function"] = gui.selectbox( "", label_visibility="collapsed", key=key + ":func", @@ -258,8 +258,8 @@ def render_agg_inputs(key: str, del_key: str, d: AggFunction): with col3: del_button(del_key) - st.html("
") - st.write("##### " + field_title_desc(self.RequestModel, "agg_functions")) + gui.html("
") + gui.write("##### " + field_title_desc(self.RequestModel, "agg_functions")) list_view_editor( add_btn_label="โž• Add an Aggregation", key="agg_functions", @@ -273,12 +273,12 @@ def render_example(self, state: dict): render_documents(state) def render_output(self): - files = st.session_state.get("output_documents", []) - aggregations = st.session_state.get("aggregations", []) + files = gui.session_state.get("output_documents", []) + aggregations = gui.session_state.get("aggregations", []) for file, results in zip_longest(files, aggregations): - st.write(file) - st.data_table(file) + gui.write(file) + gui.data_table(file) if not results: continue @@ -313,14 +313,14 @@ def get_raw_price(self, state: dict) -> float: return price * nprompts * nrows def render_steps(self): - documents = st.session_state.get("documents") or [] - final_prompts = st.session_state.get("final_prompts") or [] + documents = gui.session_state.get("documents") or [] + final_prompts = gui.session_state.get("final_prompts") or [] for doc, prompts in zip_longest(documents, final_prompts): if not prompts: continue - st.write(f"###### {doc}") + gui.write(f"###### {doc}") for i, prompt in enumerate(prompts): - st.text_area("", value=prompt, key=f"--final-prompt-{i}") + gui.text_area("", value=prompt, key=f"--final-prompt-{i}") class TaskResult(typing.NamedTuple): @@ -348,7 +348,7 @@ def submit( for ep_ix, ep in enumerate(request.eval_prompts): prompt = render_prompt_vars( ep["prompt"], - st.session_state | {"columns": current_rec}, + gui.session_state | {"columns": current_rec}, ) response.final_prompts[doc_ix].append(prompt) futs.append( @@ -429,7 +429,7 @@ def iterate( response.aggregations[result.doc_ix] = aggs -@st.cache_in_session_state +@gui.cache_in_session_state def get_nrows(files: list[str]) -> int: try: dfs = map_parallel(read_df_any, files) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index bde4e96c2..2cbde30e0 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -5,7 +5,7 @@ from furl import furl from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow, SavedRun from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.base import BasePage @@ -31,7 +31,6 @@ get_published_run_options, edit_done_button, ) -from gooey_ui.components.url_button import url_button from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.DocSearch import render_documents @@ -96,7 +95,7 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_BULK_META_IMG def render_form_v2(self): - st.write(f"##### {field_title_desc(self.RequestModel, 'run_urls')}") + gui.write(f"##### {field_title_desc(self.RequestModel, 'run_urls')}") run_urls = list_view_editor( add_btn_label="โž• Add a Workflow", key="run_urls", @@ -155,19 +154,19 @@ def render_form_v2(self): for field, model_field in page_cls.ResponseModel.__fields__.items() } - st.write( + gui.write( """ ###### **Preview**: Here's what you uploaded """ ) for file in files: - st.data_table(file) + gui.data_table(file) if not (required_input_fields or optional_input_fields): return - with st.div(className="pt-3"): - st.write( + with gui.div(className="pt-3"): + gui.write( """ ###### **Columns** Please select which CSV column corresponds to your workflow's input fields. @@ -176,17 +175,17 @@ def render_form_v2(self): """, ) - visible_col1, visible_col2 = st.columns(2) - with st.expander("๐Ÿคฒ Show All Columns"): - hidden_col1, hidden_col2 = st.columns(2) + visible_col1, visible_col2 = gui.columns(2) + with gui.expander("๐Ÿคฒ Show All Columns"): + hidden_col1, hidden_col2 = gui.columns(2) with visible_col1: - st.write("##### Inputs") + gui.write("##### Inputs") with hidden_col1: - st.write("##### Inputs") + gui.write("##### Inputs") - input_columns_old = st.session_state.pop("input_columns", {}) - input_columns_new = st.session_state.setdefault("input_columns", {}) + input_columns_old = gui.session_state.pop("input_columns", {}) + input_columns_new = gui.session_state.setdefault("input_columns", {}) column_options = [None, *get_columns(files)] for fields, div in ( @@ -195,7 +194,7 @@ def render_form_v2(self): ): for field, title in fields.items(): with div: - col = st.selectbox( + col = gui.selectbox( label="`" + title + "`", options=column_options, key="--input-mapping:" + field, @@ -205,9 +204,9 @@ def render_form_v2(self): input_columns_new[field] = col with visible_col2: - st.write("##### Outputs") + gui.write("##### Outputs") with hidden_col2: - st.write("##### Outputs") + gui.write("##### Outputs") visible_out_fields = {} # only show the first output & run url field by default, and hide others @@ -227,8 +226,8 @@ def render_form_v2(self): "error_msg": "Error Msg", } | {k: v for k, v in output_fields.items() if k not in visible_out_fields} - output_columns_old = st.session_state.pop("output_columns", {}) - output_columns_new = st.session_state.setdefault("output_columns", {}) + output_columns_old = gui.session_state.pop("output_columns", {}) + output_columns_new = gui.session_state.setdefault("output_columns", {}) for fields, div, checked in ( (visible_out_fields, visible_col2, True), @@ -236,7 +235,7 @@ def render_form_v2(self): ): for field, title in fields.items(): with div: - col = st.checkbox( + col = gui.checkbox( label="`" + title + "`", key="--output-mapping:" + field, value=bool(output_columns_old.get(field, checked)), @@ -244,8 +243,8 @@ def render_form_v2(self): if col: output_columns_new[field] = title - st.write("---") - st.write(f"##### {field_title_desc(self.RequestModel, 'eval_urls')}") + gui.write("---") + gui.write(f"##### {field_title_desc(self.RequestModel, 'eval_urls')}") list_view_editor( add_btn_label="โž• Add an Eval", key="eval_urls", @@ -257,28 +256,28 @@ def render_example(self, state: dict): render_documents(state) def render_output(self): - eval_runs = st.session_state.get("eval_runs") + eval_runs = gui.session_state.get("eval_runs") if eval_runs: - _backup = st.session_state + _backup = gui.session_state for url in eval_runs: try: page_cls, sr, _ = url_to_runs(url) except SavedRun.DoesNotExist: continue - st.set_session_state(sr.state) + gui.set_session_state(sr.state) try: page_cls().render_output() except Exception as e: - st.error(repr(e)) - st.write(url) - st.write("---") - st.set_session_state(_backup) + gui.error(repr(e)) + gui.write(url) + gui.write("---") + gui.set_session_state(_backup) else: - files = st.session_state.get("output_documents", []) + files = gui.session_state.get("output_documents", []) for file in files: - st.write(file) - st.data_table(file) + gui.write(file) + gui.data_table(file) def run_v2( self, @@ -406,7 +405,7 @@ def preview_description(self, state: dict) -> str: """ def render_description(self): - st.write( + gui.write( """ Building complex AI workflows like copilot) and then evaluating each iteration is complex. Workflows are affected by the particular LLM used (GPT4 vs PalM2), their vector DB knowledge sets (e.g. your google docs), how synthetic data creation happened (e.g. how you transformed your video transcript or PDF into structured data), which translation or speech engine you used and your LLM prompts. Every change can affect the quality of your outputs. @@ -430,10 +429,10 @@ def render_run_url_inputs(self, key: str, del_key: str, d: dict): init_workflow_selector(d, key) - col1, col2, col3, col4 = st.columns([9, 1, 1, 1], responsive=False) + col1, col2, col3, col4 = gui.columns([9, 1, 1, 1], responsive=False) if not d.get("workflow") and d.get("url"): with col1: - url = st.text_input( + url = gui.text_input( "", key=key, value=d.get("url"), @@ -443,34 +442,35 @@ def render_run_url_inputs(self, key: str, del_key: str, d: dict): edit_done_button(key) else: with col1: - scol1, scol2 = st.columns([1, 1], responsive=False) + scol1, scol2 = gui.columns([1, 1], responsive=False) with scol1: - with st.div(className="pt-1"): + with gui.div(className="pt-1"): options = { page_cls.workflow: page_cls.get_recipe_title() for page_cls in all_home_pages } last_workflow_key = "__last_run_url_workflow" - workflow = st.selectbox( + workflow = gui.selectbox( "", key=key + ":workflow", value=( - d.get("workflow") or st.session_state.get(last_workflow_key) + d.get("workflow") + or gui.session_state.get(last_workflow_key) ), options=options, format_func=lambda x: options[x], ) d["workflow"] = workflow # use this to set default for next time - st.session_state[last_workflow_key] = workflow + gui.session_state[last_workflow_key] = workflow with scol2: page_cls = Workflow(workflow).page_cls options = get_published_run_options( page_cls, current_user=self.request.user ) options.update(d.get("--added_workflows", {})) - with st.div(className="pt-1"): - url = st.selectbox( + with gui.div(className="pt-1"): + url = gui.selectbox( "", key=key, options=options, @@ -480,14 +480,14 @@ def render_run_url_inputs(self, key: str, del_key: str, d: dict): with col2: edit_button(key) with col3: - url_button(url) + gui.url_button(url) with col4: del_button(del_key) try: url_to_runs(url) except Exception as e: - st.error(repr(e)) + gui.error(repr(e)) d["url"] = url def render_eval_url_inputs(self, key: str, del_key: str | None, d: dict): @@ -590,7 +590,7 @@ def is_obj(field_props: dict | None) -> bool: return bool(field_props.get("type") == "object" or field_props.get("$ref")) -@st.cache_in_session_state +@gui.cache_in_session_state def get_columns(files: list[str]) -> list[str]: try: dfs = map_parallel(read_df_any, files) @@ -627,9 +627,9 @@ def list_view_editor( ): if flatten_dict_key: list_key = f"--list-view:{key}" - st.session_state.setdefault( + gui.session_state.setdefault( list_key, - [{flatten_dict_key: val} for val in st.session_state.get(key, [])], + [{flatten_dict_key: val} for val in gui.session_state.get(key, [])], ) new_lst = list_view_editor( add_btn_label=add_btn_label, @@ -638,25 +638,25 @@ def list_view_editor( render_inputs=render_inputs, ) ret = [d[flatten_dict_key] for d in new_lst] - st.session_state[key] = ret + gui.session_state[key] = ret return ret - old_lst = st.session_state.get(key) or [] + old_lst = gui.session_state.get(key) or [] add_key = f"--{key}:add" - if st.session_state.get(add_key): + if gui.session_state.get(add_key): old_lst.append({}) - label_placeholder = st.div() + label_placeholder = gui.div() new_lst = [] for d in old_lst: entry_key = d.setdefault("__key__", f"--{key}:{uuid.uuid1()}") del_key = entry_key + ":del" - if st.session_state.pop(del_key, None): + if gui.session_state.pop(del_key, None): continue render_inputs(entry_key, del_key, d) new_lst.append(d) if new_lst and render_labels: with label_placeholder: render_labels() - st.session_state[key] = new_lst - st.button(add_btn_label, key=add_key) + gui.session_state[key] = new_lst + gui.button(add_btn_label, key=add_key) return new_lst diff --git a/recipes/ChyronPlant.py b/recipes/ChyronPlant.py index 09b1076b7..fef4714b9 100644 --- a/recipes/ChyronPlant.py +++ b/recipes/ChyronPlant.py @@ -1,8 +1,7 @@ +import gooey_gui as gui from pydantic import BaseModel -import gooey_ui as st from bots.models import Workflow -from daras_ai_v2 import settings from daras_ai_v2.base import ( BasePage, ) @@ -32,7 +31,7 @@ class ResponseModel(BaseModel): chyron_output: str def render_form_v2(self): - st.text_input( + gui.text_input( """ ### Input Midi notes """, @@ -40,25 +39,25 @@ def render_form_v2(self): ) def render_output(self): - st.text_area( + gui.text_area( """ **MIDI translation** """, - value=st.session_state.get("midi_translation", ""), + value=gui.session_state.get("midi_translation", ""), disabled=True, ) - st.text_area( + gui.text_area( """ ### Chyron Output """, disabled=True, - value=st.session_state.get("chyron_output", ""), + value=gui.session_state.get("chyron_output", ""), height=300, ) def render_settings(self): - st.text_area( + gui.text_area( """ ### Midi Notes -> English GPT script """, @@ -66,7 +65,7 @@ def render_settings(self): height=500, ) - st.text_area( + gui.text_area( """ ### Chyron Plant Radbot script """, @@ -128,8 +127,8 @@ def run_chyron(self, state: dict): return "" def render_example(self, state): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.write(state.get("midi_translation", "")) + gui.write(state.get("midi_translation", "")) with col2: - st.write(state.get("chyron_output", "")) + gui.write(state.get("chyron_output", "")) diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index b5b1cc2ee..cc8c93313 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -67,7 +67,7 @@ def get_example_preferred_fields(cls, state: dict) -> list[str]: return ["input_prompt", "selected_models"] def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt """, @@ -84,8 +84,8 @@ def render_form_v2(self): ) def validate_form_v2(self): - assert st.session_state["input_prompt"], "Please enter a Prompt" - assert st.session_state["selected_models"], "Please select at least one model" + assert gui.session_state["input_prompt"], "Please enter a Prompt" + assert gui.session_state["selected_models"], "Please select at least one model" def render_usage_guide(self): youtube_video("dhexRRDAuY8") @@ -118,15 +118,15 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): - _render_outputs(st.session_state, 450) + _render_outputs(gui.session_state, 450) def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.write("**Prompt**") - st.write("```jinja2\n" + state.get("input_prompt", "") + "\n```") + gui.write("**Prompt**") + gui.write("```jinja2\n" + state.get("input_prompt", "") + "\n```") for key, value in state.get("variables", {}).items(): - st.text_area(f"`{key}`", value=str(value), disabled=True) + gui.text_area(f"`{key}`", value=str(value), disabled=True) with col2: _render_outputs(state, 300) @@ -163,7 +163,7 @@ def _render_outputs(state, height): for key in selected_models: output_text: dict = state.get("output_text", {}).get(key, []) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( f"**{LargeLanguageModels[key].value}**", help=f"output {key} {idx} {random.random()}", disabled=True, diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index ee1f4cacd..f41a46170 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -3,7 +3,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.descriptions import prompting101 @@ -100,7 +100,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt Describe the scene that you'd like to generate. @@ -108,16 +108,16 @@ def render_form_v2(self): key="text_prompt", placeholder="Iron man", ) - st.caption( + gui.caption( """ Refer to the saved examples or our basic prompt guide in the โ€˜Detailsโ€™ dropdown menu. """ ) - st.write("#### ๐Ÿงจ Compare Image Models") - st.caption( + gui.write("#### ๐Ÿงจ Compare Image Models") + gui.caption( "Each selected model costs 2 credits to run except for Dall-E which is 15 credits per rendered image." ) - st.caption( + gui.caption( """ Confused about what each model looks like? [Check out our prompt guide](https://docs.google.com/presentation/d/1RaoMP0l7FnBZovDAR42zVmrUND9W5DW6eWet-pi6kiE/edit#slide=id.g210b1678eba_0_26). @@ -129,14 +129,14 @@ def render_form_v2(self): ) def validate_form_v2(self): - assert st.session_state["text_prompt"], "Please provide a prompt" - assert st.session_state["selected_models"], "Please select at least one model" + assert gui.session_state["text_prompt"], "Please provide a prompt" + assert gui.session_state["selected_models"], "Please select at least one model" def render_usage_guide(self): youtube_video("TxT-mTYP0II") def render_description(self): - st.markdown( + gui.markdown( """ This recipe takes any text and renders an image using multiple Text2Image engines. Use it to understand which image generator e.g. DallE or Stable Diffusion is best for your particular prompt. @@ -145,42 +145,42 @@ def render_description(self): prompting101() def render_settings(self): - st.write( + gui.write( """ Customize the image output for your text prompt with these Settings. """ ) - st.caption( + gui.caption( """ You can also enable โ€˜Edit Instructionsโ€™ to use InstructPix2Pix that allows you to change your generated image output with a follow-up written instruction. """ ) - if st.checkbox("๐Ÿ“ Edit Instructions"): - st.text_area( + if gui.checkbox("๐Ÿ“ Edit Instructions"): + gui.text_area( """ Describe how you want to change the generated image using [InstructPix2Pix](https://www.timothybrooks.com/instruct-pix2pix). """, key="__edit_instruction", placeholder="Give it sunglasses and a mustache", ) - st.session_state["edit_instruction"] = st.session_state.get( + gui.session_state["edit_instruction"] = gui.session_state.get( "__edit_instruction" ) negative_prompt_setting() output_resolution_setting() - num_outputs_setting(st.session_state.get("selected_models", [])) + num_outputs_setting(gui.session_state.get("selected_models", [])) sd_2_upscaling_setting() - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: guidance_scale_setting() scheduler_setting() with col2: - if st.session_state.get("edit_instruction"): + if gui.session_state.get("edit_instruction"): instruct_pix2pix_settings() def render_output(self): - self._render_outputs(st.session_state) + self._render_outputs(gui.session_state) def run(self, state: dict) -> typing.Iterator[str | None]: request: CompareText2ImgPage.RequestModel = self.RequestModel.parse_obj(state) @@ -241,9 +241,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ] def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.markdown("```properties\n" + state.get("text_prompt", "") + "\n```") + gui.markdown("```properties\n" + state.get("text_prompt", "") + "\n```") with col2: self._render_outputs(state) @@ -252,7 +252,7 @@ def _render_outputs(self, state): for key in selected_models: output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: - st.image( + gui.image( img, caption=Text2ImgModels[key].value, show_download_button=True ) diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 55f6c4a0f..4684ab309 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -46,11 +46,11 @@ class ResponseModel(BaseModel): ] = Field({}, description="Output Videos") def validate_form_v2(self): - assert st.session_state.get( + assert gui.session_state.get( "selected_models" ), "Please select at least one model" - assert st.session_state.get("input_image") or st.session_state.get( + assert gui.session_state.get("input_image") or gui.session_state.get( "input_video" ), "Please provide an Input Image or Video" @@ -83,22 +83,22 @@ def run_v2( ) def render_form_v2(self): - selected_input_type = st.horizontal_radio( + selected_input_type = gui.horizontal_radio( "", options=["Image", "Video"], - value="Video" if st.session_state.get("input_video") else "Image", + value="Video" if gui.session_state.get("input_video") else "Image", ) if selected_input_type == "Video": - st.session_state.pop("input_image", None) - st.file_uploader( + gui.session_state.pop("input_image", None) + gui.file_uploader( """ #### Input Video """, key="input_video", ) else: - st.session_state.pop("input_video", None) - st.file_uploader( + gui.session_state.pop("input_video", None) + gui.file_uploader( """ #### Input Image """, @@ -107,20 +107,20 @@ def render_form_v2(self): ) if selected_input_type == "Video": - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.multiselect( + gui.multiselect( label="##### Upscaler Models", options=[e.name for e in UpscalerModels if e.supports_video], format_func=lambda x: UpscalerModels[x].label, key="selected_models", ) with col2: - st.selectbox( + gui.selectbox( label="##### Background Upscaler", options=[e.name for e in UpscalerModels if e.is_bg_model], format_func=lambda x: ( - UpscalerModels[x].label if x else st.BLANK_OPTION + UpscalerModels[x].label if x else gui.BLANK_OPTION ), allow_none=True, key="selected_bg_model", @@ -132,7 +132,7 @@ def render_form_v2(self): key="selected_models", ) - st.slider( + gui.slider( """ ### Scale Factor to scale image by @@ -147,19 +147,19 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_UPSCALER_META_IMG def render_description(self): - st.write( + gui.write( """ Have an old photo or just a funky AI picture? Run this workflow to compare the top image upscalers. """ ) def render_output(self): - _render_outputs(st.session_state) + _render_outputs(gui.session_state) def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.image(state.get("input_image"), caption="Input Image") + gui.image(state.get("input_image"), caption="Input Image") with col2: _render_outputs(state) @@ -185,8 +185,8 @@ def _render_outputs(state): for key in state.get("selected_models") or []: img = (state.get("output_images") or {}).get(key) if img: - st.image(img, caption=UpscalerModels[key].label, show_download_button=True) + gui.image(img, caption=UpscalerModels[key].label, show_download_button=True) vid = (state.get("output_videos") or {}).get(key) if vid: - st.video(vid, caption=UpscalerModels[key].label, show_download_button=True) + gui.video(vid, caption=UpscalerModels[key].label, show_download_button=True) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 02d6682de..ef4b186b5 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector @@ -68,19 +68,19 @@ def animation_prompts_editor( input_prompt_key: str = "input_prompt", ): st_list_key = f"{animation_prompts_key}/st_list" - if st_list_key in st.session_state: - prompt_st_list = st.session_state[st_list_key] + if st_list_key in gui.session_state: + prompt_st_list = gui.session_state[st_list_key] else: - animation_prompts = st.session_state.get( + animation_prompts = gui.session_state.get( animation_prompts_key ) or input_prompt_to_animation_prompts( - st.session_state.get(input_prompt_key, "0:") + gui.session_state.get(input_prompt_key, "0:") ) prompt_st_list = animation_prompts_to_st_list(animation_prompts) - st.session_state[st_list_key] = prompt_st_list + gui.session_state[st_list_key] = prompt_st_list - st.write("#### ๐Ÿ‘ฉโ€๐Ÿ’ป Animation Prompts") - st.caption( + gui.write("#### ๐Ÿ‘ฉโ€๐Ÿ’ป Animation Prompts") + gui.caption( """ Describe the scenes or series of images that you want to generate into an animation. You can add as many prompts as you like. Mention the keyframe number for each prompt i.e. the transition point from the first prompt to the next. View the โ€˜Detailsโ€™ drop down menu to get started. @@ -91,33 +91,33 @@ def animation_prompts_editor( fp_key = fp["key"] frame_key = f"{st_list_key}/frame/{fp_key}" prompt_key = f"{st_list_key}/prompt/{fp_key}" - if frame_key not in st.session_state: - st.session_state[frame_key] = fp["frame"] - if prompt_key not in st.session_state: - st.session_state[prompt_key] = fp["prompt"] + if frame_key not in gui.session_state: + gui.session_state[frame_key] = fp["frame"] + if prompt_key not in gui.session_state: + gui.session_state[prompt_key] = fp["prompt"] - col1, col2 = st.columns([8, 3], responsive=False) + col1, col2 = gui.columns([8, 3], responsive=False) with col1: - st.text_area( + gui.text_area( label="*Prompt*", key=prompt_key, height=100, ) with col2: - st.number_input( + gui.number_input( label="*Frame*", key=frame_key, min_value=0, step=1, ) - if st.button("๐Ÿ—‘๏ธ", help=f"Remove Frame {idx + 1}"): + if gui.button("๐Ÿ—‘๏ธ", help=f"Remove Frame {idx + 1}"): prompt_st_list.pop(idx) - st.experimental_rerun() + gui.rerun() updated_st_list.append( { - "frame": st.session_state.get(frame_key), - "prompt": st.session_state.get(prompt_key), + "frame": gui.session_state.get(frame_key), + "prompt": gui.session_state.get(prompt_key), "key": fp_key, } ) @@ -125,15 +125,15 @@ def animation_prompts_editor( prompt_st_list.clear() prompt_st_list.extend(updated_st_list) - if st.button("โž• Add a Prompt"): - max_frames = st.session_state.get("max_frames", 100) + if gui.button("โž• Add a Prompt"): + max_frames = gui.session_state.get("max_frames", 100) if prompt_st_list: next_frame = get_last_frame(prompt_st_list) next_frame += max(min(max_frames - next_frame, 10), 1) else: next_frame = 0 if next_frame > max_frames: - st.error("Please increase Frame Count") + gui.error("Please increase Frame Count") else: prompt_st_list.append( { @@ -142,12 +142,12 @@ def animation_prompts_editor( "key": str(uuid.uuid1()), } ) - st.experimental_rerun() + gui.rerun() - st.session_state[animation_prompts_key] = st_list_to_animation_prompt( + gui.session_state[animation_prompts_key] = st_list_to_animation_prompt( prompt_st_list ) - st.caption( + gui.caption( """ Pro-tip: To avoid abrupt endings on your animation, ensure that the last keyframe prompt is set for a higher number of keyframes/time than the previous transition rate. There should be an ample number of frames between the last frame and the total frame count of the animation. """ @@ -220,9 +220,9 @@ def related_workflows(self) -> list: def render_form_v2(self): animation_prompts_editor() - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.slider( + gui.slider( """ #### Frame Count Choose the number of frames in your animation. @@ -232,7 +232,7 @@ def render_form_v2(self): step=10, key="max_frames", ) - st.caption( + gui.caption( """ Pro-tip: The more frames you add, the longer it will take to render the animation. Test your prompts before adding more frames. """ @@ -249,10 +249,10 @@ def get_raw_price(self, state: dict) -> float: return max_frames * CREDITS_PER_FRAME def validate_form_v2(self): - prompt_list = st.session_state.get("animation_prompts") + prompt_list = gui.session_state.get("animation_prompts") assert prompt_list, "Please provide animation prompts" - max_frames = st.session_state["max_frames"] + max_frames = gui.session_state["max_frames"] assert ( get_last_frame(prompt_list) <= max_frames ), "Please make sure that Frame Count matches the Animation Prompts" @@ -261,7 +261,7 @@ def render_usage_guide(self): youtube_video("sUvica6UuQU") def render_settings(self): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: enum_selector( AnimationModels, @@ -272,11 +272,11 @@ def render_settings(self): use_selectbox=True, ) - animation_mode = st.selectbox( + animation_mode = gui.selectbox( "Animation Mode", key="animation_mode", options=["2D", "3D"] ) - st.text_input( + gui.text_input( """ ###### Zoom How should the camera zoom in or out? This setting scales the canvas size, multiplicatively. @@ -284,19 +284,19 @@ def render_settings(self): """, key="zoom", ) - st.caption( + gui.caption( """ With 0 as the starting keyframe, the input of 0: (1.004) can be used to zoom in moderately, starting at frame 0 and continuing until the end. """ ) - st.text_input( + gui.text_input( """ ###### Horizontal Pan How should the camera pan horizontally? This parameter uses positive values to move right and negative values to move left. """, key="translation_x", ) - st.text_input( + gui.text_input( """ ###### Vertical Pan How should the camera pan vertically? This parameter uses positive values to move up and negative values to move down. @@ -304,28 +304,28 @@ def render_settings(self): key="translation_y", ) if animation_mode == "3D": - st.text_input( + gui.text_input( """ ###### Roll Clockwise/Counterclockwise Gradually moves the camera on a focal axis. Roll the camera clockwise or counterclockwise in a specific degree per frame. This parameter uses positive values to roll counterclockwise and negative values to roll clockwise. E.g. use `0:(-1), 20:(0)` to roll the camera 1 degree clockwise for the first 20 frames. """, key="rotation_3d_z", ) - st.text_input( + gui.text_input( """ ###### Pan Left/Right Pans the canvas left or right in degrees per frame. This parameter uses positive values to pan right and negative values to pan left. """, key="rotation_3d_y", ) - st.text_input( + gui.text_input( """ ###### Tilt Up/Down Tilts the camera up or down in degrees per frame. This parameter uses positive values to tilt up and negative values to tilt down. """, key="rotation_3d_x", ) - st.slider( + gui.slider( """ ###### FPS (Frames per second) Choose fps for the video. @@ -336,7 +336,7 @@ def render_settings(self): key="fps", ) - # st.selectbox( + # gui.selectbox( # """ # ###### Sampler # What Stable Diffusion sampler should be used. @@ -365,7 +365,7 @@ def preview_description(self, state: dict) -> str: return "Create AI-generated Animation without relying on complex CoLab notebooks. Input your prompts + keyframes and bring your ideas to life using the animation capabilities of Gooey & Stable Diffusion's Deforum. For more help on how to use the tool visit https://www.help.gooey.ai/learn-animation" def render_description(self): - st.markdown( + gui.markdown( f""" - Every Submit will require approximately 3-5 minutes to render. @@ -377,7 +377,7 @@ def render_description(self): """ ) - st.markdown( + gui.markdown( """ #### Resources: @@ -391,8 +391,8 @@ def render_description(self): """ ) - st.write("---") - st.markdown( + gui.write("---") + gui.markdown( """ Animation Length: You can indicate how long you want your animation to be by increasing or decreasing your frame count. @@ -407,7 +407,7 @@ def render_description(self): Use the Camera Settings to generate animations with depth and other 3D parameters. """ ) - st.markdown( + gui.markdown( """ Prompt Construction Tip: @@ -421,20 +421,20 @@ def render_description(self): ) def render_output(self): - output_video = st.session_state.get("output_video") + output_video = gui.session_state.get("output_video") if output_video: - st.write("#### Output Video") - st.video(output_video, autoplay=True, show_download_button=True) + gui.write("#### Output Video") + gui.video(output_video, autoplay=True, show_download_button=True) def estimate_run_duration(self): # in seconds - return st.session_state.get("max_frames", 100) * MODEL_ESTIMATED_TIME_PER_FRAME + return gui.session_state.get("max_frames", 100) * MODEL_ESTIMATED_TIME_PER_FRAME def render_example(self, state: dict): display = self.preview_input(state) - st.markdown("```lua\n" + display + "\n```") + gui.markdown("```lua\n" + display + "\n```") - st.video(state.get("output_video"), autoplay=True) + gui.video(state.get("output_video"), autoplay=True) @classmethod def preview_input(cls, state: dict) -> str: diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 0ef63c4e9..010ed8ad0 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -10,7 +10,7 @@ from furl import furl from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings @@ -117,28 +117,28 @@ def render_form_v2(self): "#### ๐Ÿค– Youtube/PDF/Drive URLs", accept=("audio/*", "application/pdf", "video/*"), ) - st.text_input( + gui.text_input( "#### ๐Ÿ“Š Google Sheets URL", key="sheet_url", ) def validate_form_v2(self): - assert st.session_state.get("documents"), "Please enter Youtube/PDF/Drive URLs" - assert st.session_state.get("sheet_url"), "Please enter a Google Sheet URL" + assert gui.session_state.get("documents"), "Please enter Youtube/PDF/Drive URLs" + assert gui.session_state.get("sheet_url"), "Please enter a Google Sheet URL" def preview_description(self, state: dict) -> str: return "Transcribe YouTube videos in any language with Whisper, Google Chirp & more, run your own GPT4 prompt on each transcript and save it all to a Google Sheet. Perfect for making a YouTube-based dataset to create your own chatbot or enterprise copilot (ie. just add the finished Google sheet url to the doc section in https://gooey.ai/copilot)." def render_example(self, state: dict): render_documents(state) - st.write("**Google Sheets URL**") - st.write(state.get("sheet_url")) + gui.write("**Google Sheets URL**") + gui.write(state.get("sheet_url")) def render_usage_guide(self): youtube_video("p7ZLb-loR_4") def render_settings(self): - st.text_area( + gui.text_area( "##### ๐Ÿ‘ฉโ€๐Ÿซ Task Instructions", key="task_instructions", height=300, @@ -147,15 +147,15 @@ def render_settings(self): language_model_settings(selected_model) enum_selector(AsrModels, label="##### ASR Model", key="selected_asr_model") - st.write("---") + gui.write("---") google_translate_language_selector() - st.file_uploader( + gui.file_uploader( label=f"###### {field_title_desc(self.RequestModel, 'glossary_document')}", key="glossary_document", accept=SUPPORTED_SPREADSHEET_TYPES, ) - st.write("---") + gui.write("---") def related_workflows(self) -> list: from recipes.asr_page import AsrPage diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 7a0d0c388..3bbccef25 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -4,7 +4,7 @@ from furl import furl from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import ( @@ -92,13 +92,13 @@ def get_example_preferred_fields(self, state: dict) -> list[str]: return ["documents"] def render_form_v2(self): - st.text_area("#### Search Query", key="search_query") + gui.text_area("#### Search Query", key="search_query") bulk_documents_uploader("#### Documents") def validate_form_v2(self): - search_query = st.session_state.get("search_query", "").strip() + search_query = gui.session_state.get("search_query", "").strip() assert search_query, "Please enter a Search Query" - assert st.session_state.get("documents"), "Please provide at least 1 Document" + assert gui.session_state.get("documents"), "Please provide at least 1 Document" def related_workflows(self) -> list: from recipes.EmailFaceInpainting import EmailFaceInpaintingPage @@ -114,31 +114,31 @@ def related_workflows(self) -> list: ] def render_output(self): - render_output_with_refs(st.session_state) - refs = st.session_state.get("references", []) + render_output_with_refs(gui.session_state) + refs = gui.session_state.get("references", []) render_sources_widget(refs) def render_example(self, state: dict): render_documents(state) - st.html("**Search Query**") - st.write("```properties\n" + state.get("search_query", "") + "\n```") + gui.html("**Search Query**") + gui.write("```properties\n" + state.get("search_query", "") + "\n```") render_output_with_refs(state, 200) def render_settings(self): - st.text_area( + gui.text_area( "##### ๐Ÿ‘ฉโ€๐Ÿซ Task Instructions", key="task_instructions", height=300, ) - st.write("---") + gui.write("---") selected_model = language_model_selector() language_model_settings(selected_model) - st.write("---") - st.write("##### ๐Ÿ”Ž Document Search Settings") + gui.write("---") + gui.write("##### ๐Ÿ”Ž Document Search Settings") citation_style_selector() doc_extract_selector(self.request and self.request.user) query_instructions_widget() - st.write("---") + gui.write("---") doc_search_advanced_settings() def preview_image(self, state: dict) -> str | None: @@ -148,7 +148,7 @@ def preview_description(self, state: dict) -> str: return "Add your PDF, Word, HTML or Text docs, train our AI on them with OpenAI embeddings & vector search and then process results with a GPT3 script. This workflow is perfect for anything NOT in ChatGPT: 250-page compliance PDFs, training manuals, your diary, etc." def render_steps(self): - render_doc_search_step(st.session_state) + render_doc_search_step(gui.session_state) def render_usage_guide(self): youtube_video("Xe4L_dQ2KvU") @@ -224,7 +224,7 @@ def get_raw_price(self, state: dict) -> float: def additional_notes(self): try: - model = LargeLanguageModels[st.session_state["selected_model"]].value + model = LargeLanguageModels[gui.session_state["selected_model"]].value except KeyError: model = "LLM" return f"\n*Breakdown: {math.ceil(self.get_total_linked_usage_cost_in_credits())} ({model}) + {self.PROFIT_CREDITS}/run*" @@ -234,29 +234,29 @@ def render_documents(state, label="**Documents**", *, key="documents"): documents = state.get(key, []) if not documents: return - st.write(label) + gui.write(label) for doc in documents: if is_user_uploaded_url(doc): f = furl(doc) filename = f.path.segments[-1] else: filename = doc - st.write(f"๐Ÿ”—[*{filename}*]({doc})") + gui.write(f"๐Ÿ”—[*{filename}*]({doc})") def render_doc_search_step(state: dict): final_search_query = state.get("final_search_query") if final_search_query: - st.text_area("**Final Search Query**", value=final_search_query, disabled=True) + gui.text_area("**Final Search Query**", value=final_search_query, disabled=True) references = state.get("references") if references: - st.write("**References**") - st.json(references, expanded=False) + gui.write("**References**") + gui.json(references, expanded=False) final_prompt = state.get("final_prompt") if final_prompt: - st.text_area( + gui.text_area( "**Final Prompt**", value=final_prompt, height=400, @@ -265,7 +265,7 @@ def render_doc_search_step(state: dict): output_text = state.get("output_text", []) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( f"**Output Text**", help=f"output {idx}", disabled=True, diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index f26398575..18412e197 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -4,7 +4,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.asr import AsrModels from daras_ai_v2.base import BasePage @@ -95,10 +95,10 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): bulk_documents_uploader("#### ๐Ÿ“Ž Documents") - st.text_area("#### ๐Ÿ‘ฉโ€๐Ÿ’ป Instructions", key="task_instructions") + gui.text_area("#### ๐Ÿ‘ฉโ€๐Ÿ’ป Instructions", key="task_instructions") def render_settings(self): - st.text_area( + gui.text_area( """ ##### ๐Ÿ“„+๐Ÿ“„ Merge Instructions Prompt for merging several outputs together @@ -112,7 +112,7 @@ def render_settings(self): # """, # key="chain_type", # ) - st.write("---") + gui.write("---") selected_model = language_model_selector() language_model_settings(selected_model) @@ -121,38 +121,38 @@ def preview_description(self, state: dict) -> str: return "Upload any collection of PDFs, docs and/or audio files and we'll transcribe them. Then give any GPT based instruction and we'll do a map-reduce and return the result. Great for summarizing large data sets to create structured data. Check out the examples for more." def validate_form_v2(self): - search_query = st.session_state.get("task_instructions", "").strip() + search_query = gui.session_state.get("task_instructions", "").strip() assert search_query, "Please enter the Instructions" - assert st.session_state.get("documents"), "Please provide at least 1 Document" + assert gui.session_state.get("documents"), "Please provide at least 1 Document" def render_output(self): - render_output_with_refs(st.session_state) + render_output_with_refs(gui.session_state) def render_example(self, state: dict): render_documents(state) - st.write("**Instructions**") - st.write("```properties\n" + state.get("task_instructions", "") + "\n```") + gui.write("**Instructions**") + gui.write("```properties\n" + state.get("task_instructions", "") + "\n```") render_output_with_refs(state, 200) def render_steps(self): - prompt_tree = st.session_state.get("prompt_tree", {}) + prompt_tree = gui.session_state.get("prompt_tree", {}) if prompt_tree: - st.write("**Prompt Tree**") - st.json(prompt_tree, expanded=False) + gui.write("**Prompt Tree**") + gui.json(prompt_tree, expanded=False) - final_prompt = st.session_state.get("final_prompt") + final_prompt = gui.session_state.get("final_prompt") if final_prompt: - st.text_area( + gui.text_area( "**Final Prompt**", value=final_prompt, disabled=True, ) else: - st.div() + gui.div() - output_text: list = st.session_state.get("output_text", []) + output_text: list = gui.session_state.get("output_text", []) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( f"**Output Text**", help=f"output {idx}", disabled=True, diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index ab51aba3d..e5b9b24a7 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -4,7 +4,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import db, settings @@ -100,7 +100,7 @@ def preview_description(self, state: dict) -> str: return "Find an email's public photo and then draw the face into an AI generated scene using your own prompt + the latest Stable Diffusion or DallE image generator." def render_description(self): - st.write( + gui.write( """ *EmailID > Profile pic > Face Masking + Zoom > Stable Diffusion > GFPGAN > Email* @@ -118,7 +118,7 @@ def render_description(self): ) def render_form_v2(self): - st.text_area( + gui.text_area( """ #### Prompt Describe the scene that you'd like to generate around the face. @@ -126,14 +126,14 @@ def render_form_v2(self): key="text_prompt", placeholder="winter's day in paris", ) - if "__photo_source" not in st.session_state: - st.session_state["__photo_source"] = ( + if "__photo_source" not in gui.session_state: + gui.session_state["__photo_source"] = ( "Email Address" - if st.session_state.get("email_address") + if gui.session_state.get("email_address") else "Twitter Handle" ) - source = st.radio( + source = gui.radio( """ #### Photo Source From where we should get the photo?""", @@ -141,7 +141,7 @@ def render_form_v2(self): key="__photo_source", ) if source == "Email Address": - st.text_input( + gui.text_input( """ #### Email Address Give us your email address and we'll try to get your photo @@ -149,9 +149,9 @@ def render_form_v2(self): key="email_address", placeholder="john@appleseed.com", ) - st.session_state["twitter_handle"] = None + gui.session_state["twitter_handle"] = None else: - st.text_input( + gui.text_input( """ #### Twitter Handle Give us your twitter handle, we'll try to get your photo from there @@ -159,28 +159,28 @@ def render_form_v2(self): key="twitter_handle", max_chars=15, ) - st.session_state["email_address"] = None + gui.session_state["email_address"] = None def validate_form_v2(self): - text_prompt = st.session_state.get("text_prompt") - email_address = st.session_state.get("email_address") - twitter_handle = st.session_state.get("twitter_handle") + text_prompt = gui.session_state.get("text_prompt") + email_address = gui.session_state.get("email_address") + twitter_handle = gui.session_state.get("twitter_handle") assert text_prompt, "Please provide a Prompt and your Email Address" - if st.session_state.get("twitter_handle"): + if gui.session_state.get("twitter_handle"): assert re.fullmatch( twitter_handle_regex, twitter_handle ), "Please provide a valid Twitter Handle" - elif st.session_state.get("email_address"): + elif gui.session_state.get("email_address"): assert re.fullmatch( email_regex, email_address ), "Please provide a valid Email Address" else: raise AssertionError("Please provide an Email Address or Twitter Handle") - from_email = st.session_state.get("email_from") - email_subject = st.session_state.get("email_subject") - email_body = st.session_state.get("email_body") + from_email = gui.session_state.get("email_from") + email_subject = gui.session_state.get("email_subject") + email_body = gui.session_state.get("email_body") assert ( from_email and email_subject and email_body ), "Please provide a From Email, Subject & Body" @@ -203,43 +203,43 @@ def render_usage_guide(self): def render_settings(self): super().render_settings() - st.write( + gui.write( """ ### Email settings """ ) - st.checkbox( + gui.checkbox( "Send email", key="should_send_email", ) - st.text_input( + gui.text_input( label="From", key="email_from", ) - st.text_input( + gui.text_input( label="Cc (You can enter multiple emails separated by comma)", key="email_cc", placeholder="john@gmail.com, cathy@gmail.com", ) - st.text_input( + gui.text_input( label="Bcc (You can enter multiple emails separated by comma)", key="email_bcc", placeholder="john@gmail.com, cathy@gmail.com", ) - st.text_input( + gui.text_input( label="Subject", key="email_subject", ) - st.checkbox( + gui.checkbox( label="Enable HTML Body", key="email_body_enable_html", ) - st.text_area( + gui.text_area( label="Body (use {{output_images}} to insert the images into the email)", key="email_body", ) - st.text_area( + gui.text_area( label="Fallback Body (in case of failure)", key="fallback_email_body", ) @@ -247,10 +247,10 @@ def render_settings(self): def render_output(self): super().render_output() - if st.session_state.get("email_sent"): - st.write(f"โœ… Email sent to {st.session_state.get('email_address')}") + if gui.session_state.get("email_sent"): + gui.write(f"โœ… Email sent to {gui.session_state.get('email_address')}") else: - st.div() + gui.div() def run(self, state: dict): request: EmailFaceInpaintingPage.RequestModel = self.RequestModel.parse_obj( @@ -335,13 +335,13 @@ def _get_email_body( def render_example(self, state: dict): if state.get("email_address"): - st.write("**Input Email** -", state.get("email_address")) + gui.write("**Input Email** -", state.get("email_address")) elif state.get("twitter_handle"): - st.write("**Input Twitter Handle** -", state.get("twitter_handle")) + gui.write("**Input Twitter Handle** -", state.get("twitter_handle")) output_images = state.get("output_images") if output_images: for img in output_images: - st.image( + gui.image( img, caption="```" + state.get("text_prompt", "").replace("\n", "") diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 40e4274e0..1e97a3fbe 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -5,7 +5,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.extract_face import extract_and_reposition_face_cv2 from daras_ai.image_input import ( @@ -88,7 +88,7 @@ def preview_description(self, state: dict) -> str: return "Upload & extract a face into an AI-generated photo using your text + the latest Stable Diffusion or DallE image generator." def render_description(self): - st.write( + gui.write( """ This recipe takes a photo with a face and then uses the text prompt to paint a background. @@ -104,7 +104,7 @@ def render_description(self): ) def render_form_v2(self): - st.text_area( + gui.text_area( """ #### Prompt Describe the character that you'd like to generate. @@ -113,7 +113,7 @@ def render_form_v2(self): placeholder="Iron man", ) - st.file_uploader( + gui.file_uploader( """ #### Face Photo Give us a photo of yourself, or anyone else @@ -123,16 +123,16 @@ def render_form_v2(self): ) def validate_form_v2(self): - text_prompt = st.session_state.get("text_prompt") - input_image = st.session_state.get("input_image") + text_prompt = gui.session_state.get("text_prompt") + input_image = gui.session_state.get("input_image") assert text_prompt and input_image, "Please provide a Prompt and a Face Photo" def render_settings(self): img_model_settings(InpaintingModels) - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.slider( + gui.slider( "##### Upscaling", min_value=1.0, max_value=4.0, @@ -140,35 +140,35 @@ def render_settings(self): key="upscale_factor", ) - st.write("---") + gui.write("---") - st.write( + gui.write( """ #### Face Repositioning Settings """ ) - st.write("How _big_ should the face look?") - col1, _ = st.columns(2) + gui.write("How _big_ should the face look?") + col1, _ = gui.columns(2) with col1: - face_scale = st.slider( + face_scale = gui.slider( "Scale", min_value=0.1, max_value=1.0, key="face_scale", ) - st.write("_Where_ would you like to place the face in the scene?") - col1, col2 = st.columns(2) + gui.write("_Where_ would you like to place the face in the scene?") + col1, col2 = gui.columns(2) with col1: - pos_x = st.slider( + pos_x = gui.slider( "Position X", min_value=0.0, max_value=1.0, key="face_pos_x", ) with col2: - pos_y = st.slider( + pos_y = gui.slider( "Position Y", min_value=0.0, max_value=1.0, @@ -183,8 +183,8 @@ def render_settings(self): img, _ = extract_and_reposition_face_cv2( img_cv2, out_size=( - st.session_state["output_width"], - st.session_state["output_height"], + gui.session_state["output_width"], + gui.session_state["output_height"], ), out_face_scale=face_scale, out_pos_x=pos_x, @@ -193,57 +193,57 @@ def render_settings(self): repositioning_preview_img(img) def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") - output_images = st.session_state.get("output_images") + text_prompt = gui.session_state.get("text_prompt", "") + output_images = gui.session_state.get("output_images") if output_images: - st.write("#### Output Image") + gui.write("#### Output Image") for url in output_images: - st.image(url, show_download_button=True) + gui.image(url, show_download_button=True) else: - st.div() + gui.div() def render_steps(self): - input_file = st.session_state.get("input_file") - input_image = st.session_state.get("input_image") + input_file = gui.session_state.get("input_file") + input_image = gui.session_state.get("input_image") input_image_or_file = input_image or input_file - output_images = st.session_state.get("output_images") + output_images = gui.session_state.get("output_images") - col1, col2, col3, col4 = st.columns(4) + col1, col2, col3, col4 = gui.columns(4) with col1: if input_image_or_file: - st.image(input_image_or_file, caption="Input Image") + gui.image(input_image_or_file, caption="Input Image") else: - st.div() + gui.div() with col2: - resized_image = st.session_state.get("resized_image") + resized_image = gui.session_state.get("resized_image") if resized_image: - st.image(resized_image, caption="Repositioned Face") + gui.image(resized_image, caption="Repositioned Face") else: - st.div() + gui.div() - face_mask = st.session_state.get("face_mask") + face_mask = gui.session_state.get("face_mask") if face_mask: - st.image(face_mask, caption="Face Mask") + gui.image(face_mask, caption="Face Mask") else: - st.div() + gui.div() with col3: - diffusion_images = st.session_state.get("diffusion_images") + diffusion_images = gui.session_state.get("diffusion_images") if diffusion_images: for url in diffusion_images: - st.image(url, caption="Generated Image") + gui.image(url, caption="Generated Image") else: - st.div() + gui.div() with col4: if output_images: for url in output_images: - st.image(url, caption="gfpgan - Face Restoration") + gui.image(url, caption="gfpgan - Face Restoration") else: - st.div() + gui.div() def render_usage_guide(self): youtube_video("To4Oc_d4Nus") @@ -321,17 +321,17 @@ def related_workflows(self) -> list: ] def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col2: output_images = state.get("output_images") if output_images: for img in output_images: - st.image(img, caption="Generated Image") + gui.image(img, caption="Generated Image") with col1: input_image = state.get("input_image") - st.image(input_image, caption="Input Image") - st.write("**Prompt**") - st.write("```properties\n" + state.get("text_prompt", "") + "\n```") + gui.image(input_image, caption="Input Image") + gui.write("**Prompt**") + gui.write("```properties\n" + state.get("text_prompt", "") + "\n```") def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") diff --git a/recipes/Functions.py b/recipes/Functions.py index 4850214ef..e376bc061 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -3,7 +3,7 @@ import requests from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2 import settings from daras_ai_v2.base import BasePage @@ -56,7 +56,7 @@ def run_v2( request: "FunctionsPage.RequestModel", response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: - query_params = st.get_query_params() + query_params = gui.get_query_params() run_id = query_params.get("run_id") uid = query_params.get("uid") tag = f"run_id={run_id}&uid={uid}" @@ -75,7 +75,7 @@ def run_v2( response.error = data.get("error") def render_form_v2(self): - st.code_editor( + gui.code_editor( label="##### " + field_title_desc(self.RequestModel, "code"), key="code", language="javascript", @@ -86,21 +86,21 @@ def render_variables(self): variables_input(template_keys=["code"], allow_add=True) def render_output(self): - if error := st.session_state.get("error"): - with st.tag("pre", className="bg-danger bg-opacity-25"): - st.html(error) + if error := gui.session_state.get("error"): + with gui.tag("pre", className="bg-danger bg-opacity-25"): + gui.html(error) - if return_value := st.session_state.get("return_value"): - st.write("**Return value**") - st.json(return_value) + if return_value := gui.session_state.get("return_value"): + gui.write("**Return value**") + gui.json(return_value) - logs = st.session_state.get("logs") + logs = gui.session_state.get("logs") if not logs: return - st.write("---") - st.write("**Logs**") - with st.tag( + gui.write("---") + gui.write("**Logs**") + with gui.tag( "pre", style=dict(maxHeight=500, overflowY="auto"), className="bg-light p-2" ): for i, log in enumerate(logs): @@ -112,7 +112,7 @@ def render_output(self): borderClass = "border-top" else: borderClass = "" - st.html( + gui.html( log.get("message"), className=f"d-block py-1 {borderClass} {textClass}", ) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 81b09bd97..6a9eaf999 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -3,7 +3,7 @@ from furl import furl from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import ( @@ -114,43 +114,43 @@ class ResponseModel(BaseModel): final_search_query: str | None def render_form_v2(self): - st.text_area("#### Google Search Query", key="search_query") - st.text_input("Search on a specific site *(optional)*", key="site_filter") + gui.text_area("#### Google Search Query", key="search_query") + gui.text_input("Search on a specific site *(optional)*", key="site_filter") def validate_form_v2(self): - assert st.session_state.get( + assert gui.session_state.get( "search_query", "" ).strip(), "Please enter a search query" def render_output(self): - render_output_with_refs(st.session_state) + render_output_with_refs(gui.session_state) - refs = st.session_state.get("references", []) + refs = gui.session_state.get("references", []) render_sources_widget(refs) def render_example(self, state: dict): - st.write("**Search Query**") - st.write("```properties\n" + state.get("search_query", "") + "\n```") + gui.write("**Search Query**") + gui.write("```properties\n" + state.get("search_query", "") + "\n```") site_filter = state.get("site_filter") if site_filter: - st.write(f"**Site** \\\n{site_filter}") + gui.write(f"**Site** \\\n{site_filter}") render_output_with_refs(state, 200) def render_settings(self): - st.text_area( + gui.text_area( "### Task Instructions", key="task_instructions", height=300, ) - st.write("---") + gui.write("---") selected_model = language_model_selector() language_model_settings(selected_model) - st.write("---") + gui.write("---") serp_search_settings() - st.write("---") - st.write("##### ๐Ÿ”Ž Document Search Settings") + gui.write("---") + gui.write("##### ๐Ÿ”Ž Document Search Settings") query_instructions_widget() - st.write("---") + gui.write("---") doc_search_advanced_settings() def related_workflows(self) -> list: @@ -176,41 +176,41 @@ def render_usage_guide(self): youtube_video("mcscNaUIosA") def render_steps(self): - final_search_query = st.session_state.get("final_search_query") + final_search_query = gui.session_state.get("final_search_query") if final_search_query: - st.text_area( + gui.text_area( "**Final Search Query**", value=final_search_query, disabled=True ) - serp_results = st.session_state.get( - "serp_results", st.session_state.get("scaleserp_results") + serp_results = gui.session_state.get( + "serp_results", gui.session_state.get("scaleserp_results") ) if serp_results: - st.write("**Web Search Results**") - st.json(serp_results) + gui.write("**Web Search Results**") + gui.json(serp_results) - final_prompt = st.session_state.get("final_prompt") + final_prompt = gui.session_state.get("final_prompt") if final_prompt: - st.text_area( + gui.text_area( "**Final Prompt**", value=final_prompt, height=400, disabled=True, ) - output_text: list = st.session_state.get("output_text", []) + output_text: list = gui.session_state.get("output_text", []) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( f"**Output Text**", help=f"output {idx}", disabled=True, value=text, ) - references = st.session_state.get("references", []) + references = gui.session_state.get("references", []) if references: - st.write("**References**") - st.json(references) + gui.write("**References**") + gui.json(references) def run_v2( self, diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index 86cd58157..a94e812ce 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -3,7 +3,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import ( upload_file_from_bytes, @@ -93,7 +93,7 @@ def related_workflows(self): ] def render_description(self): - st.write( + gui.write( """ This workflow creates unique, relevant images to help your site rank well for a given search query. @@ -178,7 +178,7 @@ def run(self, state: dict): ) def render_form_v2(self): - st.text_input( + gui.text_input( """ #### ๐Ÿ”Ž Google Image Search Type a query you'd use in [Google image search](https://images.google.com/?gws_rd=ssl) @@ -186,13 +186,13 @@ def render_form_v2(self): key="search_query", ) model_selector(Img2ImgModels) - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt Describe how you want to edit the photo in words """, key="text_prompt", - disabled=st.session_state.get("selected_model") is None, + disabled=gui.session_state.get("selected_model") is None, ) def render_usage_guide(self): @@ -203,29 +203,31 @@ def render_settings(self): serp_search_location_selectbox() def render_output(self): - out_imgs = st.session_state.get("output_images") + out_imgs = gui.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="#### Generated Image", show_download_button=True) + gui.image( + img, caption="#### Generated Image", show_download_button=True + ) else: - st.div() + gui.div() def render_steps(self): - image_urls = st.session_state.get("image_urls") + image_urls = gui.session_state.get("image_urls") if image_urls: - st.write("**Image URLs**") - st.json(image_urls, expanded=False) + gui.write("**Image URLs**") + gui.json(image_urls, expanded=False) else: - st.div() + gui.div() - selected_image = st.session_state.get("selected_image") + selected_image = gui.session_state.get("selected_image") if selected_image: - st.image(selected_image, caption="Selected Image") + gui.image(selected_image, caption="Selected Image") else: - st.div() + gui.div() def render_example(self, state: dict): - st.write( + gui.write( f""" **Google Search Query** `{state.get("search_query", '')}` \\ **Prompt** `{state.get("text_prompt", '')}` @@ -234,7 +236,7 @@ def render_example(self, state: dict): out_imgs = state.get("output_images") if out_imgs: - st.image(out_imgs[0], caption="Generated Image") + gui.image(out_imgs[0], caption="Generated Image") def preview_description(self, state: dict) -> str: return "Enter a Google Image Search query + your Img2Img text prompt describing how to alter the result to create a unique, relevant ai generated images for any search query." diff --git a/recipes/GoogleTTS.py b/recipes/GoogleTTS.py index 49de78c04..a3cdc6415 100644 --- a/recipes/GoogleTTS.py +++ b/recipes/GoogleTTS.py @@ -1,5 +1,5 @@ import uuid -import gooey_ui as st +import gooey_gui as gui from google.cloud import texttospeech from enum import Enum @@ -25,23 +25,23 @@ class VoiceGender(Enum): def main(): - st.write("# GOOGLE Text To Speach") + gui.write("# GOOGLE Text To Speach") - with st.form(key="send_email", clear_on_submit=False): - voice_name = st.text_input(label="Voice name", value="en-US-Neural2-F") - st.write( + with gui.form(key="send_email", clear_on_submit=False): + voice_name = gui.text_input(label="Voice name", value="en-US-Neural2-F") + gui.write( "Get more voice names [here](https://cloud.google.com/text-to-speech/docs/voices)" ) - text = st.text_area(label="Text input", value="This is a test.") - pitch = st.slider("Pitch", min_value=-20.0, max_value=20.0, value=0.0) - speaking_rate = st.slider( + text = gui.text_area(label="Text input", value="This is a test.") + pitch = gui.slider("Pitch", min_value=-20.0, max_value=20.0, value=0.0) + speaking_rate = gui.slider( "Speaking rate (1.0 is the normal native speed)", min_value=0.25, max_value=4.0, value=1.0, ) - # voice_gender = st.selectbox("Voice", (voice.name for voice in VoiceGender)) - submitted = st.form_submit_button("Generate") + # voice_gender = gui.selectbox("Voice", (voice.name for voice in VoiceGender)) + submitted = gui.form_submit_button("Generate") if submitted: client = texttospeech.TextToSpeechClient(credentials=credentials) @@ -58,23 +58,23 @@ def main(): # Perform the text-to-speech request on the text input with the selected # voice parameters and audio file type - with st.spinner("Generating audio..."): + with gui.spinner("Generating audio..."): response = client.synthesize_speech( input=synthesis_input, voice=voice, audio_config=audio_config ) if not response: - st.error("Error: Audio generation failed") + gui.error("Error: Audio generation failed") return - with st.spinner("Uploading file..."): + with gui.spinner("Uploading file..."): audio_url = upload_file_from_bytes( f"google_tts_{uuid.uuid4()}.mp3", response.audio_content ) if not audio_url: - st.error("Error: Uploading failed") + gui.error("Error: Uploading failed") return - st.audio(audio_url) + gui.audio(audio_url) main() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index d35256ba2..f886c2b4e 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -5,7 +5,7 @@ import numpy as np from daras_ai_v2.pydantic_validation import FieldHttpUrl import requests -import gooey_ui as st +import gooey_gui as gui from pydantic import BaseModel from bots.models import Workflow @@ -85,7 +85,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.file_uploader( + gui.file_uploader( """ #### Input Photo Give us a photo of anything @@ -95,7 +95,7 @@ def render_form_v2(self): ) def validate_form_v2(self): - input_image = st.session_state.get("input_image") + input_image = gui.session_state.get("input_image") assert input_image, "Please provide an Input Photo" def render_settings(self): @@ -105,7 +105,7 @@ def render_settings(self): key="selected_model", ) - st.slider( + gui.slider( """ #### Edge Threshold Helps to remove edge artifacts. `0` will turn this off. `0.9` will aggressively cut down edges. @@ -115,63 +115,63 @@ def render_settings(self): key="mask_threshold", ) - st.write( + gui.write( """ #### Fix Skewed Perspective Automatically transform the perspective of the image to make objects look like a perfect rectangle """ ) - st.checkbox( + gui.checkbox( "Fix Skewed Perspective", key="rect_persepective_transform", ) - st.write( + gui.write( """ #### Add reflections """ ) - col1, _ = st.columns(2) + col1, _ = gui.columns(2) with col1: - st.slider("Reflection Opacity", key="reflection_opacity") + gui.slider("Reflection Opacity", key="reflection_opacity") - # st.write( + # gui.write( # """ # ##### Add Drop shadow # """ # ) - # col1, _ = st.columns(2) + # col1, _ = gui.columns(2) # with col1: - # st.slider("Shadow ", key="reflection_opacity") + # gui.slider("Shadow ", key="reflection_opacity") - st.write( + gui.write( """ #### Object Repositioning Settings """ ) - st.write("How _big_ should the object look?") - col1, _ = st.columns(2) + gui.write("How _big_ should the object look?") + col1, _ = gui.columns(2) with col1: - obj_scale = st.slider( + obj_scale = gui.slider( "Scale", min_value=0.1, max_value=1.0, key="obj_scale", ) - st.write("_Where_ would you like to place the object in the scene?") - col1, col2 = st.columns(2) + gui.write("_Where_ would you like to place the object in the scene?") + col1, col2 = gui.columns(2) with col1: - pos_x = st.slider( + pos_x = gui.slider( "Position X", min_value=0.0, max_value=1.0, key="obj_pos_x", ) with col2: - pos_y = st.slider( + pos_y = gui.slider( "Position Y", min_value=0.0, max_value=1.0, @@ -298,61 +298,61 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield def render_output(self): - self.render_example(st.session_state) + self.render_example(gui.session_state) def render_steps(self): - col1, col2, col3, col4 = st.columns(4) + col1, col2, col3, col4 = gui.columns(4) with col1: - input_image = st.session_state.get("input_image") + input_image = gui.session_state.get("input_image") if input_image: - st.image(input_image, caption="Input Photo") + gui.image(input_image, caption="Input Photo") else: - st.div() + gui.div() with col2: - output_image = st.session_state.get("output_image") + output_image = gui.session_state.get("output_image") if output_image: - st.image(output_image, caption=f"Segmentation Mask") + gui.image(output_image, caption=f"Segmentation Mask") else: - st.div() + gui.div() with col3: - resized_image = st.session_state.get("resized_image") + resized_image = gui.session_state.get("resized_image") if resized_image: - st.image(resized_image, caption=f"Resized Image") + gui.image(resized_image, caption=f"Resized Image") else: - st.div() + gui.div() - resized_mask = st.session_state.get("resized_mask") + resized_mask = gui.session_state.get("resized_mask") if resized_mask: - st.image(resized_mask, caption=f"Resized Mask") + gui.image(resized_mask, caption=f"Resized Mask") else: - st.div() + gui.div() with col4: - cutout_image = st.session_state.get("cutout_image") + cutout_image = gui.session_state.get("cutout_image") if cutout_image: - st.image(cutout_image, caption=f"Cutout Image") + gui.image(cutout_image, caption=f"Cutout Image") else: - st.div() + gui.div() def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: input_image = state.get("input_image") if input_image: - st.image(input_image, caption="Input Photo", show_download_button=True) + gui.image(input_image, caption="Input Photo", show_download_button=True) else: - st.div() + gui.div() with col2: cutout_image = state.get("cutout_image") if cutout_image: - st.image(cutout_image, caption=f"Cutout Image") + gui.image(cutout_image, caption=f"Cutout Image") else: - st.div() + gui.div() def preview_description(self, state: dict) -> str: return "Use Dichotomous Image Segmentation to remove unwanted backgrounds from your images and correct perspective. Awesome when used with other Gooey.AI steps." diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 1bee3efd2..d41deb1f5 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -4,7 +4,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.img_model_settings_widgets import img_model_settings @@ -93,7 +93,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.file_uploader( + gui.file_uploader( """ #### Input Image """, @@ -101,7 +101,7 @@ def render_form_v2(self): upload_meta=dict(resize=f"{SD_IMG_MAX_SIZE[0] * SD_IMG_MAX_SIZE[1]}@>"), ) - st.text_area( + gui.text_area( """ #### Prompt Describe your edits @@ -111,11 +111,11 @@ def render_form_v2(self): ) def validate_form_v2(self): - input_image = st.session_state.get("input_image") + input_image = gui.session_state.get("input_image") assert input_image, "Please provide an Input Image" def render_description(self): - st.write( + gui.write( """ This recipe takes an image and a prompt and then attempts to alter the image, based on the text. @@ -130,24 +130,24 @@ def render_usage_guide(self): youtube_video("narcZNyuNAg") def render_output(self): - output_images = st.session_state.get("output_images", []) + output_images = gui.session_state.get("output_images", []) if not output_images: return - st.write("#### Output Image") + gui.write("#### Output Image") for img in output_images: - st.image(img, show_download_button=True) + gui.image(img, show_download_button=True) def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col2: output_images = state.get("output_images", []) for img in output_images: - st.image(img, caption="Generated Image") + gui.image(img, caption="Generated Image") with col1: input_image = state.get("input_image") - st.image(input_image, caption="Input Image") - st.write("**Prompt**") - st.write("```properties\n" + state.get("text_prompt", "") + "\n```") + gui.image(input_image, caption="Input Image") + gui.write("**Prompt**") + gui.write("```properties\n" + state.get("text_prompt", "") + "\n```") def run(self, state: dict) -> typing.Iterator[str | None]: request: Img2ImgPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/LetterWriter.py b/recipes/LetterWriter.py index 76c46add6..ff39fb3aa 100644 --- a/recipes/LetterWriter.py +++ b/recipes/LetterWriter.py @@ -4,7 +4,7 @@ import requests from pydantic.main import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.text_format import daras_ai_format_str from daras_ai_v2.base import BasePage @@ -47,7 +47,7 @@ class ResponseModel(BaseModel): final_prompt: str def render_description(self): - st.write( + gui.write( """ *ID > Call Custom API > Build Training Data > GPT3* @@ -61,21 +61,21 @@ def render_description(self): ) def render_form_v2(self): - st.text_input( + gui.text_input( "### Action ID", key="action_id", ) - col1, col2 = st.columns(2, gap="medium") + col1, col2 = gui.columns(2, gap="medium") with col1: - st.slider( + gui.slider( label="Number of Outputs", key="num_outputs", min_value=1, max_value=4, ) with col2: - st.slider( + gui.slider( label="Quality", key="quality", min_value=1.0, @@ -84,14 +84,14 @@ def render_form_v2(self): ) def render_settings(self): - st.write("### Model Settings") + gui.write("### Model Settings") - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: # select text api api_provider_options = ["openai", "goose.ai"] - api_provider = st.selectbox( + api_provider = gui.selectbox( label="Language Model Provider", options=api_provider_options, key="lm_selected_api", @@ -131,13 +131,13 @@ def render_settings(self): raise ValueError() with col2: - st.selectbox( + gui.selectbox( label="Engine", options=engine_choices, key="lm_selected_engine", ) - st.slider( + gui.slider( """ ##### Model Risk Factor @@ -152,9 +152,9 @@ def render_settings(self): max_value=1.0, ) - st.write("---") + gui.write("---") - st.text_area( + gui.text_area( """ ### Task description Briefly describe the task for the language model @@ -162,11 +162,11 @@ def render_settings(self): key="prompt_header", ) - st.write("---") + gui.write("---") - st.write("### Example letters") + gui.write("### Example letters") - st.write( + gui.write( """ A set of example letters for the model to learn your writing style """ @@ -174,11 +174,11 @@ def render_settings(self): text_training_data("Talking points", "Letter", key="example_letters") - st.write("---") + gui.write("---") - st.write("### Custom API settings") + gui.write("### Custom API settings") - st.write( + gui.write( """ Call any external API to get the talking points from an input Action ID @@ -186,29 +186,29 @@ def render_settings(self): """ ) - col1, col2 = st.columns([1, 4]) + col1, col2 = gui.columns([1, 4]) with col1: - st.text_input( + gui.text_input( "HTTP Method", key="api_http_method", ) with col2: - st.text_input( + gui.text_input( "URL", key="api_url", ) - st.text_area( + gui.text_area( "Headers as JSON (optional)", key="api_headers", ) - st.text_area( + gui.text_area( "JSON Body (optional)", key="api_json_body", ) - st.write("---") + gui.write("---") - st.text_area( + gui.text_area( """ ##### Input Talking Points (Prompt) @@ -219,7 +219,7 @@ def render_settings(self): """, key="input_prompt", ) - st.checkbox("Strip all HTML -> Text?", key="strip_html_2_text") + gui.checkbox("Strip all HTML -> Text?", key="strip_html_2_text") def run(self, state: dict) -> typing.Iterator[str | None]: yield "Calling API.." @@ -304,14 +304,14 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - st.write("### Generated Letters") - output_letters = st.session_state.get( + gui.write("### Generated Letters") + output_letters = gui.session_state.get( "output_letters", # this default value makes a nicer output while running :) - [""] * st.session_state["num_outputs"], + [""] * gui.session_state["num_outputs"], ) for idx, out in enumerate(output_letters): - st.text_area( + gui.text_area( f"output {idx}", label_visibility="collapsed", value=out, @@ -320,22 +320,22 @@ def render_output(self): ) def render_steps(self): - response_json = st.session_state.get("response_json", {}) - st.write("**API Response**") - st.json( + response_json = gui.session_state.get("response_json", {}) + gui.write("**API Response**") + gui.json( response_json, expanded=False, ) - input_prompt = st.session_state.get("generated_input_prompt", "") - st.text_area( + input_prompt = gui.session_state.get("generated_input_prompt", "") + gui.text_area( "**Input Talking Points (Prompt)**", value=input_prompt, disabled=True, ) - final_prompt = st.session_state.get("final_prompt", "") - st.text_area( + final_prompt = gui.session_state.get("final_prompt", "") + gui.text_area( "**Final Language Model Prompt**", value=final_prompt, disabled=True, diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 260ba9630..df1354272 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -3,7 +3,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector @@ -36,7 +36,7 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_LIPSYNC_META_IMG def render_form_v2(self): - st.file_uploader( + gui.file_uploader( """ #### Input Face Upload a video/image that contains faces to use @@ -45,7 +45,7 @@ def render_form_v2(self): key="input_face", ) - st.file_uploader( + gui.file_uploader( """ #### Input Audio Upload the video/audio file to use as audio source for lipsyncing @@ -62,11 +62,11 @@ def render_form_v2(self): ) def validate_form_v2(self): - assert st.session_state.get("input_audio"), "Please provide an Audio file" - assert st.session_state.get("input_face"), "Please provide an Input Face" + assert gui.session_state.get("input_audio"), "Please provide an Audio file" + assert gui.session_state.get("input_face"), "Please provide an Input Face" def render_settings(self): - lipsync_settings(st.session_state.get("selected_model")) + lipsync_settings(gui.session_state.get("selected_model")) def run(self, state: dict) -> typing.Iterator[str | None]: request = self.RequestModel.parse_obj(state) @@ -95,13 +95,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: def render_example(self, state: dict): output_video = state.get("output_video") if output_video: - st.write("#### Output Video") - st.video(output_video, autoplay=True, show_download_button=True) + gui.write("#### Output Video") + gui.video(output_video, autoplay=True, show_download_button=True) else: - st.div() + gui.div() def render_output(self): - self.render_example(st.session_state) + self.render_example(gui.session_state) def related_workflows(self) -> list: from recipes.DeforumSD import DeforumSDPage @@ -120,7 +120,7 @@ def preview_description(self, state: dict) -> str: def get_cost_note(self) -> str | None: multiplier = ( 3 - if st.session_state.get("lipsync_model") == LipsyncModel.SadTalker.name + if gui.session_state.get("lipsync_model") == LipsyncModel.SadTalker.name else 1 ) return f"{CREDITS_PER_MB * multiplier} credits per MB" diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 8ca7df4e5..79983f36e 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -3,7 +3,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel @@ -52,7 +52,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.file_uploader( + gui.file_uploader( """ #### Input Face Upload a video/image that contains faces to use @@ -60,7 +60,7 @@ def render_form_v2(self): """, key="input_face", ) - st.text_area( + gui.text_area( """ #### Input Text This generates audio for your video @@ -78,10 +78,10 @@ def render_form_v2(self): text_to_speech_provider_selector(self) def validate_form_v2(self): - assert st.session_state.get( + assert gui.session_state.get( "text_prompt", "" ).strip(), "Text input cannot be empty" - assert st.session_state.get("input_face"), "Please provide an Input Face" + assert gui.session_state.get("input_face"), "Please provide an Input Face" def preview_image(self, state: dict) -> str | None: return DEFAULT_LIPSYNC_TTS_META_IMG @@ -90,7 +90,7 @@ def preview_description(self, state: dict) -> str: return "Add your text prompt, pick a voice & upload a sample video to quickly create realistic lipsync videos. Discover the ease of text-to-video AI." def render_description(self): - st.write( + gui.write( """ This recipe takes any text and a video of a person (plus the voice defined in Settings) to create a lipsync'd video of that person speaking your text. @@ -104,8 +104,8 @@ def render_description(self): ) def render_steps(self): - audio_url = st.session_state.get("audio_url") - st.audio(audio_url, caption="Output Audio", show_download_button=True) + audio_url = gui.session_state.get("audio_url") + gui.audio(audio_url, caption="Output Audio", show_download_button=True) def render_settings(self): LipsyncPage.render_settings(self) @@ -125,17 +125,17 @@ def run(self, state: dict) -> typing.Iterator[str | None]: def render_example(self, state: dict): output_video = state.get("output_video") if output_video: - st.video( + gui.video( output_video, caption="#### Output Video", autoplay=True, show_download_button=True, ) else: - st.div() + gui.div() def render_output(self): - self.render_example(st.session_state) + self.render_example(gui.session_state) def get_raw_price(self, state: dict): # _get_tts_provider comes from TextToSpeechPage diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index ab510e741..3ec1f6f89 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -4,7 +4,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import ( upload_file_from_bytes, @@ -94,7 +94,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.text_area( + gui.text_area( """ #### Prompt Describe the scene that you'd like to generate. @@ -103,7 +103,7 @@ def render_form_v2(self): placeholder="Iron man", ) - st.file_uploader( + gui.file_uploader( """ #### Object Photo Give us a photo of anything @@ -112,12 +112,12 @@ def render_form_v2(self): ) def validate_form_v2(self): - text_prompt = st.session_state.get("text_prompt", "").strip() - input_image = st.session_state.get("input_image") + text_prompt = gui.session_state.get("text_prompt", "").strip() + input_image = gui.session_state.get("input_image") assert text_prompt and input_image, "Please provide a Prompt and a Object Photo" def render_description(self): - st.write( + gui.write( """ This recipe an image of an object, masks it and then renders the background around the object according to the prompt. @@ -133,35 +133,35 @@ def render_description(self): def render_settings(self): img_model_settings(InpaintingModels) - st.write("---") + gui.write("---") - st.write( + gui.write( """ #### Object Repositioning Settings """ ) - st.write("How _big_ should the object look?") - col1, _ = st.columns(2) + gui.write("How _big_ should the object look?") + col1, _ = gui.columns(2) with col1: - obj_scale = st.slider( + obj_scale = gui.slider( "Scale", min_value=0.1, max_value=1.0, key="obj_scale", ) - st.write("_Where_ would you like to place the object in the scene?") - col1, col2 = st.columns(2) + gui.write("_Where_ would you like to place the object in the scene?") + col1, col2 = gui.columns(2) with col1: - pos_x = st.slider( + pos_x = gui.slider( "Position X", min_value=0.0, max_value=1.0, key="obj_pos_x", ) with col2: - pos_y = st.slider( + pos_y = gui.slider( "Position Y", min_value=0.0, max_value=1.0, @@ -181,12 +181,12 @@ def render_settings(self): pos_x=pos_x, pos_y=pos_y, out_size=( - st.session_state["output_width"], - st.session_state["output_height"], + gui.session_state["output_width"], + gui.session_state["output_height"], ), ) - st.slider( + gui.slider( """ ##### Edge Threshold Helps to remove edge artifacts. `0` will turn this off. `0.9` will aggressively cut down edges. @@ -197,47 +197,47 @@ def render_settings(self): ) def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") - output_images = st.session_state.get("output_images") + text_prompt = gui.session_state.get("text_prompt", "") + output_images = gui.session_state.get("output_images") if output_images: for url in output_images: - st.image(url, caption=f"{text_prompt}", show_download_button=True) + gui.image(url, caption=f"{text_prompt}", show_download_button=True) else: - st.div() + gui.div() def render_steps(self): - input_file = st.session_state.get("input_file") - input_image = st.session_state.get("input_image") + input_file = gui.session_state.get("input_file") + input_image = gui.session_state.get("input_image") input_image_or_file = input_image or input_file - col1, col2, col3 = st.columns(3) + col1, col2, col3 = gui.columns(3) with col1: if input_image_or_file: - st.image(input_image_or_file, caption="Input Image") + gui.image(input_image_or_file, caption="Input Image") else: - st.div() + gui.div() with col2: - resized_image = st.session_state.get("resized_image") + resized_image = gui.session_state.get("resized_image") if resized_image: - st.image(resized_image, caption="Repositioned Object") + gui.image(resized_image, caption="Repositioned Object") else: - st.div() + gui.div() - obj_mask = st.session_state.get("obj_mask") + obj_mask = gui.session_state.get("obj_mask") if obj_mask: - st.image(obj_mask, caption="Object Mask") + gui.image(obj_mask, caption="Object Mask") else: - st.div() + gui.div() with col3: - diffusion_images = st.session_state.get("output_images") + diffusion_images = gui.session_state.get("output_images") if diffusion_images: for url in diffusion_images: - st.image(url, caption=f"Generated Image") + gui.image(url, caption=f"Generated Image") else: - st.div() + gui.div() def run(self, state: dict): request: ObjectInpaintingPage.RequestModel = self.RequestModel.parse_obj(state) @@ -293,17 +293,17 @@ def run(self, state: dict): state["output_images"] = diffusion_images def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col2: output_images = state.get("output_images") if output_images: for img in output_images: - st.image(img, caption="Generated Image") + gui.image(img, caption="Generated Image") with col1: input_image = state.get("input_image") - st.image(input_image, caption="Input Image") - st.write("**Prompt**") - st.write("```properties\n" + state.get("text_prompt", "") + "\n```") + gui.image(input_image, caption="Input Image") + gui.write("**Prompt**") + gui.write("```properties\n" + state.get("text_prompt", "") + "\n```") def preview_description(self, state: dict) -> str: return "Upload your product photo and describe the background. Then use Stable Diffusion's Inpainting AI to create professional background scenery without the photoshoot." diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 288eb70a8..84c092182 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from pyzbar import pyzbar -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import Workflow from daras_ai.image_input import ( @@ -152,7 +152,7 @@ def get_example_preferred_fields(cls, state: dict) -> list[str]: return ["qr_code_data"] def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt Describe the subject/scene of the QR Code. @@ -163,12 +163,12 @@ def render_form_v2(self): ) qr_code_source_key = "__qr_code_source" - if qr_code_source_key not in st.session_state: + if qr_code_source_key not in gui.session_state: for key in QrSources._member_names_: - if st.session_state.get(key): - st.session_state[qr_code_source_key] = key + if gui.session_state.get(key): + gui.session_state[qr_code_source_key] = key break - source = st.horizontal_radio( + source = gui.horizontal_radio( "", options=QrSources._member_names_, key=qr_code_source_key, @@ -178,7 +178,7 @@ def render_form_v2(self): _set_selected_qr_input_field(source) match source: case QrSources.qr_code_data.name: - st.text_area( + gui.text_area( """ Enter your URL/Text below. """, @@ -187,7 +187,7 @@ def render_form_v2(self): ) case QrSources.qr_code_input_image.name: - st.file_uploader( + gui.file_uploader( """ It will be reformatted and cleaned """, @@ -196,27 +196,27 @@ def render_form_v2(self): ) case QrSources.qr_code_vcard.name: - st.caption( + gui.caption( "We'll use the prompt above to create a beautiful QR code that when scanned on a phone, will add the info below as a contact. Great for conferences and geeky parties." ) vcard_form(key=QrSources.qr_code_vcard.name) case QrSources.qr_code_file.name: - st.file_uploader( + gui.file_uploader( "Upload any file. Contact cards and PDFs work great.", key=QrSources.qr_code_file.name, ) if source != QrSources.qr_code_vcard: - st.checkbox( + gui.checkbox( "๐Ÿ”— Shorten URL", key="use_url_shortener", ) - st.caption( + gui.caption( 'A shortened URL enables the QR code to be more beautiful and less "QR-codey" with fewer blocky pixels.' ) - st.file_uploader( + gui.file_uploader( """ #### ๐Ÿž๏ธ Reference Image *[optional]* This image will be used as inspiration to blend with the QR Code. @@ -226,13 +226,13 @@ def render_form_v2(self): ) def validate_form_v2(self): - assert st.session_state.get("text_prompt"), "Please provide a prompt" + assert gui.session_state.get("text_prompt"), "Please provide a prompt" assert any( - st.session_state.get(k) for k in QrSources._member_names_ + gui.session_state.get(k) for k in QrSources._member_names_ ), "Please provide QR Code URL, text content, contact info, or upload an image" def render_description(self): - st.markdown( + gui.markdown( """ Create interactive and engaging QR codes with stunning visuals that are amazing for marketing, branding, and more. Combining AI Art and QR Code has never been easier! Enter your URL and image prompt, and in just 30 seconds, we'll generate an artistic QR codes tailored to your style. @@ -242,30 +242,30 @@ def render_description(self): prompting101() def render_steps(self): - email_import = st.session_state.get("__email_imported") + email_import = gui.session_state.get("__email_imported") if email_import: - st.markdown("#### Import contact info from email") - st.json(email_import) - shortened_url = st.session_state.get("shortened_url") + gui.markdown("#### Import contact info from email") + gui.json(email_import) + shortened_url = gui.session_state.get("shortened_url") if shortened_url: - st.markdown( + gui.markdown( f""" #### Shorten the URL For more aesthetic and reliable QR codes with fewer black squares, we automatically shorten the URL: {shortened_url} """ ) - img = st.session_state.get("cleaned_qr_code") + img = gui.session_state.get("cleaned_qr_code") if img: - st.image( + gui.image( img, caption=""" #### Generate clean QR code Having consistent padding, formatting, and using high error correction in the QR Code encoding makes the QR code more readable and robust to damage and thus yields more reliable results with the model. """, ) - raw_images = st.session_state.get("raw_images", []) + raw_images = gui.session_state.get("raw_images", []) if raw_images: - st.markdown( + gui.markdown( """ #### Generate the QR Codes We use the model and controlnet constraints to generate QR codes that blend the prompt with the cleaned QR Code. We generate them one at a time and check if they work. If they don't work, we try again. If they work, we stop. @@ -274,10 +274,10 @@ def render_steps(self): """ ) for img in raw_images: - st.image(img) - output_images = st.session_state.get("output_images", []) + gui.image(img) + output_images = gui.session_state.get("output_images", []) if output_images: - st.markdown( + gui.markdown( """ #### Run quality control We programatically scan the QR Codes to make sure they are readable. Once a working one is found, it becomes the output. @@ -286,10 +286,10 @@ def render_steps(self): """ ) for img in output_images: - st.image(img) + gui.image(img) def render_settings(self): - st.write( + gui.write( """ Customize the QR Code output for your text prompt with these Settings. """ @@ -307,29 +307,29 @@ def render_settings(self): low_explanation="At {low} the prompted visual will be intact and the QR code will be more artistic but less readable", high_explanation="At {high} the control settings that blend the QR code will be applied tightly, possibly overriding the image prompt, but the QR code will be more readable", ) - st.write("---") + gui.write("---") output_resolution_setting() - st.write( + gui.write( """ ##### โŒ– QR Positioning Use this to control where the QR code is placed in the image, and how big it should be. """, className="gui-input", ) - col1, _ = st.columns(2) + col1, _ = gui.columns(2) with col1: - obj_scale = st.slider( + obj_scale = gui.slider( "Scale", min_value=0.1, max_value=1.0, step=0.05, key="obj_scale", ) - col1, col2 = st.columns(2, responsive=False) + col1, col2 = gui.columns(2, responsive=False) with col1: - pos_x = st.slider( + pos_x = gui.slider( "Position X", min_value=0.0, max_value=1.0, @@ -337,7 +337,7 @@ def render_settings(self): key="obj_pos_x", ) with col2: - pos_y = st.slider( + pos_y = gui.slider( "Position Y", min_value=0.0, max_value=1.0, @@ -355,22 +355,22 @@ def render_settings(self): pos_x=pos_x, pos_y=pos_y, out_size=( - st.session_state["output_width"], - st.session_state["output_height"], + gui.session_state["output_width"], + gui.session_state["output_height"], ), color=255, ) - if st.session_state.get("image_prompt"): - st.write("---") - st.write( + if gui.session_state.get("image_prompt"): + gui.write("---") + gui.write( """ ##### ๐ŸŽจ Inspiration Use this to control how the image prompt should influence the output. """, className="gui-input", ) - st.slider( + gui.slider( "Inspiration Strength", min_value=0.0, max_value=1.0, @@ -384,25 +384,25 @@ def render_settings(self): checkboxes=False, allow_none=False, ) - st.write( + gui.write( """ ##### โŒ– Reference Image Positioning Use this to control where the reference image is placed, and how big it should be. """, className="gui-input", ) - col1, _ = st.columns(2) + col1, _ = gui.columns(2) with col1: - image_prompt_scale = st.slider( + image_prompt_scale = gui.slider( "Scale", min_value=0.1, max_value=1.0, step=0.05, key="image_prompt_scale", ) - col1, col2 = st.columns(2, responsive=False) + col1, col2 = gui.columns(2, responsive=False) with col1: - image_prompt_pos_x = st.slider( + image_prompt_pos_x = gui.slider( "Position X", min_value=0.0, max_value=1.0, @@ -410,7 +410,7 @@ def render_settings(self): key="image_prompt_pos_x", ) with col2: - image_prompt_pos_y = st.slider( + image_prompt_pos_y = gui.slider( "Position Y", min_value=0.0, max_value=1.0, @@ -419,7 +419,7 @@ def render_settings(self): ) img_cv2 = mask_cv2 = bytes_to_cv2_img( - requests.get(st.session_state["image_prompt"]).content, + requests.get(gui.session_state["image_prompt"]).content, ) repositioning_preview_widget( img_cv2=img_cv2, @@ -428,14 +428,14 @@ def render_settings(self): pos_x=image_prompt_pos_x, pos_y=image_prompt_pos_y, out_size=( - st.session_state["output_width"], - st.session_state["output_height"], + gui.session_state["output_width"], + gui.session_state["output_height"], ), color=255, ) def render_output(self): - state = st.session_state + state = gui.session_state self._render_outputs(state) def render_example(self, state: dict): @@ -446,7 +446,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): if max_count: output_images = output_images[:max_count] for img in output_images: - st.image(img, show_download_button=True) + gui.image(img, show_download_button=True) qr_code_data = ( state.get(QrSources.qr_code_data.name) or state.get(QrSources.qr_code_input_image.name) @@ -457,7 +457,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): continue shortened_url = state.get("shortened_url") if not shortened_url: - st.caption(qr_code_data) + gui.caption(qr_code_data) continue hashid = furl(shortened_url.strip("/")).path.segments[-1] try: @@ -465,9 +465,9 @@ def _render_outputs(self, state: dict, max_count: int | None = None): except ShortenedURL.DoesNotExist: clicks = None if clicks is not None: - st.caption(f"{shortened_url} โ†’ {qr_code_data} (Views: {clicks})") + gui.caption(f"{shortened_url} โ†’ {qr_code_data} (Views: {clicks})") else: - st.caption(f"{shortened_url} โ†’ {qr_code_data}") + gui.caption(f"{shortened_url} โ†’ {qr_code_data}") def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) @@ -557,44 +557,44 @@ def render_usage_guide(self): def vcard_form(*, key: str) -> VCARD: - vcard_data = st.session_state.get(key, {}) + vcard_data = gui.session_state.get(key, {}) # populate inputs for k in VCARD.__fields__.keys(): - st.session_state.setdefault(f"__vcard_data__{k}", vcard_data.get(k) or "") + gui.session_state.setdefault(f"__vcard_data__{k}", vcard_data.get(k) or "") vcard = VCARD.construct() - vcard.email = st.text_input( + vcard.email = gui.text_input( "Email", key="__vcard_data__email", placeholder="dev@gooey.ai" ) - if vcard.email and st.button( + if vcard.email and gui.button( "Import other contact info from my email - magic!", type="link", ): imported_vcard = get_vcard_from_email(vcard.email) if not imported_vcard or not imported_vcard.format_name: - st.error("No contact info found for that email") + gui.error("No contact info found for that email") else: vcard = imported_vcard # update inputs for k, v in vcard.dict().items(): - st.session_state[f"__vcard_data__{k}"] = v + gui.session_state[f"__vcard_data__{k}"] = v - vcard.format_name = st.text_input( + vcard.format_name = gui.text_input( "Name*", key="__vcard_data__format_name", placeholder="Supreme Overlord Alex Metzger, PhD", ) - vcard.tel = st.text_input( + vcard.tel = gui.text_input( "Phone Number", key="__vcard_data__tel", placeholder="+1 (420) 669-6969", ) - vcard.role = st.text_input("Role", key="__vcard_data__role", placeholder="Intern") + vcard.role = gui.text_input("Role", key="__vcard_data__role", placeholder="Intern") - st.session_state.setdefault("__vcard_data__urls_text", "\n".join(vcard.urls or [])) + gui.session_state.setdefault("__vcard_data__urls_text", "\n".join(vcard.urls or [])) vcard.urls = ( - st.text_area( + gui.text_area( """ Website Links *([calend.ly](https://calend.ly) works great!)* @@ -606,28 +606,28 @@ def vcard_form(*, key: str) -> VCARD: .splitlines() ) - vcard.photo_url = st.text_input( + vcard.photo_url = gui.text_input( "Photo URL", key="__vcard_data__photo_url", placeholder="https://www.gooey.ai/static/images/logo.png", ) - with st.expander("More Contact Fields"): - vcard.gender = st.text_input( + with gui.expander("More Contact Fields"): + vcard.gender = gui.text_input( "Gender", key="__vcard_data__gender", placeholder="F" ) - vcard.note = st.text_area( + vcard.note = gui.text_area( "Notes", key="__vcard_data__note", placeholder="- awesome person\n- loves pizza\n- plays tons of chess\n- absolutely a genius", ) - vcard.address = st.text_area( + vcard.address = gui.text_area( "Address", key="__vcard_data__address", - placeholder="123 Main St, San Francisco, CA 94105", + placeholder="123 Main gui, San Francisco, CA 94105", ) - st.session_state[key] = vcard.dict() + gui.session_state[key] = vcard.dict() return vcard @@ -645,7 +645,7 @@ def _set_selected_qr_input_field( - caching any previous form data from other fields with hidden keys - restoring previously saved data for this field """ - state = st.session_state + state = gui.session_state all_fields = QrSources._member_names_ if selected not in all_fields: diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index efddba5ba..6372ce65e 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.functional import apply_parallel @@ -60,32 +60,32 @@ def validate_form_v2(self): GoogleGPTPage.validate_form_v2(self) def render_steps(self): - serp_results = st.session_state.get( - "serp_results", st.session_state.get("scaleserp_results") + serp_results = gui.session_state.get( + "serp_results", gui.session_state.get("scaleserp_results") ) if serp_results: - st.write("**Web Search Results**") - st.json(serp_results) + gui.write("**Web Search Results**") + gui.json(serp_results) - output_queries = st.session_state.get("output_queries", []) + output_queries = gui.session_state.get("output_queries", []) for i, result in enumerate(output_queries): - st.write("---") - st.write(f"##### {i+1}. _{result.get('search_query')}_") + gui.write("---") + gui.write(f"##### {i+1}. _{result.get('search_query')}_") serp_results = result.get("serp_results", result.get("scaleserp_results")) if serp_results: - st.write("**Web Search Results**") - st.json(serp_results) + gui.write("**Web Search Results**") + gui.json(serp_results) render_doc_search_step(result) def render_output(self): - render_qna_outputs(st.session_state, 300) + render_qna_outputs(gui.session_state, 300) def render_example(self, state: dict): - st.write("**Search Query**") - st.write("```properties\n" + state.get("search_query", "") + "\n```") + gui.write("**Search Query**") + gui.write("```properties\n" + state.get("search_query", "") + "\n```") site_filter = state.get("site_filter") if site_filter: - st.write(f"**Site** \\\n{site_filter}") + gui.write(f"**Site** \\\n{site_filter}") render_qna_outputs(state, 200, show_count=1) def render_settings(self): diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 73b97c4d5..3f8c2d2d8 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.functional import apply_parallel @@ -57,14 +57,14 @@ def validate_form_v2(self): DocSearchPage.validate_form_v2(self) def render_output(self): - render_qna_outputs(st.session_state) + render_qna_outputs(gui.session_state) def render_example(self, state: dict): - st.write("**Search Query**") - st.write("```properties\n" + state.get("search_query", "") + "\n```") + gui.write("**Search Query**") + gui.write("```properties\n" + state.get("search_query", "") + "\n```") site_filter = state.get("site_filter") if site_filter: - st.write(f"**Site** \\\n{site_filter}") + gui.write(f"**Site** \\\n{site_filter}") render_qna_outputs(state, 200, show_count=1) def render_settings(self): @@ -87,17 +87,17 @@ def preview_description(self, state: dict) -> str: return 'This workflow finds the related queries (aka "People also ask") for a Google search, searches your doc, pdf or file (from a URL or via an upload) and then generates answers using vector DB results from your docs.' def render_steps(self): - serp_results = st.session_state.get( - "serp_results", st.session_state.get("scaleserp_results") + serp_results = gui.session_state.get( + "serp_results", gui.session_state.get("scaleserp_results") ) if serp_results: - st.write("**Web Search Results**") - st.json(serp_results) + gui.write("**Web Search Results**") + gui.json(serp_results) - output_queries = st.session_state.get("output_queries", []) + output_queries = gui.session_state.get("output_queries", []) for i, result in enumerate(output_queries): - st.write("---") - st.write(f"##### {i + 1}. _{result.get('search_query')}_") + gui.write("---") + gui.write(f"##### {i + 1}. _{result.get('search_query')}_") render_doc_search_step(result) def run_v2( @@ -150,9 +150,9 @@ def render_qna_outputs(state, height=500, show_count=None): if not output_text: continue references = result.get("references", []) - st.write(f"##### _{i + 1}. {result.get('search_query')}_") + gui.write(f"##### _{i + 1}. {result.get('search_query')}_") render_output_with_refs( {"output_text": output_text, "references": references}, height ) render_sources_widget(references) - st.html("
") + gui.html("
") diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 2f95f3bd1..62af754a5 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -9,7 +9,7 @@ from loguru import logger from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.exceptions import raise_for_status @@ -140,7 +140,7 @@ def render_usage_guide(self): youtube_video("8VDYTYWhOaw") def render_description(self): - st.write( + gui.write( """ This workflow is designed to make it incredibly easy to create a webpage that Google's search engine will rank well. @@ -160,45 +160,45 @@ def render_description(self): ) def render_form_v2(self): - st.write("#### Inputs") - st.text_input("Google Search Query", key="search_query") - st.text_input("Website Name", key="title") - st.text_input("Website URL", key="company_url") - st.text_area("Focus Keywords *(optional)*", key="keywords") + gui.write("#### Inputs") + gui.text_input("Google Search Query", key="search_query") + gui.text_input("Website Name", key="title") + gui.text_input("Website URL", key="company_url") + gui.text_area("Focus Keywords *(optional)*", key="keywords") def validate_form_v2(self): - assert st.session_state["search_query"], "Please provide Google Search Query" - assert st.session_state["title"], "Please provide Website Name" - assert st.session_state["company_url"], "Please provide Website URL" - # assert st.session_state["keywords"], "Please provide Focus Keywords" + assert gui.session_state["search_query"], "Please provide Google Search Query" + assert gui.session_state["title"], "Please provide Website Name" + assert gui.session_state["company_url"], "Please provide Website URL" + # assert gui.session_state["keywords"], "Please provide Focus Keywords" def render_settings(self): - st.text_area( + gui.text_area( "### Task Instructions", key="task_instructions", height=300, ) - # st.checkbox("Blog Generator Mode", key="enable_blog_mode") - st.checkbox("Enable Internal Cross-Linking", key="enable_crosslinks") - st.checkbox("Enable HTML Formatting", key="enable_html") + # gui.checkbox("Blog Generator Mode", key="enable_blog_mode") + gui.checkbox("Enable Internal Cross-Linking", key="enable_crosslinks") + gui.checkbox("Enable HTML Formatting", key="enable_html") selected_model = language_model_selector() language_model_settings(selected_model) - st.write("---") + gui.write("---") serp_search_settings() def render_output(self): - output_content = st.session_state.get("output_content") + output_content = gui.session_state.get("output_content") if output_content: - st.write("#### Generated Content") + gui.write("#### Generated Content") for idx, text in enumerate(output_content): - if st.session_state.get("enable_html"): + if gui.session_state.get("enable_html"): scrollable_html(text) else: - st.text_area( + gui.text_area( f"output {idx}", label_visibility="collapsed", value=text, @@ -207,47 +207,47 @@ def render_output(self): ) else: - st.div() + gui.div() def render_steps(self): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - serp_results = st.session_state.get( - "serp_results", st.session_state.get("scaleserp_results") + serp_results = gui.session_state.get( + "serp_results", gui.session_state.get("scaleserp_results") ) if serp_results: - st.write("**Web Search Results**") - st.json(serp_results) + gui.write("**Web Search Results**") + gui.json(serp_results) with col2: - search_urls = st.session_state.get("search_urls") + search_urls = gui.session_state.get("search_urls") if search_urls: - st.write("**Search URLs**") - st.json(search_urls, expanded=False) + gui.write("**Search URLs**") + gui.json(search_urls, expanded=False) else: - st.div() + gui.div() - summarized_urls = st.session_state.get("summarized_urls") + summarized_urls = gui.session_state.get("summarized_urls") if summarized_urls: - st.write("**Summarized URLs**") - st.json(summarized_urls, expanded=False) + gui.write("**Summarized URLs**") + gui.json(summarized_urls, expanded=False) else: - st.div() + gui.div() - final_prompt = st.session_state.get("final_prompt") + final_prompt = gui.session_state.get("final_prompt") if final_prompt: - st.text_area( + gui.text_area( "Final Prompt", value=final_prompt, height=400, disabled=True, ) else: - st.div() + gui.div() def render_example(self, state: dict): - st.write( + gui.write( f""" Search Query `{state.get('search_query', '')}` \\ Company Name `{state.get('title', '')}` \\ @@ -261,7 +261,7 @@ def render_example(self, state: dict): if state.get("enable_html"): scrollable_html(output_content[0], height=300) else: - st.text_area( + gui.text_area( "Generated Content", value=output_content[0], height=200, diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index dea7e121e..fa4066fc7 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -3,7 +3,7 @@ import jinja2.sandbox from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.functional import map_parallel @@ -54,7 +54,7 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_SMARTGPT_META_IMG def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt """, @@ -64,19 +64,19 @@ def render_form_v2(self): ) def render_settings(self): - st.text_area( + gui.text_area( """ ##### Step 1: CoT Prompt """, key="cot_prompt", ) - st.text_area( + gui.text_area( """ ##### Step 2: Reflexion Prompt """, key="reflexion_prompt", ) - st.text_area( + gui.text_area( """ ##### Step 3: DERA Prompt """, @@ -181,22 +181,22 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_text"] = dera_outputs def render_output(self): - render_output_with_refs(st.session_state) + render_output_with_refs(gui.session_state) def render_example(self, state: dict): - st.write("**Prompt**") - st.write("```properties\n" + state.get("input_prompt", "") + "\n```") + gui.write("**Prompt**") + gui.write("```properties\n" + state.get("input_prompt", "") + "\n```") render_output_with_refs(state, 200) def render_steps(self): - prompt_tree = st.session_state.get("prompt_tree", {}) + prompt_tree = gui.session_state.get("prompt_tree", {}) if prompt_tree: - st.write("**Prompt Tree**") - st.json(prompt_tree, expanded=True) + gui.write("**Prompt Tree**") + gui.json(prompt_tree, expanded=True) - output_text: list = st.session_state.get("output_text", []) + output_text: list = gui.session_state.get("output_text", []) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( f"**Output Text**", help=f"output {idx}", disabled=True, diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index b51682055..bc3a0dea1 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -4,7 +4,7 @@ import requests from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.text_format import daras_ai_format_str from daras_ai_v2 import settings @@ -87,7 +87,7 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_SOCIAL_LOOKUP_EMAIL_META_IMG def render_description(self): - st.write( + gui.write( """ This recipe takes an email address and a sample email body. It attempts to pull the social profile of the email address and then personlize the email using AI. @@ -107,7 +107,7 @@ def render_usage_guide(self): youtube_video("lVWQbS_rFaM") def render_form_v2(self): - st.text_input( + gui.text_input( """ #### Email Address Give us an email address and we'll try to get determine the profile data associated with it @@ -115,11 +115,11 @@ def render_form_v2(self): key="email_address", placeholder="john@appleseed.com", ) - st.caption( + gui.caption( "By providing an email address, you agree to Gooey.AI's [Privacy Policy](https://gooey.ai/privacy)" ) - st.text_area( + gui.text_area( """ #### Email Prompt """, @@ -130,15 +130,15 @@ def render_settings(self): selected_model = language_model_selector() language_model_settings(selected_model) - # st.text_input("URL 1", key="url1") - # st.text_input("URL 2", key="url2") - # st.text_input("Company", key="company") - # st.text_input("Article Title", key="article_title") - # st.text_input("Domain", key="domain") - # st.text_input("Key Words", key="key_words") + # gui.text_input("URL 1", key="url1") + # gui.text_input("URL 2", key="url2") + # gui.text_input("Company", key="company") + # gui.text_input("Article Title", key="article_title") + # gui.text_input("Domain", key="domain") + # gui.text_input("Key Words", key="key_words") def validate_form_v2(self): - email_address = st.session_state.get("email_address") + email_address = gui.session_state.get("email_address") assert email_address, "Please provide a Prompt and an Email Address" @@ -186,16 +186,16 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): - output_text = st.session_state.get("output_text", "") + output_text = gui.session_state.get("output_text", "") if not output_text: return - st.write( + gui.write( """ #### Email Body Output """ ) for idx, text in enumerate(output_text): - st.text_area( + gui.text_area( "", disabled=True, value=text, @@ -204,24 +204,24 @@ def render_output(self): ) def render_steps(self): - person_data = st.session_state.get("person_data") + person_data = gui.session_state.get("person_data") if person_data: - st.write("**Input Variables**") - st.json(_input_variables(st.session_state)) + gui.write("**Input Variables**") + gui.json(_input_variables(gui.session_state)) else: - st.div() + gui.div() - final_prompt = st.session_state.get("final_prompt") + final_prompt = gui.session_state.get("final_prompt") if final_prompt: - st.text_area("Final Prompt", disabled=True, value=final_prompt) + gui.text_area("Final Prompt", disabled=True, value=final_prompt) else: - st.div() + gui.div() def render_example(self, state: dict): - st.write("**Email Address**") - st.write(state.get("email_address", "")) - st.write("**Email Body Output**") - st.write(state.get("output_email_body", "")) + gui.write("**Email Address**") + gui.write(state.get("email_address", "")) + gui.write("**Email Body Output**") + gui.write(state.get("output_email_body", "")) def _input_variables(state: dict): diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index c0a851c44..f3199a95b 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -4,7 +4,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -64,7 +64,7 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_TEXT2AUDIO_META_IMG def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt Describe the audio that you'd like to generate. @@ -73,18 +73,18 @@ def render_form_v2(self): placeholder="Iron man", ) - st.write("#### ๐Ÿงจ Audio Model") + gui.write("#### ๐Ÿงจ Audio Model") enum_multiselect( Text2AudioModels, key="selected_models", ) def validate_form_v2(self): - assert st.session_state["text_prompt"], "Please provide a prompt" - assert st.session_state["selected_models"], "Please select at least one model" + assert gui.session_state["text_prompt"], "Please provide a prompt" + assert gui.session_state["selected_models"], "Please select at least one model" def render_settings(self): - st.slider( + gui.slider( label=""" ##### โฑ๏ธ Audio Duration (sec) """, @@ -130,12 +130,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(st.session_state) + _render_output(gui.session_state) def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: - st.markdown("```properties\n" + state.get("text_prompt", "") + "\n```") + gui.markdown("```properties\n" + state.get("text_prompt", "") + "\n```") with col2: _render_output(state) @@ -148,6 +148,6 @@ def _render_output(state): for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: - st.audio( + gui.audio( audio, caption=Text2AudioModels[key].value, show_download_button=True ) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 48a4baaf9..dbc877a17 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -5,7 +5,7 @@ import requests from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings @@ -117,7 +117,7 @@ def preview_description(self, state: dict) -> str: return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app." def render_description(self): - st.write( + gui.write( """ *Convert text into audio in the voice of your choice* @@ -130,7 +130,7 @@ def render_description(self): ) def render_form_v2(self): - st.text_area( + gui.text_area( """ #### Prompt Enter text you want to convert to speech @@ -140,11 +140,11 @@ def render_form_v2(self): text_to_speech_provider_selector(self) def validate_form_v2(self): - assert st.session_state.get("text_prompt"), "Text input cannot be empty" - assert st.session_state.get("tts_provider"), "Please select a TTS provider" + assert gui.session_state.get("text_prompt"), "Text input cannot be empty" + assert gui.session_state.get("tts_provider"), "Please select a TTS provider" def render_settings(self): - text_to_speech_settings(self, st.session_state.get("tts_provider")) + text_to_speech_settings(self, gui.session_state.get("tts_provider")) def get_raw_price(self, state: dict): tts_provider = self._get_tts_provider(state) @@ -159,8 +159,8 @@ def render_usage_guide(self): # loom_video("2d853b7442874b9cbbf3f27b98594add") def render_output(self): - audio_url = st.session_state.get("audio_url") - st.audio(audio_url, show_download_button=True) + audio_url = gui.session_state.get("audio_url") + gui.audio(audio_url, show_download_button=True) def _get_elevenlabs_price(self, state: dict): _, is_user_provided_key = self._get_elevenlabs_api_key(state) @@ -177,9 +177,9 @@ def _get_tts_provider(self, state: dict): return TextToSpeechProviders[tts_provider] def get_cost_note(self): - tts_provider = st.session_state.get("tts_provider") + tts_provider = gui.session_state.get("tts_provider") if tts_provider == TextToSpeechProviders.ELEVEN_LABS.name: - _, is_user_provided_key = self._get_elevenlabs_api_key(st.session_state) + _, is_user_provided_key = self._get_elevenlabs_api_key(gui.session_state) if is_user_provided_key: return "*No additional credit charge given we'll use your API key*" else: @@ -382,12 +382,12 @@ def run(self, state: dict): client = OpenAI() model = OpenAI_TTS_Models[ - st.session_state.get( + gui.session_state.get( "openai_tts_model", OpenAI_TTS_Models.tts_1.name ) ].value voice = OpenAI_TTS_Voices[ - st.session_state.get( + gui.session_state.get( "openai_voice_name", OpenAI_TTS_Voices.alloy.name ) ].value @@ -442,12 +442,12 @@ def related_workflows(self) -> list: ] def render_example(self, state: dict): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: text = state.get("text_prompt") if text: - st.write(text) + gui.write(text) with col2: audio_url = state.get("audio_url") if audio_url: - st.audio(audio_url) + gui.audio(audio_url) diff --git a/recipes/Translation.py b/recipes/Translation.py index 6583f78fa..dada4cf2a 100644 --- a/recipes/Translation.py +++ b/recipes/Translation.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.asr import ( TranslationModels, @@ -88,7 +88,7 @@ def related_workflows(self) -> list: ] def render_form_v2(self): - st.write("###### Source Texts") + gui.write("###### Source Texts") list_view_editor( add_btn_label="โž• Add Text", key="texts", @@ -100,7 +100,7 @@ def render_form_v2(self): key="selected_model", allow_none=False, ) - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: translation_language_selector( model=translation_model, @@ -118,21 +118,23 @@ def render_form_v2(self): def render_settings(self): try: translation_model = TranslationModels[ - st.session_state.get("selected_model") + gui.session_state.get("selected_model") ] except KeyError: translation_model = None if translation_model and translation_model.supports_glossary: - st.file_uploader( + gui.file_uploader( label=f"###### {field_title_desc(self.RequestModel, 'glossary_document')}", key="glossary_document", accept=SUPPORTED_SPREADSHEET_TYPES, ) else: - st.session_state["glossary_document"] = None + gui.session_state["glossary_document"] = None def validate_form_v2(self): - non_empty_text_inputs = [text for text in st.session_state.get("texts") if text] + non_empty_text_inputs = [ + text for text in gui.session_state.get("texts") if text + ] assert non_empty_text_inputs, "Please provide at least 1 non-empty text input" def render_output(self): @@ -142,22 +144,22 @@ def render_example(self, state: dict): text_outputs("**Translations**", value=state.get("output_texts", [])) def render_steps(self): - st.markdown( + gui.markdown( """ 1. Apply Transliteration as necessary. """ ) - st.markdown( + gui.markdown( """ 2. Detect the source language if not provided. """ ) - st.markdown( + gui.markdown( """ 3. Translate with the selected API (for Auto, we look up the optimal API based on the detected language and script). """ ) - st.markdown( + gui.markdown( """ 4. Apply romanization if requested and applicable. """ @@ -173,10 +175,10 @@ def get_raw_price(self, state: dict): def render_text_input(key: str, del_key: str, d: dict): - col1, col2 = st.columns([8, 1], responsive=False) + col1, col2 = gui.columns([8, 1], responsive=False) with col1: - with st.div(className="pt-1"): - d["text"] = st.text_area( + with gui.div(className="pt-1"): + d["text"] = gui.text_area( "", label_visibility="collapsed", key=key + ":text", diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index b0fe0b6bc..46ce36ef5 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -3,11 +3,11 @@ import mimetypes import typing +import gooey_gui as gui from django.db.models import QuerySet, Q from furl import furl from pydantic import BaseModel, Field -import gooey_ui as st from bots.models import BotIntegration, Platform from bots.models import Workflow from celeryapp.tasks import send_integration_attempt_email @@ -78,7 +78,6 @@ from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query -from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.search_ref import ( parse_refs, @@ -95,7 +94,6 @@ ) from daras_ai_v2.vector_search import DocSearchRequest from functions.recipe_functions import LLMTools -from gooey_ui import RedirectException from recipes.DocSearch import ( get_top_k_references, references_as_prompt, @@ -307,7 +305,7 @@ def get_submit_container_props(self): return {} def render_description(self): - st.write( + gui.write( """ Have you ever wanted to create a bot that you could talk to about anything? Ever wanted to create your own https://dara.network/RadBots or https://Farmer.CHAT? This is how. @@ -327,7 +325,7 @@ def render_description(self): ) def render_form_v2(self): - st.text_area( + gui.text_area( """ #### ๐Ÿ“ Instructions """, @@ -345,23 +343,23 @@ def render_form_v2(self): accept=["audio/*", "application/*", "video/*", "text/*"], ) - st.markdown("#### Capabilities") - if st.checkbox( + gui.markdown("#### Capabilities") + if gui.checkbox( "##### ๐Ÿ—ฃ๏ธ Text to Speech & Lipsync", - value=bool(st.session_state.get("tts_provider")), + value=bool(gui.session_state.get("tts_provider")), ): text_to_speech_provider_selector(self) - st.write("---") + gui.write("---") - enable_video = st.checkbox( + enable_video = gui.checkbox( "##### ๐Ÿซฆ Add Lipsync Video", - value=bool(st.session_state.get("input_face")), + value=bool(gui.session_state.get("input_face")), ) else: - st.session_state["tts_provider"] = None + gui.session_state["tts_provider"] = None enable_video = False if enable_video: - st.file_uploader( + gui.file_uploader( """ ###### ๐Ÿ‘ฉโ€๐Ÿฆฐ Input Face Upload a video or image (with a human face) to lipsync responses. mp4, mov, png, jpg or gif preferred. @@ -374,20 +372,20 @@ def render_form_v2(self): key="lipsync_model", use_selectbox=True, ) - st.write("---") + gui.write("---") else: - st.session_state["input_face"] = None - st.session_state.pop("lipsync_model", None) + gui.session_state["input_face"] = None + gui.session_state.pop("lipsync_model", None) - if st.checkbox( + if gui.checkbox( "##### ๐Ÿ”  Translation & Speech Recognition", value=bool( - st.session_state.get("user_language") - or st.session_state.get("asr_model") + gui.session_state.get("user_language") + or gui.session_state.get("asr_model") ), ): - st.caption(field_desc(self.RequestModel, "user_language")) - col1, col2 = st.columns(2) + gui.caption(field_desc(self.RequestModel, "user_language")) + col1, col2 = gui.columns(2) with col1: translation_model = translation_model_selector(allow_none=False) with col2: @@ -396,9 +394,9 @@ def render_form_v2(self): label=f"###### {field_title(self.RequestModel, 'user_language')}", key="user_language", ) - st.write("---") + gui.write("---") - col1, col2 = st.columns(2, responsive=False) + col1, col2 = gui.columns(2, responsive=False) with col1: selected_model = enum_selector( AsrModels, @@ -416,37 +414,37 @@ def render_form_v2(self): key="asr_language", ) else: - st.caption( + gui.caption( f"We'll automatically select an [ASR](https://gooey.ai/asr) model for you based on the {field_title(self.RequestModel, 'user_language')}." ) - st.write("---") + gui.write("---") else: - st.session_state["translation_model"] = None - st.session_state["asr_model"] = None - st.session_state["user_language"] = None + gui.session_state["translation_model"] = None + gui.session_state["asr_model"] = None + gui.session_state["user_language"] = None - if st.checkbox( + if gui.checkbox( "##### ๐Ÿฉป Photo & Document Intelligence", value=bool( - st.session_state.get("document_model"), + gui.session_state.get("document_model"), ), ): if settings.AZURE_FORM_RECOGNIZER_KEY: doc_model_descriptions = azure_form_recognizer_models() else: doc_model_descriptions = {} - st.selectbox( + gui.selectbox( f"{field_desc(self.RequestModel, 'document_model')}", key="document_model", options=doc_model_descriptions, format_func=lambda x: f"{doc_model_descriptions[x]} ({x})", ) else: - st.session_state["document_model"] = None + gui.session_state["document_model"] = None def validate_form_v2(self): - input_glossary = st.session_state.get("input_glossary_document", "") - output_glossary = st.session_state.get("output_glossary_document", "") + input_glossary = gui.session_state.get("input_glossary_document", "") + output_glossary = gui.session_state.get("output_glossary_document", "") if input_glossary: validate_glossary_document(input_glossary) if output_glossary: @@ -456,57 +454,57 @@ def render_usage_guide(self): youtube_video("-j2su1r8pEg") def render_settings(self): - tts_provider = st.session_state.get("tts_provider") + tts_provider = gui.session_state.get("tts_provider") if tts_provider: text_to_speech_settings(self, tts_provider) - st.write("---") + gui.write("---") - lipsync_model = st.session_state.get("lipsync_model") + lipsync_model = gui.session_state.get("lipsync_model") if lipsync_model: lipsync_settings(lipsync_model) - st.write("---") + gui.write("---") - translation_model = st.session_state.get( + translation_model = gui.session_state.get( "translation_model", TranslationModels.google.name ) if ( - st.session_state.get("user_language") + gui.session_state.get("user_language") and TranslationModels[translation_model].supports_glossary ): - st.markdown("##### ๐Ÿ”  Translation Settings") - enable_glossary = st.checkbox( + gui.markdown("##### ๐Ÿ”  Translation Settings") + enable_glossary = gui.checkbox( "๐Ÿ“– Add Glossary", value=bool( - st.session_state.get("input_glossary_document") - or st.session_state.get("output_glossary_document") + gui.session_state.get("input_glossary_document") + or gui.session_state.get("output_glossary_document") ), ) if enable_glossary: - st.caption( + gui.caption( """ Provide a glossary to customize translation and improve accuracy of domain-specific terms. If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing). """ ) - st.file_uploader( + gui.file_uploader( f"##### {field_title_desc(self.RequestModel, 'input_glossary_document')}", key="input_glossary_document", accept=SUPPORTED_SPREADSHEET_TYPES, ) - st.file_uploader( + gui.file_uploader( f"##### {field_title_desc(self.RequestModel, 'output_glossary_document')}", key="output_glossary_document", accept=SUPPORTED_SPREADSHEET_TYPES, ) else: - st.session_state["input_glossary_document"] = None - st.session_state["output_glossary_document"] = None - st.write("---") + gui.session_state["input_glossary_document"] = None + gui.session_state["output_glossary_document"] = None + gui.write("---") - documents = st.session_state.get("documents") + documents = gui.session_state.get("documents") if documents: - st.write("#### ๐Ÿ“„ Knowledge Base") - st.text_area( + gui.write("#### ๐Ÿ“„ Knowledge Base") + gui.text_area( """ ###### ๐Ÿ‘ฉโ€๐Ÿซ Search Instructions How should the LLM interpret the results from your knowledge base? @@ -516,13 +514,13 @@ def render_settings(self): ) citation_style_selector() - st.checkbox("๐Ÿ”— Shorten Citation URLs", key="use_url_shortener") + gui.checkbox("๐Ÿ”— Shorten Citation URLs", key="use_url_shortener") doc_extract_selector(self.request and self.request.user) - st.write("---") + gui.write("---") - st.markdown( + gui.markdown( """ #### Advanced Settings In general, you should not need to adjust these. @@ -532,14 +530,14 @@ def render_settings(self): if documents: query_instructions_widget() keyword_instructions_widget() - st.write("---") + gui.write("---") doc_search_advanced_settings() - st.write("---") + gui.write("---") - st.write("##### ๐Ÿ”  Language Model Settings") - language_model_settings(st.session_state.get("selected_model")) + gui.write("##### ๐Ÿ”  Language Model Settings") + language_model_settings(gui.session_state.get("selected_model")) - st.write("---") + gui.write("---") enum_multiselect( enum_cls=LLMTools, @@ -565,25 +563,25 @@ def run_as_api_tab(self): def render_example(self, state: dict): input_prompt = state.get("input_prompt") if input_prompt: - st.write( + gui.write( "**Prompt**\n```properties\n" + truncate_text_words(input_prompt, maxlen=200) + "\n```" ) - st.write("**Response**") + gui.write("**Response**") output_video = state.get("output_video") if output_video: - st.video(output_video[0], autoplay=True) + gui.video(output_video[0], autoplay=True) output_text = state.get("output_text") if output_text: - st.write(output_text[0], line_clamp=5) + gui.write(output_text[0], line_clamp=5) def render_output(self): # chat window - with st.div(className="pb-3"): + with gui.div(className="pb-3"): chat_list_view() pressed_send, new_inputs = chat_input_view() @@ -591,28 +589,28 @@ def render_output(self): self.on_send(*new_inputs) # clear chat inputs - if st.button("๐Ÿ—‘๏ธ Clear"): - st.session_state["messages"] = [] - st.session_state["input_prompt"] = "" - st.session_state["input_images"] = None - st.session_state["input_audio"] = None - st.session_state["input_documents"] = None - st.session_state["raw_input_text"] = "" + if gui.button("๐Ÿ—‘๏ธ Clear"): + gui.session_state["messages"] = [] + gui.session_state["input_prompt"] = "" + gui.session_state["input_images"] = None + gui.session_state["input_audio"] = None + gui.session_state["input_documents"] = None + gui.session_state["raw_input_text"] = "" self.clear_outputs() - st.session_state["final_keyword_query"] = "" - st.session_state["final_search_query"] = "" - st.experimental_rerun() + gui.session_state["final_keyword_query"] = "" + gui.session_state["final_search_query"] = "" + gui.rerun() # render sources - references = st.session_state.get("references", []) + references = gui.session_state.get("references", []) if not references: return key = "sources-expander" - with st.expander("๐Ÿ’โ€โ™€๏ธ Sources", key=key): - if not st.session_state.get(key): + with gui.expander("๐Ÿ’โ€โ™€๏ธ Sources", key=key): + if not gui.session_state.get(key): return for idx, ref in enumerate(references): - st.write(f"**{idx + 1}**. [{ref['title']}]({ref['url']})") + gui.write(f"**{idx + 1}**. [{ref['title']}]({ref['url']})") text_output( "Source Document", value=ref["snippet"], @@ -626,17 +624,17 @@ def on_send( new_input_audio: str, new_input_documents: list[str], ): - prev_input = st.session_state.get("raw_input_text") or "" - prev_output = (st.session_state.get("raw_output_text") or [""])[0] - prev_input_images = st.session_state.get("input_images") - prev_input_audio = st.session_state.get("input_audio") - prev_input_documents = st.session_state.get("input_documents") + prev_input = gui.session_state.get("raw_input_text") or "" + prev_output = (gui.session_state.get("raw_output_text") or [""])[0] + prev_input_images = gui.session_state.get("input_images") + prev_input_audio = gui.session_state.get("input_audio") + prev_input_documents = gui.session_state.get("input_documents") if ( prev_input or prev_input_images or prev_input_audio or prev_input_documents ) and prev_output: # append previous input to the history - st.session_state["messages"] = st.session_state.get("messages", []) + [ + gui.session_state["messages"] = gui.session_state.get("messages", []) + [ format_chat_entry( role=CHATML_ROLE_USER, content=prev_input, @@ -654,59 +652,59 @@ def on_send( furl(url.strip("/")).path.segments[-1] for url in new_input_documents ) new_input_text = f"Files: {filenames}\n\n{new_input_text}" - st.session_state["input_prompt"] = new_input_text - st.session_state["input_audio"] = new_input_audio or None - st.session_state["input_images"] = new_input_images or None - st.session_state["input_documents"] = new_input_documents or None + gui.session_state["input_prompt"] = new_input_text + gui.session_state["input_audio"] = new_input_audio or None + gui.session_state["input_images"] = new_input_images or None + gui.session_state["input_documents"] = new_input_documents or None self.on_submit() def render_steps(self): - if st.session_state.get("tts_provider"): - st.video(st.session_state.get("input_face"), caption="Input Face") + if gui.session_state.get("tts_provider"): + gui.video(gui.session_state.get("input_face"), caption="Input Face") - final_search_query = st.session_state.get("final_search_query") + final_search_query = gui.session_state.get("final_search_query") if final_search_query: - st.text_area( + gui.text_area( "###### `final_search_query`", value=final_search_query, disabled=True ) - final_keyword_query = st.session_state.get("final_keyword_query") + final_keyword_query = gui.session_state.get("final_keyword_query") if final_keyword_query: if isinstance(final_keyword_query, list): - st.write("###### `final_keyword_query`") - st.json(final_keyword_query) + gui.write("###### `final_keyword_query`") + gui.json(final_keyword_query) else: - st.text_area( + gui.text_area( "###### `final_keyword_query`", value=str(final_keyword_query), disabled=True, ) - references = st.session_state.get("references", []) + references = gui.session_state.get("references", []) if references: - st.write("###### `references`") - st.json(references) + gui.write("###### `references`") + gui.json(references) - final_prompt = st.session_state.get("final_prompt") + final_prompt = gui.session_state.get("final_prompt") if final_prompt: if isinstance(final_prompt, str): text_output("###### `final_prompt`", value=final_prompt, height=300) else: - st.write("###### `final_prompt`") - st.json(final_prompt) + gui.write("###### `final_prompt`") + gui.json(final_prompt) for k in ["raw_output_text", "output_text", "raw_tts_text"]: - for idx, text in enumerate(st.session_state.get(k) or []): - st.text_area( + for idx, text in enumerate(gui.session_state.get(k) or []): + gui.text_area( f"###### `{k}[{idx}]`", value=text, disabled=True, ) - for idx, audio_url in enumerate(st.session_state.get("output_audio", [])): - st.write(f"###### `output_audio[{idx}]`") - st.audio(audio_url) + for idx, audio_url in enumerate(gui.session_state.get("output_audio", [])): + gui.write(f"###### `output_audio[{idx}]`") + gui.audio(audio_url) def get_raw_price(self, state: dict): total = self.get_total_linked_usage_cost_in_credits() + self.PROFIT_CREDITS @@ -725,18 +723,18 @@ def get_raw_price(self, state: dict): def additional_notes(self): try: - model = LargeLanguageModels[st.session_state["selected_model"]].value + model = LargeLanguageModels[gui.session_state["selected_model"]].value except KeyError: model = "LLM" notes = f"\n*Breakdown: {math.ceil(self.get_total_linked_usage_cost_in_credits())} ({model}) + {self.PROFIT_CREDITS}/run*" if ( - st.session_state.get("tts_provider") + gui.session_state.get("tts_provider") == TextToSpeechProviders.ELEVEN_LABS.name ): notes += f" *+ {TextToSpeechPage().get_cost_note()} (11labs)*" - if st.session_state.get("input_face"): + if gui.session_state.get("input_face"): notes += " *+ 1 (lipsync)*" return notes @@ -821,7 +819,7 @@ def run_v2( # consturct the system prompt bot_script = (request.bot_script or "").strip() if bot_script: - bot_script = render_prompt_vars(bot_script, st.session_state) + bot_script = render_prompt_vars(bot_script, gui.session_state) # insert to top system_prompt = {"role": CHATML_ROLE_SYSTEM, "content": bot_script} else: @@ -847,7 +845,7 @@ def run_v2( response.final_search_query = generate_final_search_query( request=request, instructions=query_instructions, - context={**st.session_state, "messages": chat_history}, + context={**gui.session_state, "messages": chat_history}, ) else: query_msgs.reverse() @@ -865,7 +863,7 @@ def run_v2( generate_final_search_query( request=k_request, instructions=keyword_instructions, - context={**st.session_state, "messages": chat_history}, + context={**gui.session_state, "messages": chat_history}, response_format_type="json_object", ), ) @@ -878,7 +876,7 @@ def run_v2( response.references = yield from get_top_k_references( DocSearchRequest.parse_obj( { - **st.session_state, + **gui.session_state, "search_query": response.final_search_query, "keyword_query": response.final_keyword_query, }, @@ -896,7 +894,7 @@ def run_v2( if response.references: # add task instructions task_instructions = render_prompt_vars( - request.task_instructions, st.session_state + request.task_instructions, gui.session_state ) user_input = ( references_as_prompt(response.references) @@ -1011,7 +1009,7 @@ def run_v2( response.output_audio = [] for text in response.raw_tts_text or response.raw_output_text: tts_state = TextToSpeechPage.RequestModel.parse_obj( - {**st.session_state, "text_prompt": text} + {**gui.session_state, "text_prompt": text} ).dict() yield from TextToSpeechPage( request=self.request, run_user=self.run_user @@ -1024,7 +1022,7 @@ def run_v2( for audio_url in response.output_audio: lip_state = LipsyncPage.RequestModel.parse_obj( { - **st.session_state, + **gui.session_state, "input_audio": audio_url, "selected_model": request.lipsync_model, } @@ -1048,14 +1046,14 @@ def render_selected_tab(self): def render_integrations_tab(self): from daras_ai_v2.breadcrumbs import get_title_breadcrumbs - st.newline() + gui.newline() # not signed in case if not self.request.user or self.request.user.is_anonymous: integration_welcome_screen(title="Connect your Copilot") - st.newline() - with st.center(): - st.anchor( + gui.newline() + with gui.center(): + gui.anchor( "Get Started", href=self.get_auth_url(self.app_url()), type="primary", @@ -1063,15 +1061,15 @@ def render_integrations_tab(self): return current_run, published_run = self.get_runs_from_query_params( - *extract_query_params(gooey_get_query_params()) + *extract_query_params(gui.get_query_params()) ) # type: ignore # signed in but not on a run the user can edit (admins will never see this) if not self.can_user_edit_run(current_run, published_run): integration_welcome_screen(title="Create your Saved Copilot") - st.newline() - with st.center(): - st.anchor( + gui.newline() + with gui.center(): + gui.anchor( "Run & Save this Copilot", href=self.get_auth_url(self.app_url()), type="primary", @@ -1082,8 +1080,8 @@ def render_integrations_tab(self): # note: this means we no longer allow botintegrations on non-published runs which is a breaking change requested by Sean if not self.can_user_edit_published_run(published_run): integration_welcome_screen(title="Save your Published Copilot") - st.newline() - with st.center(): + gui.newline() + with gui.center(): self._render_published_run_buttons( current_run=current_run, published_run=published_run, @@ -1123,7 +1121,7 @@ def render_integrations_tab(self): return # this gets triggered on the /add route - if st.session_state.pop("--add-integration", None): + if gui.session_state.pop("--add-integration", None): cancel_url = self.current_app_url(RecipeTabs.integrations) self.render_integrations_add( label=f""" @@ -1132,12 +1130,12 @@ def render_integrations_tab(self): """, run_title=run_title, ) - with st.center(): - if st.button("Return to Test & Configure"): - raise RedirectException(cancel_url) + with gui.center(): + if gui.button("Return to Test & Configure"): + raise gui.RedirectException(cancel_url) return - with st.center(): + with gui.center(): # signed in, can edit, and has connected botintegrations on this run self.render_integrations_settings( integrations=list(integrations_qs), run_title=run_title @@ -1147,25 +1145,25 @@ def render_integrations_add(self, label: str, run_title: str): from routers.facebook_api import fb_connect_url, wa_connect_url from routers.slack_api import slack_connect_url - st.write(label, unsafe_allow_html=True, className="text-center") - st.newline() + gui.write(label, unsafe_allow_html=True, className="text-center") + gui.newline() pressed_platform = None with ( - st.tag("table", className="d-flex justify-content-center"), - st.tag("tbody"), + gui.tag("table", className="d-flex justify-content-center"), + gui.tag("tbody"), ): for choice in connect_choices: - with st.tag("tr"): - with st.tag("td"): - if st.button( + with gui.tag("tr"): + with gui.tag("td"): + if gui.button( f'{choice.platform.name}', className="p-0 border border-1 border-secondary rounded", style=dict(width="160px", height="60px"), ): pressed_platform = choice.platform - with st.tag("td", className="ps-3"): - st.caption(choice.label) + with gui.tag("td", className="ps-3"): + gui.caption(choice.label) if pressed_platform: on_connect = self.current_app_url(RecipeTabs.integrations) @@ -1195,11 +1193,11 @@ def render_integrations_add(self, label: str, run_title: str): platform=pressed_platform, run_url=self.current_app_url() or "", ) - raise RedirectException(redirect_url) + raise gui.RedirectException(redirect_url) - st.newline() + gui.newline() api_tab_url = self.current_app_url(RecipeTabs.run_as_api) - st.write( + gui.write( f"Or use [our API]({api_tab_url}) to build custom integrations with your server.", className="text-center", ) @@ -1209,45 +1207,45 @@ def render_integrations_settings( ): from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button - st.markdown("#### Configure your Copilot") + gui.markdown("#### Configure your Copilot") if len(integrations) > 1: - with st.div(style={"minWidth": "500px", "textAlign": "left"}): + with gui.div(style={"minWidth": "500px", "textAlign": "left"}): integrations_map = {i.id: i for i in integrations} - bi_id = st.selectbox( + bi_id = gui.selectbox( label="", options=integrations_map.keys(), format_func=lambda bi_id: f"{Platform(integrations_map[bi_id].platform).get_icon()}   {integrations_map[bi_id].name}", key="bi_id", ) bi = integrations_map[bi_id] - old_bi_id = st.session_state.get("old_bi_id", bi_id) + old_bi_id = gui.session_state.get("old_bi_id", bi_id) if bi_id != old_bi_id: - raise RedirectException( + raise gui.RedirectException( self.current_app_url( RecipeTabs.integrations, path_params=dict(integration_id=bi.api_integration_id()), ) ) - st.session_state["old_bi_id"] = bi_id + gui.session_state["old_bi_id"] = bi_id else: bi = integrations[0] icon = Platform(bi.platform).get_icon() if bi.platform == Platform.WEB: web_widget_config(bi, self.request.user) - st.newline() + gui.newline() - st.newline() - with st.div(style={"width": "100%", "textAlign": "left"}): + gui.newline() + with gui.div(style={"width": "100%", "textAlign": "left"}): test_link = get_bot_test_link(bi) - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### Connected To") - st.write(f"{icon} {bi}", unsafe_allow_html=True) + gui.write("###### Connected To") + gui.write(f"{icon} {bi}", unsafe_allow_html=True) with col2: if not test_link: - st.write("Message quicklink not available.") + gui.write("Message quicklink not available.") elif bi.platform == Platform.TWILIO: copy_to_clipboard_button( ' Copy Phone Number', @@ -1262,7 +1260,7 @@ def render_integrations_settings( ) if bi.platform == Platform.FACEBOOK: - st.anchor( + gui.anchor( ' Open Inbox', "https://www.facebook.com/latest/inbox", unsafe_allow_html=True, @@ -1276,53 +1274,53 @@ def render_integrations_settings( type="secondary", ) - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### Test") - st.caption(f"Send a test message via {Platform(bi.platform).label}.") + gui.write("###### Test") + gui.caption(f"Send a test message via {Platform(bi.platform).label}.") with col2: if not test_link: - st.write("Message quicklink not available.") + gui.write("Message quicklink not available.") elif bi.platform == Platform.FACEBOOK: - st.anchor( + gui.anchor( f"{icon} Open Profile", test_link, unsafe_allow_html=True, new_tab=True, ) - st.anchor( + gui.anchor( f' Open Messenger', str(furl("https://www.messenger.com/t/") / bi.fb_page_id), unsafe_allow_html=True, new_tab=True, ) elif bi.platform == Platform.TWILIO: - st.anchor( + gui.anchor( ' Start Voice Call', test_link, unsafe_allow_html=True, new_tab=True, ) - st.anchor( + gui.anchor( ' Send SMS', str(furl("sms:") / bi.twilio_phone_number.as_e164), unsafe_allow_html=True, new_tab=True, ) else: - st.anchor( + gui.anchor( f"{icon} Message {bi.get_display_name()}", test_link, unsafe_allow_html=True, new_tab=True, ) - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### Understand your Users") - st.caption(f"See real-time analytics.") + gui.write("###### Understand your Users") + gui.caption(f"See real-time analytics.") with col2: - st.anchor( + gui.anchor( "๐Ÿ“Š View Analytics", str( furl( @@ -1338,7 +1336,7 @@ def render_integrations_settings( new_tab=True, ) if bi.platform == Platform.TWILIO and bi.twilio_phone_number_sid: - st.anchor( + gui.anchor( f"{icon} Open Twilio Console", str( furl( @@ -1352,14 +1350,14 @@ def render_integrations_settings( ) if bi.platform == Platform.WHATSAPP and bi.wa_business_waba_id: - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### WhatsApp Business Management") - st.caption( + gui.write("###### WhatsApp Business Management") + gui.caption( f"Access your WhatsApp account on Meta to approve message templates, etc." ) with col2: - st.anchor( + gui.anchor( "Business Settings", str( furl( @@ -1369,7 +1367,7 @@ def render_integrations_settings( ), new_tab=True, ) - st.anchor( + gui.anchor( "WhatsApp Manager", str( furl( @@ -1380,18 +1378,18 @@ def render_integrations_settings( new_tab=True, ) - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### Add Integration") - st.caption(f"Add another connection for {run_title}.") + gui.write("###### Add Integration") + gui.caption(f"Add another connection for {run_title}.") with col2: - st.anchor( + gui.anchor( f'   Add Integration', str(furl(self.current_app_url(RecipeTabs.integrations)) / "add/"), unsafe_allow_html=True, ) - with st.expander("Configure Settings ๐Ÿ› ๏ธ"): + with gui.expander("Configure Settings ๐Ÿ› ๏ธ"): if bi.platform == Platform.SLACK: slack_specific_settings(bi, run_title) if bi.platform == Platform.TWILIO: @@ -1399,25 +1397,25 @@ def render_integrations_settings( general_integration_settings(bi, self.request.user) if bi.platform in [Platform.SLACK, Platform.WHATSAPP, Platform.TWILIO]: - st.newline() + gui.newline() broadcast_input(bi) - st.write("---") + gui.write("---") - col1, col2 = st.columns(2, style={"alignItems": "center"}) + col1, col2 = gui.columns(2, style={"alignItems": "center"}) with col1: - st.write("###### Disconnect") - st.caption( + gui.write("###### Disconnect") + gui.caption( f"Disconnect {run_title} from {Platform(bi.platform).label} {bi.get_display_name()}." ) with col2: - if st.button( + if gui.button( "๐Ÿ’”๏ธ Disconnect", key="btn_disconnect", ): bi.saved_run = None bi.published_run = None bi.save() - st.experimental_rerun() + gui.rerun() def integrations_on_connect(self, current_run, published_run): from app_users.models import AppUser @@ -1430,8 +1428,8 @@ def integrations_on_connect(self, current_run, published_run): except BotIntegration.DoesNotExist: continue if bi.saved_run is not None: - with st.center(): - st.write( + with gui.center(): + gui.write( f"โš ๏ธ {bi.get_display_name()} is already connected to a different published run by {AppUser.objects.filter(uid=bi.billing_account_uid).first().display_name}. Please disconnect it first." ) return @@ -1451,7 +1449,7 @@ def integrations_on_connect(self, current_run, published_run): path_params = dict(integration_id=bi.api_integration_id()) else: path_params = dict() - raise RedirectException( + raise gui.RedirectException( self.current_app_url(RecipeTabs.integrations, path_params=path_params) ) @@ -1486,7 +1484,7 @@ def infer_asr_model_and_language( def chat_list_view(): # render a reversed list view - with st.div( + with gui.div( className="pb-1", style=dict( maxHeight="80vh", @@ -1496,39 +1494,39 @@ def chat_list_view(): border="1px solid #c9c9c9", ), ): - with st.div(className="px-3"): - show_raw_msgs = st.checkbox("_Show Raw Output_") + with gui.div(className="px-3"): + show_raw_msgs = gui.checkbox("_Show Raw Output_") # render the last output with msg_container_widget(CHATML_ROLE_ASSISTANT): if show_raw_msgs: - output_text = st.session_state.get("raw_output_text", []) + output_text = gui.session_state.get("raw_output_text", []) else: - output_text = st.session_state.get("output_text", []) - output_video = st.session_state.get("output_video", []) - output_audio = st.session_state.get("output_audio", []) + output_text = gui.session_state.get("output_text", []) + output_video = gui.session_state.get("output_video", []) + output_audio = gui.session_state.get("output_audio", []) if output_text: - st.write(f"**Assistant**") + gui.write(f"**Assistant**") for idx, text in enumerate(output_text): - st.write(text) + gui.write(text) try: - st.video(output_video[idx]) + gui.video(output_video[idx]) except IndexError: try: - st.audio(output_audio[idx]) + gui.audio(output_audio[idx]) except IndexError: pass - output_documents = st.session_state.get("output_documents", []) + output_documents = gui.session_state.get("output_documents", []) if output_documents: for doc in output_documents: - st.write(doc) - messages = st.session_state.get("messages", []).copy() + gui.write(doc) + messages = gui.session_state.get("messages", []).copy() # add last input to history if present if show_raw_msgs: - input_prompt = st.session_state.get("raw_input_text") + input_prompt = gui.session_state.get("raw_input_text") else: - input_prompt = st.session_state.get("input_prompt") - input_images = st.session_state.get("input_images") - input_audio = st.session_state.get("input_audio") + input_prompt = gui.session_state.get("input_prompt") + input_images = gui.session_state.get("input_images") + input_audio = gui.session_state.get("input_audio") if input_prompt or input_images or input_audio: messages += [ format_chat_entry( @@ -1541,36 +1539,36 @@ def chat_list_view(): images = get_entry_images(entry) text = get_entry_text(entry) if text or images or input_audio: - st.write(f"**{entry['role'].capitalize()}** \n{text}") + gui.write(f"**{entry['role'].capitalize()}** \n{text}") if images: for im in images: - st.image(im, style={"maxHeight": "200px"}) + gui.image(im, style={"maxHeight": "200px"}) if input_audio: - st.audio(input_audio) + gui.audio(input_audio) input_audio = None def chat_input_view() -> tuple[bool, tuple[str, list[str], str, list[str]]]: - with st.div( + with gui.div( className="px-3 pt-3 d-flex gap-1", style=dict(background="rgba(239, 239, 239, 0.6)"), ): show_uploader_key = "--show-file-uploader" - show_uploader = st.session_state.setdefault(show_uploader_key, False) - if st.button( + show_uploader = gui.session_state.setdefault(show_uploader_key, False) + if gui.button( "๐Ÿ“Ž", style=dict(height="3.2rem", backgroundColor="white"), ): show_uploader = not show_uploader - st.session_state[show_uploader_key] = show_uploader + gui.session_state[show_uploader_key] = show_uploader - with st.div(className="flex-grow-1"): - new_input_text = st.text_area("", placeholder="Send a message", height=50) + with gui.div(className="flex-grow-1"): + new_input_text = gui.text_area("", placeholder="Send a message", height=50) - pressed_send = st.button("โœˆ Send", style=dict(height="3.2rem")) + pressed_send = gui.button("โœˆ Send", style=dict(height="3.2rem")) if show_uploader: - uploaded_files = st.file_uploader("", accept_multiple_files=True) + uploaded_files = gui.file_uploader("", accept_multiple_files=True) new_input_images = [] new_input_audio = None new_input_documents = [] @@ -1599,7 +1597,7 @@ def chat_input_view() -> tuple[bool, tuple[str, list[str], str, list[str]]]: def msg_container_widget(role: str): - return st.div( + return gui.div( className="px-3 py-1 pt-2", style=dict( background=( @@ -1622,10 +1620,10 @@ def convo_window_clipper( def integration_welcome_screen(title: str): - with st.center(): - st.markdown(f"#### {title}") + with gui.center(): + gui.markdown(f"#### {title}") - col1, col2, col3 = st.columns( + col1, col2, col3 = gui.columns( 3, column_props=dict( style=dict( @@ -1639,21 +1637,21 @@ def integration_welcome_screen(title: str): style={"justifyContent": "center"}, ) with col1: - st.html("๐Ÿƒโ€โ™€๏ธ", style={"fontSize": "4rem"}) - st.markdown( + gui.html("๐Ÿƒโ€โ™€๏ธ", style={"fontSize": "4rem"}) + gui.markdown( """ 1. Fork & Save your Run """ ) - st.caption("Make changes, Submit & Save your perfect workflow") + gui.caption("Make changes, Submit & Save your perfect workflow") with col2: - st.image(INTEGRATION_IMG, alt="Integrations", style={"height": "5rem"}) - st.markdown("2. Connect to Slack, Whatsapp or your App") - st.caption("Or Facebook, Instagram and the web. Wherever your users chat.") + gui.image(INTEGRATION_IMG, alt="Integrations", style={"height": "5rem"}) + gui.markdown("2. Connect to Slack, Whatsapp or your App") + gui.caption("Or Facebook, Instagram and the web. Wherever your users chat.") with col3: - st.html("๐Ÿ“ˆ", style={"fontSize": "4rem"}) - st.markdown("3. Test, Analyze & Iterate") - st.caption("Analyze your usage. Update your Saved Run to test changes.") + gui.html("๐Ÿ“ˆ", style={"fontSize": "4rem"}) + gui.markdown("3. Test, Analyze & Iterate") + gui.caption("Analyze your usage. Update your Saved Run to test changes.") class ConnectChoice(typing.NamedTuple): diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 8a7130240..648a2894f 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -1,6 +1,7 @@ import base64 from datetime import datetime, timedelta +import gooey_gui as gui from dateutil.relativedelta import relativedelta from django.db.models import Count, Avg, Q from django.db.models.functions import ( @@ -15,7 +16,6 @@ from furl import furl from pydantic import BaseModel -import gooey_ui as st from app_users.models import AppUser from bots.models import ( Workflow, @@ -34,7 +34,6 @@ CHATML_ROLE_ASSISTANT, CHATML_ROLE_USER, ) -from gooey_ui import RedirectException from recipes.VideoBots import VideoBotsPage @@ -67,23 +66,23 @@ def current_app_url( def show_title_breadcrumb_share( self, bi: BotIntegration, run_title: str, run_url: str ): - with st.div(className="d-flex justify-content-between mt-4"): - with st.div(className="d-lg-flex d-block align-items-center"): - with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - with st.breadcrumbs(): + with gui.div(className="d-flex justify-content-between mt-4"): + with gui.div(className="d-lg-flex d-block align-items-center"): + with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with gui.breadcrumbs(): metadata = VideoBotsPage.workflow.get_or_create_metadata() - st.breadcrumb_item( + gui.breadcrumb_item( metadata.short_title, link_to=VideoBotsPage.app_url(), className="text-muted", ) if not (bi.published_run_id and bi.published_run.is_root()): - st.breadcrumb_item( + gui.breadcrumb_item( run_title, link_to=run_url, className="text-muted", ) - st.breadcrumb_item( + gui.breadcrumb_item( "Integrations", link_to=VideoBotsPage.current_app_url( RecipeTabs.integrations, @@ -102,11 +101,11 @@ def show_title_breadcrumb_share( show_as_link=self.is_current_user_admin(), ) - with st.div(className="d-flex align-items-center"): - with st.div(className="d-flex align-items-start right-action-icons"): + with gui.div(className="d-flex align-items-center"): + with gui.div(className="d-flex align-items-start right-action-icons"): self._render_social_buttons(show_button_text=True) - st.markdown("# " + self.get_dynamic_meta_title()) + gui.markdown("# " + self.get_dynamic_meta_title()) def get_dynamic_meta_title(self): return f"๐Ÿ“Š {self.bi.name} Analytics" if self.bi else self.title @@ -115,7 +114,7 @@ def render(self): self.setup_sentry() if not self.request.user or self.request.user.is_anonymous: - st.write("**Please Login to view stats for your bot integrations**") + gui.write("**Please Login to view stats for your bot integrations**") return if self.is_current_user_admin(): bi_qs = BotIntegration.objects.all().order_by("platform", "-created_at") @@ -125,12 +124,12 @@ def render(self): ).order_by("platform", "-created_at") if not bi_qs.exists(): - st.write( + gui.write( "**Please connect a bot to a platform to view stats for your bot integrations or login to an account with connected bot integrations**" ) return - bi_id = self.request.query_params.get("bi_id") or st.session_state.get("bi_id") + bi_id = self.request.query_params.get("bi_id") or gui.session_state.get("bi_id") try: self.bi = bi = bi_qs.get(id=bi_id) except BotIntegration.DoesNotExist: @@ -138,7 +137,7 @@ def render(self): # for backwards compatibility with old urls if self.request.query_params.get("bi_id"): - raise RedirectException( + raise gui.RedirectException( str( furl( VideoBotsPage.app_url( @@ -160,7 +159,7 @@ def render(self): run_title = bi.saved_run.page_title # this is mostly for backwards compat self.show_title_breadcrumb_share(bi, run_title, run_url) - col1, col2 = st.columns([1, 2]) + col1, col2 = gui.columns([1, 2]) with col1: conversations, messages = calculate_overall_stats( @@ -184,22 +183,24 @@ def render(self): ) if df.empty or "date" not in df.columns: - st.write("No data to show yet.") + gui.write("No data to show yet.") self.update_url( view, - st.session_state.get("details"), + gui.session_state.get("details"), start_date, end_date, - st.session_state.get("sort_by"), + gui.session_state.get("sort_by"), ) return with col2: plot_graphs(view, df) - st.write("---") - st.session_state.setdefault("details", self.request.query_params.get("details")) - details = st.horizontal_radio( + gui.write("---") + gui.session_state.setdefault( + "details", self.request.query_params.get("details") + ) + details = gui.horizontal_radio( "### Details", options=( [ @@ -256,15 +257,15 @@ def render(self): sort_by = None if options: query_sort_by = self.request.query_params.get("sort_by") - st.session_state.setdefault( + gui.session_state.setdefault( "sort_by", query_sort_by if query_sort_by in options else options[0] ) - st.selectbox( + gui.selectbox( "Sort by", options=options, key="sort_by", ) - sort_by = st.session_state["sort_by"] + sort_by = gui.session_state["sort_by"] df = get_tabular_data( bi=bi, @@ -279,7 +280,7 @@ def render(self): if not df.empty: columns = df.columns.tolist() - st.data_table( + gui.data_table( [columns] + [ [ @@ -294,8 +295,8 @@ def render(self): ] ) # download as csv button - st.html("
") - if st.checkbox("Export"): + gui.html("
") + if gui.checkbox("Export"): df = get_tabular_data( bi=bi, conversations=conversations, @@ -307,12 +308,12 @@ def render(self): ) csv = df.to_csv() b64 = base64.b64encode(csv.encode()).decode() - st.html( + gui.html( f'Download CSV File' ) - st.caption("Includes full data (UI only shows first 500 rows)") + gui.caption("Includes full data (UI only shows first 500 rows)") else: - st.write("No data to show yet.") + gui.write("No data to show yet.") self.update_url(view, details, start_date, end_date, sort_by) @@ -330,25 +331,25 @@ def update_url(self, view, details, start_date, end_date, sort_by): } if f.query.params == new_query_params: return - raise RedirectException(str(f.set(query_params=new_query_params))) + raise gui.RedirectException(str(f.set(query_params=new_query_params))) def render_date_view_inputs(self, bi): - if st.checkbox("Show All"): + if gui.checkbox("Show All"): start_date = bi.created_at end_date = timezone.now() + timedelta(days=1) else: fifteen_days_ago = timezone.now() - timedelta(days=15) fifteen_days_ago = fifteen_days_ago.replace(hour=0, minute=0, second=0) - st.session_state.setdefault( + gui.session_state.setdefault( "start_date", self.request.query_params.get( "start_date", fifteen_days_ago.strftime("%Y-%m-%d") ), ) start_date: datetime = ( - st.date_input("Start date", key="start_date") or fifteen_days_ago + gui.date_input("Start date", key="start_date") or fifteen_days_ago ) - st.session_state.setdefault( + gui.session_state.setdefault( "end_date", self.request.query_params.get( "end_date", @@ -356,13 +357,13 @@ def render_date_view_inputs(self, bi): ), ) end_date: datetime = ( - st.date_input("End date", key="end_date") or timezone.now() + gui.date_input("End date", key="end_date") or timezone.now() ) - st.session_state.setdefault( + gui.session_state.setdefault( "view", self.request.query_params.get("view", "Daily") ) - st.write("---") - view = st.horizontal_radio( + gui.write("---") + view = gui.horizontal_radio( "### View", options=["Daily", "Weekly", "Monthly"], key="view", @@ -373,7 +374,7 @@ def render_date_view_inputs(self, bi): trunc_fn = TruncWeek elif view == "Daily": if end_date - start_date > timedelta(days=31): - st.write( + gui.write( "**Note: Date ranges greater than 31 days show weekly averages in daily view**" ) factor = 1.0 / 7.0 @@ -383,10 +384,10 @@ def render_date_view_inputs(self, bi): elif view == "Monthly": trunc_fn = TruncMonth start_date = start_date.replace(day=1) - st.session_state["start_date"] = start_date.strftime("%Y-%m-%d") + gui.session_state["start_date"] = start_date.strftime("%Y-%m-%d") if end_date.day != 1: end_date = end_date.replace(day=1) + relativedelta(months=1) - st.session_state["end_date"] = end_date.strftime("%Y-%m-%d") + gui.session_state["end_date"] = end_date.strftime("%Y-%m-%d") else: trunc_fn = TruncYear return start_date, end_date, view, factor, trunc_fn @@ -430,7 +431,7 @@ def calculate_overall_stats(*, bi, run_title, run_url): connection_detail = f"- Connected to: {bi.get_display_name()}" else: connection_detail = "" - st.markdown( + gui.markdown( f""" - Platform: {Platform(bi.platform).name.capitalize()} - Created on: {bi.created_at.strftime("%b %d, %Y")} @@ -629,8 +630,8 @@ def plot_graphs(view, df): template="plotly_white", ), ) - st.plotly_chart(fig) - st.write("---") + gui.plotly_chart(fig) + gui.write("---") fig = go.Figure( data=[ go.Scatter( @@ -734,8 +735,8 @@ def plot_graphs(view, df): ) ], ) - st.plotly_chart(fig) - st.write("---") + gui.plotly_chart(fig) + gui.write("---") fig = go.Figure( data=[ go.Scatter( @@ -760,8 +761,8 @@ def plot_graphs(view, df): template="plotly_white", ), ) - st.plotly_chart(fig) - st.write("---") + gui.plotly_chart(fig) + gui.write("---") fig = go.Figure( data=[ go.Scatter( @@ -810,7 +811,7 @@ def plot_graphs(view, df): template="plotly_white", ), ) - st.plotly_chart(fig) + gui.plotly_chart(fig) def get_tabular_data( diff --git a/recipes/asr_page.py b/recipes/asr_page.py index c18f806d9..58d49ffa4 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -3,7 +3,7 @@ from jinja2.lexer import whitespace_re from pydantic import BaseModel, Field -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow, SavedRun from daras_ai_v2.asr import ( AsrModels, @@ -89,12 +89,12 @@ def preview_description(self, state: dict): return "Transcribe mp3, WhatsApp audio + wavs with OpenAI's Whisper or AI4Bharat / Bhashini ASR models. Optionally translate to any language too." def render_description(self): - st.markdown( + gui.markdown( """ This workflow let's you compare the latest and finest speech recognition models from [OpenAI](https://openai.com/research/whisper), [AI4Bharat](https://ai4bharat.org) and [Bhashini](https://bhashini.gov.in) and Google's USM coming soon. """ ) - st.markdown( + gui.markdown( """ Just upload an audio file (mp3, wav, ogg or aac file) setting its language and then choose a speech recognition engine. You can also translate the output to any language too (using Google's Translation APIs). """ @@ -119,7 +119,7 @@ def render_form_v2(self): "#### Audio Files", accept=("audio/*", "video/*", "application/octet-stream"), ) - col1, col2 = st.columns(2, responsive=False) + col1, col2 = gui.columns(2, responsive=False) with col1: selected_model = enum_selector( AsrModels, @@ -131,7 +131,7 @@ def render_form_v2(self): asr_language_selector(AsrModels[selected_model]) def render_settings(self): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: translation_model = translation_model_selector() with col2: @@ -141,13 +141,13 @@ def render_settings(self): key="translation_target", ) if translation_model and translation_model.supports_glossary: - st.file_uploader( + gui.file_uploader( label=f"###### {field_title_desc(self.RequestModel, 'glossary_document')}", key="glossary_document", accept=SUPPORTED_SPREADSHEET_TYPES, ) - st.write("---") - selected_model = st.session_state.get("selected_model") + gui.write("---") + selected_model = gui.session_state.get("selected_model") if selected_model: translation_language_selector( model=translation_model, @@ -155,16 +155,18 @@ def render_settings(self): key="translation_source", allow_none=True, ) - st.caption( + gui.caption( "This is usually inferred from the spoken `language`, but in case that is set to Auto detect, you can specify one explicitly.", ) - st.write("---") + gui.write("---") enum_selector( AsrOutputFormat, label="###### Output Format", key="output_format" ) def validate_form_v2(self): - assert st.session_state.get("documents"), "Please provide at least 1 Audio File" + assert gui.session_state.get( + "documents" + ), "Please provide at least 1 Audio File" def render_output(self): text_outputs("**Transcription**", key="output_text", height=300) @@ -174,8 +176,8 @@ def render_example(self, state: dict): text_outputs("**Transcription**", value=state.get("output_text")) def render_steps(self): - if st.session_state.get("translation_model"): - col1, col2 = st.columns(2) + if gui.session_state.get("translation_model"): + col1, col2 = gui.columns(2) with col1: text_outputs("**Transcription**", key="raw_output_text") with col2: diff --git a/recipes/embeddings_page.py b/recipes/embeddings_page.py index 72bc6f5ac..e65580681 100644 --- a/recipes/embeddings_page.py +++ b/recipes/embeddings_page.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -import gooey_ui as st +import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import embeddings_model_selector @@ -27,32 +27,32 @@ class ResponseModel(BaseModel): embeddings: list[list[float]] def render_form_v2(self): - col1, col2 = st.columns(2) + col1, col2 = gui.columns(2) with col1: embeddings_model_selector(key="selected_model") - texts = st.session_state.setdefault("texts", [""]) + texts = gui.session_state.setdefault("texts", [""]) for i, text in enumerate(texts): - col1, col2 = st.columns([8, 3], responsive=False) + col1, col2 = gui.columns([8, 3], responsive=False) with col1: - texts[i] = st.text_area(f"##### `texts[{i}]`", value=text) + texts[i] = gui.text_area(f"##### `texts[{i}]`", value=text) with col2: - if st.button("๐Ÿ—‘๏ธ", className="mt-5"): + if gui.button("๐Ÿ—‘๏ธ", className="mt-5"): texts.pop(i) - st.experimental_rerun() - if st.button("โž• Add"): + gui.rerun() + if gui.button("โž• Add"): texts.append("") - st.experimental_rerun() + gui.rerun() def render_output(self): - for i, embedding in enumerate(st.session_state.get("embeddings", [])): - st.write(f"##### `embeddings[{i}]`") - st.json(embedding, depth=0) + for i, embedding in enumerate(gui.session_state.get("embeddings", [])): + gui.write(f"##### `embeddings[{i}]`") + gui.json(embedding, depth=0) def render_example(self, state: dict): - texts = st.session_state.setdefault("texts", [""]) + texts = gui.session_state.setdefault("texts", [""]) for i, text in enumerate(texts): - texts[i] = st.text_area(f"`texts[{i}]`", value=text, disabled=True) + texts[i] = gui.text_area(f"`texts[{i}]`", value=text, disabled=True) def run(self, state: dict) -> typing.Iterator[str | None]: request: EmbeddingsPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/uberduck.py b/recipes/uberduck.py index 19502c053..5b09a3367 100644 --- a/recipes/uberduck.py +++ b/recipes/uberduck.py @@ -1,35 +1,30 @@ import json -import smtplib import time -from email.mime.application import MIMEApplication -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText -from os.path import basename +import gooey_gui as gui import requests -import gooey_ui as st from decouple import config def get_audio(uuid): - with st.spinner(f"Generating your audio file ..."): + with gui.spinner(f"Generating your audio file ..."): while True: data = requests.get(f"https://api.uberduck.ai/speak-status?uuid={uuid}") path = json.loads(data.text)["path"] if path: - st.audio(path) + gui.audio(path) break else: time.sleep(2) def main(): - st.write("# Text To Audio") + gui.write("# Text To Audio") - with st.form(key="send_email", clear_on_submit=False): - voice = st.text_input(label="Voice", value="zwf") - text = st.text_area(label="Text input", value="This is a test.") - submitted = st.form_submit_button("Generate") + with gui.form(key="send_email", clear_on_submit=False): + voice = gui.text_input(label="Voice", value="zwf") + text = gui.text_area(label="Text input", value="This is a test.") + submitted = gui.form_submit_button("Generate") if submitted: response = requests.post( "https://api.uberduck.ai/speak", diff --git a/routers/account.py b/routers/account.py index a630e4238..552e193d9 100644 --- a/routers/account.py +++ b/routers/account.py @@ -2,30 +2,28 @@ from contextlib import contextmanager from enum import Enum +import gooey_gui as gui from fastapi import APIRouter from fastapi.requests import Request from furl import furl from loguru import logger from requests.models import HTTPError -import gooey_ui as st from bots.models import PublishedRun, PublishedRunVisibility, Workflow from daras_ai_v2 import icons, paypal -from daras_ai_v2.base import RedirectException from daras_ai_v2.billing import billing_page from daras_ai_v2.fastapi_tricks import get_route_path, get_app_route_url from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import raw_build_meta_tags from daras_ai_v2.profiles import edit_user_profile_page -from gooey_ui.components.pills import pill from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path app = APIRouter() -@st.route(app, "/payment-processing/") +@gui.route(app, "/payment-processing/") def payment_processing_route( request: Request, provider: str | None = None, subscription_id: str | None = None ): @@ -33,7 +31,7 @@ def payment_processing_route( subtext = None if provider == "paypal": - success = st.run_in_thread( + success = gui.run_in_thread( threaded_paypal_handle_subscription_updated, args=[subscription_id], ) @@ -48,18 +46,18 @@ def payment_processing_route( ) with page_wrapper(request, className="m-auto"): - with st.center(): - with st.div(className="d-flex align-items-center"): - st.div( + with gui.center(): + with gui.div(className="d-flex align-items-center"): + gui.div( className="gooey-spinner me-4", style=dict(height="3rem", width="3rem"), ) - st.write("# Processing payment...") + gui.write("# Processing payment...") if subtext: - st.caption(subtext) + gui.caption(subtext) - st.js( + gui.js( # language=JavaScript """ setTimeout(() => { @@ -75,7 +73,7 @@ def payment_processing_route( ) -@st.route(app, "/account/") +@gui.route(app, "/account/") def account_route(request: Request): with account_page_wrapper(request, AccountTabs.billing): billing_tab(request) @@ -91,7 +89,7 @@ def account_route(request: Request): ) -@st.route(app, "/account/profile/") +@gui.route(app, "/account/profile/") def profile_route(request: Request): with account_page_wrapper(request, AccountTabs.profile): profile_tab(request) @@ -107,7 +105,7 @@ def profile_route(request: Request): ) -@st.route(app, "/saved/") +@gui.route(app, "/saved/") def saved_route(request: Request): with account_page_wrapper(request, AccountTabs.saved): all_saved_runs_tab(request) @@ -123,7 +121,7 @@ def saved_route(request: Request): ) -@st.route(app, "/account/api-keys/") +@gui.route(app, "/account/api-keys/") def api_keys_route(request: Request): with account_page_wrapper(request, AccountTabs.api_keys): api_keys_tab(request) @@ -172,39 +170,39 @@ def _render_run(pr: PublishedRun): workflow = Workflow(pr.workflow) visibility = PublishedRunVisibility(pr.visibility) - with st.div(className="mb-2 d-flex justify-content-between align-items-start"): - pill( + with gui.div(className="mb-2 d-flex justify-content-between align-items-start"): + gui.pill( visibility.get_badge_html(), unsafe_allow_html=True, className="border border-dark", ) - pill(workflow.short_title, className="border border-dark") + gui.pill(workflow.short_title, className="border border-dark") workflow.page_cls().render_published_run_preview(pr) - st.write("# Saved Workflows") + gui.write("# Saved Workflows") if prs: if request.user.handle: - st.caption( + gui.caption( "All your Saved workflows are here, with public ones listed on your " f"profile page at {request.user.handle.get_app_url()}." ) else: edit_profile_url = AccountTabs.profile.url_path - st.caption( + gui.caption( "All your Saved workflows are here. Public ones will be listed on your " f"profile page if you [create a username]({edit_profile_url})." ) - with st.div(className="mt-4"): + with gui.div(className="mt-4"): grid_layout(3, prs, _render_run) else: - st.write("No saved runs yet", className="text-muted") + gui.write("No saved runs yet", className="text-muted") def api_keys_tab(request: Request): - st.write("# ๐Ÿ” API Keys") + gui.write("# ๐Ÿ” API Keys") manage_api_keys(request.user) @@ -213,22 +211,22 @@ def account_page_wrapper(request: Request, current_tab: TabData): if not request.user or request.user.is_anonymous: next_url = request.query_params.get("next", "/account/") redirect_url = furl("/login", query_params={"next": next_url}) - raise RedirectException(str(redirect_url)) + raise gui.RedirectException(str(redirect_url)) with page_wrapper(request): - st.div(className="mt-5") - with st.nav_tabs(): + gui.div(className="mt-5") + with gui.nav_tabs(): for tab in AccountTabs: - with st.nav_item(tab.url_path, active=tab == current_tab): - st.html(tab.title) + with gui.nav_item(tab.url_path, active=tab == current_tab): + gui.html(tab.title) - with st.nav_tab_content(): + with gui.nav_tab_content(): yield def threaded_paypal_handle_subscription_updated(subscription_id: str) -> bool: """ - Always returns True when completed (for use in st.run_in_thread()) + Always returns True when completed (for use in gui.run_in_thread()) """ try: subscription = paypal.Subscription.retrieve(subscription_id) diff --git a/routers/api.py b/routers/api.py index dbba1faf7..aab15806c 100644 --- a/routers/api.py +++ b/routers/api.py @@ -27,7 +27,7 @@ HTTP_400_BAD_REQUEST, ) -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from auth.token_authentication import api_auth_header from bots.models import RetentionPolicy @@ -405,8 +405,8 @@ def submit_api_call( state.update(request_body) # set streamlit session state - st.set_session_state(state) - st.set_query_params(query_params) + gui.set_session_state(state) + gui.set_query_params(query_params) # create a new run try: @@ -416,7 +416,7 @@ def submit_api_call( retention_policy=retention_policy or RetentionPolicy.keep, ) except ValidationError as e: - raise RequestValidationError(e.raw_errors, body=st.session_state) from e + raise RequestValidationError(e.raw_errors, body=gui.session_state) from e # submit the task result = self.call_runner_task(sr) return self, result, sr.run_id, sr.uid diff --git a/routers/root.py b/routers/root.py index 0eda96071..2e1e7077e 100644 --- a/routers/root.py +++ b/routers/root.py @@ -21,7 +21,7 @@ FileResponse, ) -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import Workflow, BotIntegration from daras_ai.image_input import upload_file_from_bytes, safe_filename @@ -42,7 +42,6 @@ from daras_ai_v2.profiles import user_profile_page, get_meta_tags_for_profile from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates -from gooey_ui.components.url_button import url_button from handles.models import Handle app = APIRouter() @@ -195,7 +194,7 @@ def file_upload(form_data: FormData = fastapi_request_form): return {"url": upload_file_from_bytes(filename, data, content_type)} -@st.route(app, "/GuiComponents/") +@gui.route(app, "/GuiComponents/") def component_page(request: Request): import components_doc @@ -211,7 +210,7 @@ def component_page(request: Request): } -@st.route(app, "/explore/") +@gui.route(app, "/explore/") def explore_page(request: Request): import explore @@ -227,7 +226,7 @@ def explore_page(request: Request): } -@st.route(app, "/api/") +@gui.route(app, "/api/") def api_docs_page(request: Request): with page_wrapper(request): _api_docs_page(request) @@ -248,7 +247,7 @@ def _api_docs_page(request): api_docs_url = str(furl(settings.API_BASE_URL) / "docs") - st.markdown( + gui.markdown( f""" # Gooey.AI API Platform @@ -271,34 +270,34 @@ def _api_docs_page(request): unsafe_allow_html=True, ) - st.write("---") + gui.write("---") options = { page_cls.workflow.value: page_cls().get_recipe_title() for page_cls in all_api_pages } - st.write( + gui.write( "##### โš• API Generator\nChoose a workflow to see how you can interact with it via the API" ) - col1, col2 = st.columns([11, 1], responsive=False) + col1, col2 = gui.columns([11, 1], responsive=False) with col1: - with st.div(className="pt-1"): + with gui.div(className="pt-1"): workflow = Workflow( - st.selectbox( + gui.selectbox( "", options=options, format_func=lambda x: options[x], ) ) with col2: - url_button(workflow.page_cls.app_url()) + gui.url_button(workflow.page_cls.app_url()) - st.write("###### ๐Ÿ“ค Example Request") + gui.write("###### ๐Ÿ“ค Example Request") - include_all = st.checkbox("Show all fields") - as_async = st.checkbox("Run Async") - as_form_data = st.checkbox("Upload Files via Form Data") + include_all = gui.checkbox("Show all fields") + as_async = gui.checkbox("Run Async") + as_form_data = gui.checkbox("Upload Files via Form Data") page = workflow.page_cls(request=request) state = page.get_root_published_run().saved_run.to_dict() @@ -313,17 +312,17 @@ def _api_docs_page(request): as_form_data=as_form_data, as_async=as_async, ) - st.write("") + gui.write("") - st.write("###### ๐ŸŽ Example Response") - st.json(response_body, expanded=True) + gui.write("###### ๐ŸŽ Example Response") + gui.json(response_body, expanded=True) - st.write("---") - with st.tag("a", id="api-keys"): - st.write("##### ๐Ÿ” API keys") + gui.write("---") + with gui.tag("a", id="api-keys"): + gui.write("##### ๐Ÿ” API keys") if not page.request.user or page.request.user.is_anonymous: - st.write( + gui.write( "**Please [Login](/login/?next=/api/) to generate the `$GOOEY_API_KEY`**" ) return @@ -331,7 +330,7 @@ def _api_docs_page(request): manage_api_keys(page.request.user) -@st.route( +@gui.route( app, "/{page_slug}/examples/", "/{page_slug}/{run_slug}/examples/", @@ -343,7 +342,7 @@ def examples_route( return render_page(request, page_slug, RecipeTabs.examples, example_id) -@st.route( +@gui.route( app, "/{page_slug}/api/", "/{page_slug}/{run_slug}/api/", @@ -355,7 +354,7 @@ def api_route( return render_page(request, page_slug, RecipeTabs.run_as_api, example_id) -@st.route( +@gui.route( app, "/{page_slug}/history/", "/{page_slug}/{run_slug}/history/", @@ -367,7 +366,7 @@ def history_route( return render_page(request, page_slug, RecipeTabs.history, example_id) -@st.route( +@gui.route( app, "/{page_slug}/saved/", "/{page_slug}/{run_slug}/saved/", @@ -379,7 +378,7 @@ def save_route( return render_page(request, page_slug, RecipeTabs.saved, example_id) -@st.route( +@gui.route( app, "/{page_slug}/integrations/add/", "/{page_slug}/{run_slug}/integrations/add/", @@ -391,11 +390,11 @@ def add_integrations_route( run_slug: str = None, example_id: str = None, ): - st.session_state["--add-integration"] = True + gui.session_state["--add-integration"] = True return render_page(request, page_slug, RecipeTabs.integrations, example_id) -@st.route( +@gui.route( app, "/{page_slug}/integrations/{integration_id}/stats/", "/{page_slug}/{run_slug}/integrations/{integration_id}/stats/", @@ -411,13 +410,13 @@ def integrations_stats_route( from routers.bots_api import api_hashids try: - st.session_state.setdefault("bi_id", api_hashids.decode(integration_id)[0]) + gui.session_state.setdefault("bi_id", api_hashids.decode(integration_id)[0]) except IndexError: raise HTTPException(status_code=404) return render_page(request, "stats", RecipeTabs.integrations, example_id) -@st.route( +@gui.route( app, "/{page_slug}/integrations/{integration_id}/analysis/", "/{page_slug}/{run_slug}/integrations/{integration_id}/analysis/", @@ -453,7 +452,7 @@ def integrations_analysis_route( ) -@st.route( +@gui.route( app, "/{page_slug}/integrations/", "/{page_slug}/{run_slug}/integrations/", @@ -474,13 +473,13 @@ def integrations_route( if integration_id: try: - st.session_state.setdefault("bi_id", api_hashids.decode(integration_id)[0]) + gui.session_state.setdefault("bi_id", api_hashids.decode(integration_id)[0]) except IndexError: raise HTTPException(status_code=404) return render_page(request, page_slug, RecipeTabs.integrations, example_id) -@st.route( +@gui.route( app, "/chat/", "/chats/", @@ -572,7 +571,7 @@ def chat_lib_route(request: Request, integration_id: str, integration_name: str ) -@st.route( +@gui.route( app, "/{page_slug}/", "/{page_slug}/{run_slug}/", @@ -632,13 +631,13 @@ def render_page( return RedirectResponse(str(new_url.set(origin=None)), status_code=301) # this is because the code still expects example_id to be in the query params - st.set_query_params(dict(request.query_params) | dict(example_id=example_id)) + gui.set_query_params(dict(request.query_params) | dict(example_id=example_id)) _, run_id, uid = extract_query_params(request.query_params) page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) - if not st.session_state: + if not gui.session_state: sr = page.get_sr_from_query_params(example_id, run_id, uid) - st.session_state.update(page.load_state_from_sr(sr)) + gui.session_state.update(page.load_state_from_sr(sr)) with page_wrapper(request): page.render() @@ -647,7 +646,7 @@ def render_page( meta=build_meta_tags( url=get_og_url_path(request), page=page, - state=st.session_state, + state=gui.session_state, run_id=run_id, uid=uid, example_id=example_id, @@ -684,16 +683,16 @@ def page_wrapper(request: Request, className=""): request.user.uid ).decode() - with st.div(className="d-flex flex-column min-vh-100"): - st.html(templates.get_template("gtag.html").render(**context)) - st.html(templates.get_template("header.html").render(**context)) - st.html(copy_to_clipboard_scripts) + with gui.div(className="d-flex flex-column min-vh-100"): + gui.html(templates.get_template("gtag.html").render(**context)) + gui.html(templates.get_template("header.html").render(**context)) + gui.html(copy_to_clipboard_scripts) - with st.div(id="main-content", className="container " + className): + with gui.div(id="main-content", className="container " + className): yield - st.html(templates.get_template("footer.html").render(**context)) - st.html(templates.get_template("login_scripts.html").render(**context)) + gui.html(templates.get_template("footer.html").render(**context)) + gui.html(templates.get_template("login_scripts.html").render(**context)) INTEGRATION_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/c3ba2392-d6b9-11ee-a67b-6ace8d8c9501/image.png" diff --git a/tests/test_checkout.py b/tests/test_checkout.py index 19a988543..21a9a3291 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -4,7 +4,7 @@ from app_users.models import AppUser from daras_ai_v2 import settings from daras_ai_v2.billing import stripe_subscription_checkout_redirect -from gooey_ui import RedirectException +from gooey_gui import RedirectException from payments.plans import PricingPlan from server import app diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 33fb7dbed..6ac6e0591 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,13 +1,13 @@ +import gooey_gui as gui import pytest -import math +from starlette.testclient import TestClient + from bots.models import AppUser +from bots.models import SavedRun, Workflow from recipes.CompareLLM import CompareLLMPage from recipes.VideoBots import VideoBotsPage -from usage_costs.models import UsageCost, ModelPricing -from bots.models import SavedRun, Workflow, WorkflowMetadata -from gooey_ui.state import set_query_params -from starlette.testclient import TestClient from server import app +from usage_costs.models import UsageCost, ModelPricing client = TestClient(app) @@ -47,7 +47,7 @@ def test_copilot_get_raw_price_round_up(): dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) copilot_page = VideoBotsPage(run_user=user) - set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) assert ( copilot_page.get_price_roundoff(state=state) == 210 + copilot_page.PROFIT_CREDITS @@ -108,7 +108,7 @@ def test_multiple_llm_sums_usage_cost(): ) llm_page = CompareLLMPage(run_user=user) - set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -153,7 +153,7 @@ def test_workflowmetadata_2x_multiplier(): metadata.save() llm_page = CompareLLMPage(run_user=user) - set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 ) diff --git a/url_shortener/models.py b/url_shortener/models.py index b7daabc6b..21b0864d7 100644 --- a/url_shortener/models.py +++ b/url_shortener/models.py @@ -8,8 +8,8 @@ from bots.models import Workflow, SavedRun from daras_ai.image_input import truncate_filename from daras_ai_v2 import settings -from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params +import gooey_gui as gui class ShortenedURLQuerySet(models.QuerySet): @@ -17,7 +17,7 @@ def get_or_create_for_workflow( self, *, user: AppUser, workflow: Workflow, **kwargs ) -> tuple["ShortenedURL", bool]: surl, created = self.filter_first_or_create(user=user, **kwargs) - _, run_id, uid = extract_query_params(gooey_get_query_params()) + _, run_id, uid = extract_query_params(gui.get_query_params()) surl.saved_runs.add( SavedRun.objects.get_or_create( workflow=workflow, diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py index ea622e06a..6596c0d22 100644 --- a/usage_costs/cost_utils.py +++ b/usage_costs/cost_utils.py @@ -1,18 +1,18 @@ from loguru import logger -from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params from usage_costs.models import ( UsageCost, ModelSku, ModelPricing, ) +import gooey_gui as gui def record_cost_auto(model: str, sku: ModelSku, quantity: int): from bots.models import SavedRun - _, run_id, uid = extract_query_params(gooey_get_query_params()) + _, run_id, uid = extract_query_params(gui.get_query_params()) if not run_id or not uid: return