Skip to content

Commit

Permalink
Update ASR script to use Seamless M4T v2 and deprecate v1
Browse files Browse the repository at this point in the history
- Replace SEAMLESS_SUPPORTED with SEAMLESS_v2_ASR_SUPPORTED
- Update AsrModels Enum to include seamless_m4t_v2 and mark seamless_m4t as deprecated
- Modify references to use seamless_m4t_v2 model_id and supported languages
- Adjust code for selecting Seamless M4T v2 in relevant functions and scripts
- Auto detect is not technically supported by seamless, so don't allow it
  • Loading branch information
devxpy committed Aug 12, 2024
1 parent f528980 commit ba54f80
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
46 changes: 28 additions & 18 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@
"fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"
} # fmt: skip

# See page 14 of https://scontent-sea1-1.xx.fbcdn.net/v/t39.2365-6/369747868_602316515432698_2401716319310287708_n.pdf?_nc_cat=106&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=_5cpNOcftdYAX8rCrVo&_nc_ht=scontent-sea1-1.xx&oh=00_AfDVkx7XubifELxmB_Un-yEYMJavBHFzPnvTbTlalbd_1Q&oe=65141B39
# https://huggingface.co/facebook/seamless-m4t-v2-large#supported-languages
# For now, below are listed the languages that support ASR. Note that Seamless only accepts ISO 639-3 codes.
SEAMLESS_SUPPORTED = {
"afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb",
"cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin",
"hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm",
"kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld",
"nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna",
"snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie",
"xho", "yor", "yue", "zlm", "zul"
SEAMLESS_v2_ASR_SUPPORTED = {
"afr", "amh", "arb", "ary", "arz", "asm", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn",
"cmn-Hant", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "fuv", "gaz", "gle", "glg", "guj", "heb",
"hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kan", "kat", "kaz", "khk", "khm", "kir",
"kor", "lao", "lit", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob",
"npi", "nya", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa",
"srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "yor", "yue", "zul",
} # fmt: skip

AZURE_SUPPORTED = {
Expand Down Expand Up @@ -199,7 +198,8 @@
} # fmt: skip

# https://translation.ghananlp.org/api-details#api=ghananlp-translation-webservice-api
GHANA_NLP_SUPPORTED = { 'en': 'English', 'tw': 'Twi', 'gaa': 'Ga', 'ee': 'Ewe', 'fat': 'Fante', 'dag': 'Dagbani', 'gur': 'Gurene', 'yo': 'Yoruba', 'ki': 'Kikuyu', 'luo': 'Luo', 'mer': 'Kimeru' } # fmt: skip
GHANA_NLP_SUPPORTED = {'en': 'English', 'tw': 'Twi', 'gaa': 'Ga', 'ee': 'Ewe', 'fat': 'Fante', 'dag': 'Dagbani',
'gur': 'Gurene', 'yo': 'Yoruba', 'ki': 'Kikuyu', 'luo': 'Luo', 'mer': 'Kimeru'} # fmt: skip
GHANA_NLP_MAXLEN = 500


Expand All @@ -215,11 +215,22 @@ class AsrModels(Enum):
usm = "Chirp / USM (Google V2)"
deepgram = "Deepgram"
azure = "Azure Speech"
seamless_m4t = "Seamless M4T (Facebook Research)"
seamless_m4t_v2 = "Seamless M4T v2 (Facebook Research)"
mms_1b_all = "Massively Multilingual Speech (MMS) (Facebook Research)"

seamless_m4t = "Seamless M4T [Deprecated] (Facebook Research)"

def supports_auto_detect(self) -> bool:
return self not in {self.azure, self.gcp_v1, self.mms_1b_all}
return self not in {
self.azure,
self.gcp_v1,
self.mms_1b_all,
self.seamless_m4t_v2,
}

@classmethod
def _deprecated(cls):
return {cls.seamless_m4t}


asr_model_ids = {
Expand All @@ -230,7 +241,7 @@ def supports_auto_detect(self) -> bool:
AsrModels.vakyansh_bhojpuri: "Harveenchadha/vakyansh-wav2vec2-bhojpuri-bhom-60",
AsrModels.nemo_english: "https://objectstore.e2enetworks.net/indic-asr-public/checkpoints/conformer/english_large_data_fixed.nemo",
AsrModels.nemo_hindi: "https://objectstore.e2enetworks.net/indic-asr-public/checkpoints/conformer/stt_hi_conformer_ctc_large_v2.nemo",
AsrModels.seamless_m4t: "facebook/seamless-m4t-v2-large",
AsrModels.seamless_m4t_v2: "facebook/seamless-m4t-v2-large",
AsrModels.mms_1b_all: "facebook/mms-1b-all",
}

Expand All @@ -248,7 +259,7 @@ def supports_auto_detect(self) -> bool:
AsrModels.gcp_v1: GCP_V1_SUPPORTED,
AsrModels.usm: CHIRP_SUPPORTED,
AsrModels.deepgram: DEEPGRAM_SUPPORTED,
AsrModels.seamless_m4t: SEAMLESS_SUPPORTED,
AsrModels.seamless_m4t_v2: SEAMLESS_v2_ASR_SUPPORTED,
AsrModels.azure: AZURE_SUPPORTED,
AsrModels.mms_1b_all: MMS_SUPPORTED,
}
Expand Down Expand Up @@ -783,15 +794,14 @@ def run_asr(
return "\n".join(
f"Speaker {chunk['speaker']}: {chunk['text']}" for chunk in chunks
)
elif selected_model == AsrModels.seamless_m4t:
elif selected_model == AsrModels.seamless_m4t_v2:
data = call_celery_task(
"seamless",
"seamless.asr",
pipeline=dict(
model_id=asr_model_ids[AsrModels.seamless_m4t],
model_id=asr_model_ids[AsrModels.seamless_m4t_v2],
),
inputs=dict(
audio=audio_url,
task="ASR",
src_lang=language,
),
)
Expand Down
2 changes: 1 addition & 1 deletion recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def infer_asr_model_and_language(
elif "bho" in user_lang:
asr_model = AsrModels.vakyansh_bhojpuri
elif "sw" in user_lang:
asr_model = AsrModels.seamless_m4t
asr_model = AsrModels.seamless_m4t_v2
asr_lang = "swh"
else:
asr_model = default
Expand Down

0 comments on commit ba54f80

Please sign in to comment.