diff --git a/.github/workflows/gitleaks.yml b/.github/workflows/gitleaks.yml new file mode 100644 index 000000000..79ffcb1a2 --- /dev/null +++ b/.github/workflows/gitleaks.yml @@ -0,0 +1,15 @@ +name: gitleaks +on: [pull_request, push, workflow_dispatch] +jobs: + scan: + name: gitleaks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: gitleaks/gitleaks-action@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITLEAKS_LICENSE: ${{ secrets.GITLEAKS_LICENSE}} + GITLEAKS_NOTIFY_USER_LIST: '@sandergi' diff --git a/.gitleaksignore b/.gitleaksignore new file mode 100644 index 000000000..022ac599f --- /dev/null +++ b/.gitleaksignore @@ -0,0 +1,21 @@ +4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:aws-access-token:16 +4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:private-key:23 +4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:generic-api-key:32 +b0c80dac8e22faafa319d5466947df8723dfaa4a:daras_ai_v2/img_model_settings_widgets.py:generic-api-key:372 +8670036e722f40530dbff3e0e7573e9b5aac85c9:routers/slack.py:slack-webhook-url:73 +b6ad1fc0168832711adcff07287907660f3305fb:bots/location.py:generic-api-key:12 +8c05ec8320a866304842fb5f4df76e0698f1031f:bots/analysis.py:generic-api-key:5 +1c03d569dd30bb9703e4ff968a57a05eb405e398:bots/signals.py:generic-api-key:11 +5e3dd6cf0da20b3e5b1daaca41ad126bc489fbf3:static/js/auth.js:generic-api-key:2 +87e443addbbc49746ab3088307a59b3e2fc2d177:recipes/CompareText2Img.py:generic-api-key:97 +1f109a743b1781c7a21c1b0ca6a3f880f7f7dc84:recipes/CompareText2Img.py:generic-api-key:77 +d18d8b9bb18a9ff8248b16b26f0455f7826ce23a:recipes/CompareText2Img.py:generic-api-key:85 +5471a8ac2d60026b24f21b51ae6f11db8acd160c:pages/CompareText2Img.py:generic-api-key:92 +5471a8ac2d60026b24f21b51ae6f11db8acd160c:daras_ai_v2/img_model_settings_widgets.py:generic-api-key:90 +6fca6072032e4f34d7d571e7de8e0ff05f7a487b:static/js/auth.js:generic-api-key:2 +2292469b22d97263c7c59cf49fae7281ce96a39c:pages/CompareText2Img.py:generic-api-key:137 +aae9d67ed6330a3eb2ede41d5ceeca52a8f0daf4:static/js/auth.js:gcp-api-key:2 +d5866242d107743ab5eebeb284e7e5ee2426d941:pages/SocialLookupEmail.py:generic-api-key:181 +73bef8c3be7682fed0b99ceb6770f599eabbbd08:daras_ai_v2/send_email.py:generic-api-key:25 +fa3f7982fa1527838c2073d2542c83887cc6ebbd:pages/EmailFaceInpainting.py:generic-api-key:189 +e1c218882d288ca1df0225654aae8dd10027e9d0:political_example.py:jwt:30 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 65af59112..1bc20673c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer - id: check-yaml @@ -13,3 +13,7 @@ repos: entry: poetry run black language: system types: [python] +- repo: https://github.com/gitleaks/gitleaks + rev: v8.18.4 + hooks: + - id: gitleaks diff --git a/README.md b/README.md index b3ddddba0..299c84839 100644 --- a/README.md +++ b/README.md @@ -208,3 +208,9 @@ docker run \ ### 📐 Code Formatting Use black - https://pypi.org/project/black + +### 💣 Secret Scanning + +Gitleaks will automatically run pre-commit (see `pre-commit-config.yaml` for details) to prevent commits with secrets in the first place. To test this without committing, run `pre-commit` from the terminal. To skip this check, use `SKIP=gitleaks git commit -m "message"` to commit changes. Preferably, label false positives with the `#gitleaks:allow` comment instead of skipping the check. + +Gitleaks will also run in the CI pipeline as a GitHub action on push and pull request (can also be manually triggered in the actions tab on GitHub). To update the baseline of ignored secrets, run `python ./scripts/create_gitleaks_baseline.py` from the venv and commit the changes to `.gitleaksignore`. diff --git a/bots/admin.py b/bots/admin.py index 4b6a731c8..0b8b28d10 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -438,7 +438,7 @@ def rerun_tasks(self, request, queryset): page = Workflow(sr.workflow).page_cls( request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) ) - page.call_runner_task(sr) + page.call_runner_task(sr, deduct_credits=False) self.message_user( request, f"Started re-running {queryset.count()} tasks in the background.", diff --git a/bots/models.py b/bots/models.py index 0407569d5..51eaebc59 100644 --- a/bots/models.py +++ b/bots/models.py @@ -363,6 +363,7 @@ def submit_api_call( request_body: dict, enable_rate_limits: bool = False, parent_pr: "PublishedRun" = None, + deduct_credits: bool = True, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call @@ -384,6 +385,7 @@ def submit_api_call( user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, + deduct_credits=deduct_credits, ), ) @@ -1818,12 +1820,14 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, + deduct_credits: bool = True, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: return self.saved_run.submit_api_call( current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, parent_pr=self, + deduct_credits=deduct_credits, ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 794b5a061..2651fd80c 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -41,6 +41,7 @@ def runner_task( uid: str, channel: str, unsaved_state: dict[str, typing.Any] = None, + deduct_credits: bool = True, ) -> int: start_time = time() error_msg = None @@ -107,7 +108,8 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # run completed successfully, deduct credits else: - sr.transaction, sr.price = page.deduct_credits(gui.session_state) + if deduct_credits: + sr.transaction, sr.price = page.deduct_credits(gui.session_state) # save everything, mark run as completed finally: diff --git a/daras_ai_v2/analysis_results.py b/daras_ai_v2/analysis_results.py index 5bd519982..6e43e3585 100644 --- a/daras_ai_v2/analysis_results.py +++ b/daras_ai_v2/analysis_results.py @@ -69,7 +69,7 @@ def render_analysis_results_page( gui.session_state.setdefault("selected_graphs", graphs) selected_graphs = list_view_editor( - add_btn_label="➕ Add a Graph", + add_btn_label="Add a Graph", key="selected_graphs", render_inputs=partial(render_inputs, results), ) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index c09102fe4..96b5f2e70 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -75,16 +75,15 @@ "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy" } # fmt: skip -# See page 14 of https://scontent-sea1-1.xx.fbcdn.net/v/t39.2365-6/369747868_602316515432698_2401716319310287708_n.pdf?_nc_cat=106&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=_5cpNOcftdYAX8rCrVo&_nc_ht=scontent-sea1-1.xx&oh=00_AfDVkx7XubifELxmB_Un-yEYMJavBHFzPnvTbTlalbd_1Q&oe=65141B39 +# https://huggingface.co/facebook/seamless-m4t-v2-large#supported-languages # For now, below are listed the languages that support ASR. Note that Seamless only accepts ISO 639-3 codes. -SEAMLESS_SUPPORTED = { - "afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", - "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", - "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", - "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", - "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", - "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", - "xho", "yor", "yue", "zlm", "zul" +SEAMLESS_v2_ASR_SUPPORTED = { + "afr", "amh", "arb", "ary", "arz", "asm", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn", + "cmn-Hant", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "fuv", "gaz", "gle", "glg", "guj", "heb", + "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kan", "kat", "kaz", "khk", "khm", "kir", + "kor", "lao", "lit", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob", + "npi", "nya", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", + "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "yor", "yue", "zul", } # fmt: skip AZURE_SUPPORTED = { @@ -199,7 +198,8 @@ } # fmt: skip # https://translation.ghananlp.org/api-details#api=ghananlp-translation-webservice-api -GHANA_NLP_SUPPORTED = { 'en': 'English', 'tw': 'Twi', 'gaa': 'Ga', 'ee': 'Ewe', 'fat': 'Fante', 'dag': 'Dagbani', 'gur': 'Gurene', 'yo': 'Yoruba', 'ki': 'Kikuyu', 'luo': 'Luo', 'mer': 'Kimeru' } # fmt: skip +GHANA_NLP_SUPPORTED = {'en': 'English', 'tw': 'Twi', 'gaa': 'Ga', 'ee': 'Ewe', 'fat': 'Fante', 'dag': 'Dagbani', + 'gur': 'Gurene', 'yo': 'Yoruba', 'ki': 'Kikuyu', 'luo': 'Luo', 'mer': 'Kimeru'} # fmt: skip GHANA_NLP_MAXLEN = 500 @@ -215,11 +215,22 @@ class AsrModels(Enum): usm = "Chirp / USM (Google V2)" deepgram = "Deepgram" azure = "Azure Speech" - seamless_m4t = "Seamless M4T (Facebook Research)" + seamless_m4t_v2 = "Seamless M4T v2 (Facebook Research)" mms_1b_all = "Massively Multilingual Speech (MMS) (Facebook Research)" + seamless_m4t = "Seamless M4T [Deprecated] (Facebook Research)" + def supports_auto_detect(self) -> bool: - return self not in {self.azure, self.gcp_v1, self.mms_1b_all} + return self not in { + self.azure, + self.gcp_v1, + self.mms_1b_all, + self.seamless_m4t_v2, + } + + @classmethod + def _deprecated(cls): + return {cls.seamless_m4t} asr_model_ids = { @@ -230,7 +241,7 @@ def supports_auto_detect(self) -> bool: AsrModels.vakyansh_bhojpuri: "Harveenchadha/vakyansh-wav2vec2-bhojpuri-bhom-60", AsrModels.nemo_english: "https://objectstore.e2enetworks.net/indic-asr-public/checkpoints/conformer/english_large_data_fixed.nemo", AsrModels.nemo_hindi: "https://objectstore.e2enetworks.net/indic-asr-public/checkpoints/conformer/stt_hi_conformer_ctc_large_v2.nemo", - AsrModels.seamless_m4t: "facebook/hf-seamless-m4t-large", + AsrModels.seamless_m4t_v2: "facebook/seamless-m4t-v2-large", AsrModels.mms_1b_all: "facebook/mms-1b-all", } @@ -248,7 +259,7 @@ def supports_auto_detect(self) -> bool: AsrModels.gcp_v1: GCP_V1_SUPPORTED, AsrModels.usm: CHIRP_SUPPORTED, AsrModels.deepgram: DEEPGRAM_SUPPORTED, - AsrModels.seamless_m4t: SEAMLESS_SUPPORTED, + AsrModels.seamless_m4t_v2: SEAMLESS_v2_ASR_SUPPORTED, AsrModels.azure: AZURE_SUPPORTED, AsrModels.mms_1b_all: MMS_SUPPORTED, } @@ -783,15 +794,14 @@ def run_asr( return "\n".join( f"Speaker {chunk['speaker']}: {chunk['text']}" for chunk in chunks ) - elif selected_model == AsrModels.seamless_m4t: + elif selected_model == AsrModels.seamless_m4t_v2: data = call_celery_task( - "seamless", + "seamless.asr", pipeline=dict( - model_id=asr_model_ids[AsrModels.seamless_m4t], + model_id=asr_model_ids[AsrModels.seamless_m4t_v2], ), inputs=dict( audio=audio_url, - task="ASR", src_lang=language, ), ) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 8e3610a28..7ee73dcfc 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -136,7 +136,7 @@ class BasePage: class RequestModel(BaseModel): functions: list[RecipeFunction] | None = Field( - title="🧩 Functions", + title="🧩 Developer Tools and Functions", ) variables: dict[str, typing.Any] = Field( None, @@ -1308,9 +1308,7 @@ def get_credits_click_url(self): return "/account/" def get_submit_container_props(self): - return dict( - className="position-sticky bottom-0 bg-white", style=dict(zIndex=100) - ) + return dict(className="position-sticky bottom-0 bg-white") def render_submit_button(self, key="--submit-1"): with gui.div(**self.get_submit_container_props()): @@ -1478,15 +1476,17 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True + show_settings = True def _render_input_col(self): self.render_form_v2() placeholder = gui.div() - with gui.expander("⚙️ Settings"): - if self.functions_in_settings: - functions_input(self.request.user) - self.render_settings() + if self.show_settings: + with gui.expander("⚙️ Settings"): + self.render_settings() + if self.functions_in_settings: + functions_input(self.request.user) with placeholder: self.render_variables() @@ -1501,7 +1501,6 @@ def _render_input_col(self): def render_variables(self): if not self.functions_in_settings: - gui.write("---") functions_input(self.request.user) variables_input( template_keys=self.template_keys, allow_add=is_functions_enabled() @@ -1687,7 +1686,7 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun): } ) - def call_runner_task(self, sr: SavedRun): + def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): from celeryapp.tasks import runner_task, post_runner_tasks chain = ( @@ -1698,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun): uid=sr.uid, channel=self.realtime_channel_name(sr.run_id, sr.uid), unsaved_state=self._unsaved_state(), + deduct_credits=deduct_credits, ) | post_runner_tasks.s() ) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 4958f389d..2782c2eab 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -1,12 +1,13 @@ +import json from itertools import zip_longest from textwrap import dedent +import gooey_gui as gui from django.core.exceptions import ValidationError from django.db import transaction from django.utils.text import slugify from furl import furl -import gooey_gui as gui from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform from daras_ai_v2 import settings, icons @@ -94,7 +95,7 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): input_analysis_runs.append(dict(saved_run=sr, published_run=None)) list_view_editor( - add_btn_label="➕ Add", + add_btn_label="Add", key="analysis_urls", render_inputs=render_workflow_url_input, flatten_dict_key="url", @@ -311,7 +312,7 @@ def get_bot_test_link(bi: BotIntegration) -> str | None: return None -def get_web_widget_embed_code(bi: BotIntegration) -> str: +def get_web_widget_embed_code(bi: BotIntegration, *, config: dict = None) -> str: lib_src = get_app_route_url( chat_lib_route, path_params=dict( @@ -319,11 +320,19 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str: integration_name=slugify(bi.name) or "untitled", ), ).rstrip("/") + if config is None: + config = {} return dedent( - f""" + """
- + + """ + % dict(config_json=json.dumps(config), lib_src=lib_src) ).strip() @@ -375,7 +384,7 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): mode="inline", showSources=True, enablePhotoUpload=False, - enableLipsyncVideo=False, + autoPlayResponses=True, enableAudioMessage=True, branding=( dict(showPoweredByGooey=True) @@ -397,8 +406,8 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): config["enableAudioMessage"] = gui.checkbox( "Enable Audio Message", value=config["enableAudioMessage"] ) - config["enableLipsyncVideo"] = gui.checkbox( - "Enable Lipsync Video", value=config["enableLipsyncVideo"] + config["autoPlayResponses"] = gui.checkbox( + "Auto-play responses", value=config["autoPlayResponses"] ) # config["branding"]["showPoweredByGooey"] = gui.checkbox( # "Show Powered By Gooey", value=config["branding"]["showPoweredByGooey"] diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 58a154590..90fef04ee 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -234,18 +234,6 @@ def _echo(bot, input_text): return msgs_to_save, response_audio, response_text, response_video -def _mock_api_output(input_text): - return { - "url": "https://gooey.ai?example_id=mock-api-example", - "output": { - "input_text": input_text, - "raw_input_text": input_text, - "raw_output_text": [f"echo: ```{input_text}```\nhttps://www.youtube.com/"], - "output_text": [f"echo: ```{input_text}```\nhttps://www.youtube.com/"], - }, - } - - def msg_handler(bot: BotInterface): try: _msg_handler(bot) @@ -369,20 +357,25 @@ def _process_and_send_msg( # get latest messages for context saved_msgs = bot.convo.msgs_for_llm_context() - # # mock testing - # result = _mock_api_output(input_text) + variables = (bot.saved_run.state.get("variables") or {}) | build_run_vars( + bot.convo, bot.user_msg_id + ) body = dict( input_prompt=input_text, input_audio=input_audio, input_images=input_images, input_documents=input_documents, messages=saved_msgs, - variables=build_run_vars(bot.convo, bot.user_msg_id), + variables=variables, ) if bot.user_language: body["user_language"] = bot.user_language if bot.request_overrides: body = bot.request_overrides | body + try: + variables.update(bot.request_overrides["variables"]) + except KeyError: + pass page, result, run_id, uid = submit_api_call( page_cls=bot.page_cls, user=billing_account_user, diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py index 52bc36ea4..6f251ff35 100644 --- a/daras_ai_v2/icons.py +++ b/daras_ai_v2/icons.py @@ -17,6 +17,7 @@ preview = '' time = '' email = '' +add = '' code = '' chat = '' diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index dffe49060..533be6410 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -32,7 +32,6 @@ ) from functions.recipe_functions import LLMTools -DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." DEFAULT_JSON_PROMPT = ( "Please respond directly in JSON format. " "Don't output markdown or HTML, instead print the JSON object directly without formatting." @@ -308,11 +307,26 @@ class LargeLanguageModels(Enum): ) sea_lion_7b_instruct = LLMSpec( - label="SEA-LION-7B-Instruct (aisingapore)", + label="SEA-LION-7B-Instruct [Deprecated] (aisingapore)", model_id="aisingapore/sea-lion-7b-instruct", llm_api=LLMApis.self_hosted, context_window=2048, price=1, + is_deprecated=True, + ) + llama3_8b_cpt_sea_lion_v2_instruct = LLMSpec( + label="Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)", + model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct", + llm_api=LLMApis.self_hosted, + context_window=8192, + price=1, + ) + sarvam_2b = LLMSpec( + label="Sarvam 2B (sarvamai)", + model_id="sarvamai/sarvam-2b-v0.5", + llm_api=LLMApis.self_hosted, + context_window=2048, + price=1, ) # https://platform.openai.com/docs/models/gpt-3 @@ -452,7 +466,6 @@ def run_language_model( if prompt and not messages: # convert text prompt to chat messages messages = [ - format_chat_entry(role=CHATML_ROLE_SYSTEM, content=DEFAULT_SYSTEM_MSG), format_chat_entry(role=CHATML_ROLE_USER, content=prompt), ] if not model.is_vision_model: @@ -599,6 +612,17 @@ def _run_text_model( temperature=temperature, stop=stop, ) + case LLMApis.self_hosted: + return [ + _run_self_hosted_llm( + model=model, + text_inputs=prompt, + max_tokens=max_tokens, + temperature=temperature, + avoid_repetition=avoid_repetition, + stop=stop, + ) + ] case _: raise UserError(f"Unsupported text api: {api}") @@ -674,14 +698,19 @@ def _run_chat_model( stop=stop, ) case LLMApis.self_hosted: - return _run_self_hosted_chat( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - avoid_repetition=avoid_repetition, - stop=stop, - ) + return [ + { + "role": CHATML_ROLE_ASSISTANT, + "content": _run_self_hosted_llm( + model=model, + text_inputs=messages, + max_tokens=max_tokens, + temperature=temperature, + avoid_repetition=avoid_repetition, + stop=stop, + ), + }, + ] # case LLMApis.together: # if tools: # raise UserError("Only OpenAI chat models support Tools") @@ -697,32 +726,36 @@ def _run_chat_model( raise UserError(f"Unsupported chat api: {api}") -def _run_self_hosted_chat( +def _run_self_hosted_llm( *, model: str, - messages: list[ConversationEntry], + text_inputs: list[ConversationEntry] | str, max_tokens: int, temperature: float, avoid_repetition: bool, stop: list[str] | None, -) -> list[dict]: +) -> str: from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku # sea lion doesnt support system prompt - if model == LargeLanguageModels.sea_lion_7b_instruct.model_id: - for i, entry in enumerate(messages): + if ( + not isinstance(text_inputs, str) + and model == LargeLanguageModels.sea_lion_7b_instruct.model_id + ): + for i, entry in enumerate(text_inputs): if entry["role"] == CHATML_ROLE_SYSTEM: - messages[i]["role"] = CHATML_ROLE_USER - messages.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content="")) + text_inputs[i]["role"] = CHATML_ROLE_USER + text_inputs.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content="")) ret = call_celery_task( "llm.chat", pipeline=dict( model_id=model, + fallback_chat_template_from="meta-llama/Llama-2-7b-chat-hf", ), inputs=dict( - messages=messages, + text_inputs=text_inputs, max_new_tokens=max_tokens, stop_strings=stop, temperature=temperature, @@ -742,12 +775,7 @@ def _run_self_hosted_chat( quantity=usage["completion_tokens"], ) - return [ - { - "role": CHATML_ROLE_ASSISTANT, - "content": ret["generated_text"], - } - ] + return ret["generated_text"] def _run_anthropic_chat( diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index 7a74c7116..c315da996 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -3,22 +3,31 @@ from datetime import datetime from types import SimpleNamespace +import gooey_gui as gui import jinja2 import jinja2.meta import jinja2.sandbox -import gooey_gui as gui +from daras_ai_v2 import icons def variables_input( *, template_keys: typing.Iterable[str], label: str = "###### ⌥ Variables", + description: str = "Variables let you pass custom parameters to your workflow. Access a variable in your instruction prompt with Jinja, e.g. `{{ my_variable }}`\n ", key: str = "variables", allow_add: bool = False, ): from daras_ai_v2.workflow_url_input import del_button + def render_title_desc(): + gui.write(label) + gui.caption( + f"{description} Learn more.", + unsafe_allow_html=True, + ) + # find all variables in the prompts env = jinja2.sandbox.SandboxedEnvironment() template_var_names = set() @@ -56,7 +65,7 @@ def variables_input( continue if not title_shown: - gui.write(label) + render_title_desc() title_shown = True col1, col2 = gui.columns([11, 1], responsive=False) @@ -93,7 +102,7 @@ def variables_input( if allow_add: if not title_shown: - gui.write(label) + render_title_desc() gui.newline() col1, col2, _ = gui.columns([6, 2, 4], responsive=False) with col1: @@ -105,7 +114,7 @@ def variables_input( ) with col2: gui.button( - ' Add', + f"{icons.add} Add", key=var_add_key, type="tertiary", ) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 841817d94..338d22614 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -33,6 +33,7 @@ def safety_checker_text(text_input: str): .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), + deduct_credits=False, ) ) diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index a37929378..5c5f7b2e9 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -21,6 +21,7 @@ def workflow_url_input( del_key: str = None, current_user: AppUser | None = None, allow_none: bool = False, + include_root: bool = True ) -> tuple[typing.Type[BasePage], SavedRun, PublishedRun | None] | None: init_workflow_selector(internal_state, key) @@ -38,7 +39,9 @@ def workflow_url_input( else: internal_state["workflow"] = page_cls.workflow with col1: - options = get_published_run_options(page_cls, current_user=current_user) + options = get_published_run_options( + page_cls, current_user=current_user, include_root=include_root + ) options.update(internal_state.get("--added_workflows", {})) with gui.div(className="pt-1"): url = gui.selectbox( @@ -143,6 +146,7 @@ def url_to_runs( def get_published_run_options( page_cls: typing.Type[BasePage], current_user: AppUser | None = None, + include_root: bool = True, ) -> dict[str, str]: # approved examples pr_query = Q(is_approved_example=True, visibility=PublishedRunVisibility.PUBLIC) @@ -166,13 +170,15 @@ def get_published_run_options( pr.updated_at, # newer first ), ) - - options = { - # root recipe - page_cls.get_root_published_run().get_app_url(): "Default", - } | { + options_dict = { pr.get_app_url(): get_title_breadcrumbs(page_cls, pr.saved_run, pr).h1_title for pr in saved_runs_and_examples } - return options + if include_root: + # include root recipe if requested + options_dict = { + page_cls.get_root_published_run().get_app_url(): "Default", + } | options_dict + + return options_dict diff --git a/functions/migrations/0002_alter_calledfunction_trigger.py b/functions/migrations/0002_alter_calledfunction_trigger.py new file mode 100644 index 000000000..a373cbc25 --- /dev/null +++ b/functions/migrations/0002_alter_calledfunction_trigger.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-08-12 12:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('functions', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='calledfunction', + name='trigger', + field=models.IntegerField(choices=[(1, 'Before'), (2, 'After')]), + ), + ] diff --git a/functions/models.py b/functions/models.py index 0901fb109..be1be7a5e 100644 --- a/functions/models.py +++ b/functions/models.py @@ -13,8 +13,8 @@ class _TriggerData(typing.NamedTuple): class FunctionTrigger(_TriggerData, GooeyEnum): - pre = _TriggerData(label="Pre", db_value=1) - post = _TriggerData(label="Post", db_value=2) + pre = _TriggerData(label="Before", db_value=1) + post = _TriggerData(label="After", db_value=2) class RecipeFunction(BaseModel): diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index b43ea16e6..b7fd36fdb 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -53,6 +53,7 @@ def call_recipe_functions( request_body=dict( variables=sr.state.get("variables", {}) | variables | fn_vars, ), + deduct_credits=False, ) CalledFunction.objects.create( @@ -117,7 +118,7 @@ def render_function_input(list_key: str, del_key: str, d: dict): from daras_ai_v2.workflow_url_input import workflow_url_input from recipes.Functions import FunctionsPage - col1, col2 = gui.columns([2, 10], responsive=False) + col1, col2 = gui.columns([3, 9], responsive=True) with col1: col1.node.props["className"] += " pt-1" d["trigger"] = enum_selector( @@ -133,7 +134,9 @@ def render_function_input(list_key: str, del_key: str, d: dict): internal_state=d, del_key=del_key, current_user=current_user, + include_root=False, ) + col2.node.children[0].props["className"] += " col-12" if gui.checkbox( f"##### {field_title_desc(BasePage.RequestModel, key)}", @@ -141,12 +144,18 @@ def render_function_input(list_key: str, del_key: str, d: dict): value=key in gui.session_state, ): gui.session_state.setdefault(key, [{}]) + with gui.div(className="d-flex align-items-center"): + gui.write("###### Functions") + gui.caption( + "Functions give your workflow the ability run Javascript code (with webcalls!) allowing it execute logic, use common JS libraries or make external API calls before or after the workflow runs. Learn more.", + unsafe_allow_html=True, + ) list_view_editor( - add_btn_label="➕ Add Function", + add_btn_label="Add Function", + add_btn_type="tertiary", key=key, render_inputs=render_function_input, ) - gui.write("---") else: gui.session_state.pop(key, None) diff --git a/handles/admin.py b/handles/admin.py index cf90e811f..1a86567ee 100644 --- a/handles/admin.py +++ b/handles/admin.py @@ -1,8 +1,17 @@ from django.contrib import admin +from app_users.admin import AppUserAdmin from .models import Handle @admin.register(Handle) class HandleAdmin(admin.ModelAdmin): - search_fields = ["name", "redirect_url"] + search_fields = ["name", "redirect_url"] + [ + f"user__{field}" for field in AppUserAdmin.search_fields + ] + readonly_fields = ["user", "created_at", "updated_at"] + + list_filter = [ + ("user", admin.EmptyFieldListFilter), + ("redirect_url", admin.EmptyFieldListFilter), + ] diff --git a/handles/models.py b/handles/models.py index e3371dbac..b027736a0 100644 --- a/handles/models.py +++ b/handles/models.py @@ -28,6 +28,10 @@ "about", "blog", "sales", + "js", + "css", + "assets", + "favicon.ico", ] COMMON_EMAIL_DOMAINS = [ "gmail.com", diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index ce25e512d..7f92a53e1 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -239,7 +239,7 @@ def render_inputs(key: str, del_key: str, d: EvalPrompt): gui.write("##### " + field_title_desc(self.RequestModel, "eval_prompts")) list_view_editor( - add_btn_label="➕ Add a Prompt", + add_btn_label="Add a Prompt", key="eval_prompts", render_inputs=render_inputs, ) @@ -261,7 +261,7 @@ def render_agg_inputs(key: str, del_key: str, d: AggFunction): gui.html("