Skip to content

Commit

Permalink
Merge branch 'master' into examples-ux
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Dec 26, 2023
2 parents c3d8d6e + 93715d3 commit 19f682e
Show file tree
Hide file tree
Showing 44 changed files with 918 additions and 302 deletions.
3 changes: 2 additions & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class PublishedRunAdmin(admin.ModelAdmin):
]
list_filter = ["workflow"]
search_fields = ["workflow", "published_run_id"]

autocomplete_fields = ["saved_run", "created_by", "last_edited_by"]
readonly_fields = [
"open_in_gooey",
"created_at",
Expand Down Expand Up @@ -297,6 +297,7 @@ def preview_input(self, saved_run: SavedRun):
@admin.register(PublishedRunVersion)
class PublishedRunVersionAdmin(admin.ModelAdmin):
search_fields = ["id", "version_id", "published_run__published_run_id"]
autocomplete_fields = ["published_run", "saved_run", "changed_by"]


class LastActiveDeltaFilter(admin.SimpleListFilter):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Generated by Django 4.2.7 on 2023-12-21 15:25

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("bots", "0052_alter_publishedrun_options_and_more"),
]

operations = [
migrations.AlterField(
model_name="publishedrun",
name="workflow",
field=models.IntegerField(
choices=[
(1, "Doc Search"),
(2, "Doc Summary"),
(3, "Google GPT"),
(4, "Copilot"),
(5, "Lipysnc + TTS"),
(6, "Text to Speech"),
(7, "Speech Recognition"),
(8, "Lipsync"),
(9, "Deforum Animation"),
(10, "Compare Text2Img"),
(11, "Text2Audio"),
(12, "Img2Img"),
(13, "Face Inpainting"),
(14, "Google Image Gen"),
(15, "Compare AI Upscalers"),
(16, "SEO Summary"),
(17, "Email Face Inpainting"),
(18, "Social Lookup Email"),
(19, "Object Inpainting"),
(20, "Image Segmentation"),
(21, "Compare LLM"),
(22, "Chyron Plant"),
(23, "Letter Writer"),
(24, "Smart GPT"),
(25, "AI QR Code"),
(26, "Doc Extract"),
(27, "Related QnA Maker"),
(28, "Related QnA Maker Doc"),
(29, "Embeddings"),
(30, "Bulk Runner"),
(31, "Bulk Evaluator"),
]
),
),
migrations.AlterField(
model_name="savedrun",
name="workflow",
field=models.IntegerField(
choices=[
(1, "Doc Search"),
(2, "Doc Summary"),
(3, "Google GPT"),
(4, "Copilot"),
(5, "Lipysnc + TTS"),
(6, "Text to Speech"),
(7, "Speech Recognition"),
(8, "Lipsync"),
(9, "Deforum Animation"),
(10, "Compare Text2Img"),
(11, "Text2Audio"),
(12, "Img2Img"),
(13, "Face Inpainting"),
(14, "Google Image Gen"),
(15, "Compare AI Upscalers"),
(16, "SEO Summary"),
(17, "Email Face Inpainting"),
(18, "Social Lookup Email"),
(19, "Object Inpainting"),
(20, "Image Segmentation"),
(21, "Compare LLM"),
(22, "Chyron Plant"),
(23, "Letter Writer"),
(24, "Smart GPT"),
(25, "AI QR Code"),
(26, "Doc Extract"),
(27, "Related QnA Maker"),
(28, "Related QnA Maker Doc"),
(29, "Embeddings"),
(30, "Bulk Runner"),
(31, "Bulk Evaluator"),
],
default=4,
),
),
migrations.AddIndex(
model_name="publishedrun",
index=models.Index(
fields=[
"workflow",
"visibility",
"is_approved_example",
"published_run_id",
],
name="bots_publis_workflo_d3ad4e_idx",
),
),
]
17 changes: 14 additions & 3 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ class SavedRun(models.Model):
workflow = models.IntegerField(
choices=Workflow.choices, default=Workflow.VIDEO_BOTS
)
example_id = models.CharField(max_length=128, default=None, null=True, blank=True)
run_id = models.CharField(max_length=128, default=None, null=True, blank=True)
uid = models.CharField(max_length=128, default=None, null=True, blank=True)

Expand All @@ -189,8 +188,6 @@ class SavedRun(models.Model):
error_msg = models.TextField(default="", blank=True)
run_time = models.DurationField(default=datetime.timedelta, blank=True)
run_status = models.TextField(default="", blank=True)
page_title = models.TextField(default="", blank=True)
page_notes = models.TextField(default="", blank=True)

hidden = models.BooleanField(default=False)
is_flagged = models.BooleanField(default=False)
Expand All @@ -208,6 +205,12 @@ class SavedRun(models.Model):
updated_at = models.DateTimeField(auto_now=True)
created_at = models.DateTimeField(auto_now_add=True)

example_id = models.CharField(
max_length=128, default=None, null=True, blank=True, help_text="(Deprecated)"
)
page_title = models.TextField(default="", blank=True, help_text="(Deprecated)")
page_notes = models.TextField(default="", blank=True, help_text="(Deprecated)")

objects = SavedRunQuerySet.as_manager()

class Meta:
Expand Down Expand Up @@ -1065,6 +1068,14 @@ class Meta:
models.Index(fields=["workflow", "created_by"]),
models.Index(fields=["workflow", "published_run_id"]),
models.Index(fields=["workflow", "visibility", "is_approved_example"]),
models.Index(
fields=[
"workflow",
"visibility",
"is_approved_example",
"published_run_id",
]
),
]

def __str__(self):
Expand Down
35 changes: 19 additions & 16 deletions daras_ai_v2/api_examples_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,20 @@ def api_example_generator(
"""
1. Generate an api key [below👇](#api-keys)
2. Install [curl](https://everything.curl.dev/get) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config).
2. Install [curl](https://everything.curl.dev/get) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config).
```bash
export GOOEY_API_KEY=sk-xxxx
```
3. Run the following `curl` command in your terminal.
3. Run the following `curl` command in your terminal.
If you encounter any issues, write to us at [email protected] and make sure to include the full curl command and the error message.
```bash
%s
```
"""
% curl_code.strip()
% curl_code.strip(),
unsafe_allow_html=True,
)

with python:
Expand Down Expand Up @@ -157,8 +158,8 @@ def api_example_generator(
)
if as_async:
py_code += r"""
from time import sleep
from time import sleep
status_url = response.headers["Location"]
while True:
response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]})
Expand Down Expand Up @@ -188,20 +189,21 @@ def api_example_generator(
rf"""
1. Generate an api key [below👇](#api-keys)
2. Install [requests](https://requests.readthedocs.io/en/latest/) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config).
2. Install [requests](https://requests.readthedocs.io/en/latest/) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config).
```bash
$ python3 -m pip install requests
$ export GOOEY_API_KEY=sk-xxxx
```
3. Use this sample code to call the API.
3. Use this sample code to call the API.
If you encounter any issues, write to us at [email protected] and make sure to include the full code snippet and the error message.
```python
%s
```
"""
% py_code
% py_code,
unsafe_allow_html=True,
)

with js:
Expand Down Expand Up @@ -276,7 +278,7 @@ def api_example_generator(
if (!response.ok) {
throw new Error(response.status);
}
const result = await response.json();
if (result.status === "completed") {
console.log(response.status, result);
Expand All @@ -302,18 +304,19 @@ def api_example_generator(
r"""
1. Generate an api key [below👇](#api-keys)
2. Install [node-fetch](https://www.npmjs.com/package/node-fetch) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config) and don't use direcly in the browser.
2. Install [node-fetch](https://www.npmjs.com/package/node-fetch) & add the `GOOEY_API_KEY` to your environment variables.
Never store the api key [in your code](https://12factor.net/config) and don't use direcly in the browser.
```bash
$ npm install node-fetch
$ export GOOEY_API_KEY=sk-xxxx
```
3. Use this sample code to call the API.
3. Use this sample code to call the API.
If you encounter any issues, write to us at [email protected] and make sure to include the full code snippet and the error message.
```js
%s
```
"""
% js_code
% js_code,
unsafe_allow_html=True,
)
63 changes: 60 additions & 3 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
import tempfile
from enum import Enum
from time import sleep

import langcodes
import requests
Expand All @@ -12,17 +13,16 @@

import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gdrive_downloader import (
is_gdrive_url,
gdrive_download,
gdrive_metadata,
url_to_gdrive_file_id,
)
from daras_ai_v2 import settings
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.redis_cache import redis_cache_decorator
from time import sleep

SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB

Expand All @@ -49,6 +49,7 @@

class AsrModels(Enum):
whisper_large_v2 = "Whisper Large v2 (openai)"
whisper_large_v3 = "Whisper Large v3 (openai)"
whisper_hindi_large_v2 = "Whisper Hindi Large v2 (Bhashini)"
whisper_telugu_large_v2 = "Whisper Telugu Large v2 (Bhashini)"
nemo_english = "Conformer English (ai4bharat.org)"
Expand All @@ -66,6 +67,7 @@ def supports_auto_detect(self) -> bool:


asr_model_ids = {
AsrModels.whisper_large_v3: "vaibhavs10/incredibly-fast-whisper:37dfc0d6a7eb43ff84e230f74a24dab84e6bb7756c9b457dbdcceca3de7a4a04",
AsrModels.whisper_large_v2: "openai/whisper-large-v2",
AsrModels.whisper_hindi_large_v2: "vasista22/whisper-hindi-large-v2",
AsrModels.whisper_telugu_large_v2: "vasista22/whisper-telugu-large-v2",
Expand All @@ -84,6 +86,7 @@ def supports_auto_detect(self) -> bool:
}

asr_supported_languages = {
AsrModels.whisper_large_v3: WHISPER_SUPPORTED,
AsrModels.whisper_large_v2: WHISPER_SUPPORTED,
AsrModels.usm: CHIRP_SUPPORTED,
AsrModels.deepgram: DEEPGRAM_SUPPORTED,
Expand Down Expand Up @@ -154,6 +157,34 @@ def google_translate_languages() -> dict[str, str]:
}


@redis_cache_decorator
def google_translate_input_languages() -> dict[str, str]:
"""
Get list of supported languages for Google Translate.
:return: Dictionary of language codes and display names.
"""
from google.cloud import translate

_, project = get_google_auth_session()
parent = f"projects/{project}/locations/global"
client = translate.TranslationServiceClient()
supported_languages = client.get_supported_languages(
parent=parent, display_language_code="en"
)
return {
lang.language_code: lang.display_name
for lang in supported_languages.languages
if lang.support_source
}


def get_language_in_collection(langcode: str, languages):
for lang in languages:
if langcodes.get(lang).language == langcodes.get(langcode).language:
return langcode
return None


def asr_language_selector(
selected_model: AsrModels,
label="##### Spoken Language",
Expand Down Expand Up @@ -206,6 +237,19 @@ def run_google_translate(
"""
from google.cloud import translate_v2 as translate

# convert to BCP-47 format (google handles consistent language codes but sometimes gets confused by a mix of iso2 and iso3 which we have)
if source_language:
source_language = langcodes.Language.get(source_language).to_tag()
source_language = get_language_in_collection(
source_language, google_translate_input_languages().keys()
) # this will default to autodetect if language is not found as supported
target_language = langcodes.Language.get(target_language).to_tag()
target_language: str | None = get_language_in_collection(
target_language, google_translate_languages().keys()
)
if not target_language:
raise ValueError(f"Unsupported target language: {target_language!r}")

# if the language supports transliteration, we should check if the script is Latin
if source_language and source_language not in TRANSLITERATION_SUPPORTED:
language_codes = [source_language] * len(texts)
Expand Down Expand Up @@ -358,6 +402,19 @@ def run_asr(

if selected_model == AsrModels.azure:
return azure_asr(audio_url, language)
elif selected_model == AsrModels.whisper_large_v3:
import replicate

config = {
"audio": audio_url,
"return_timestamps": output_format != AsrOutputFormat.text,
}
if language:
config["language"] = language
data = replicate.run(
asr_model_ids[AsrModels.whisper_large_v3],
input=config,
)
elif selected_model == AsrModels.deepgram:
r = requests.post(
"https://api.deepgram.com/v1/listen",
Expand Down
Loading

0 comments on commit 19f682e

Please sign in to comment.