diff --git a/nemo/collections/common/prompts/canary2.py b/nemo/collections/common/prompts/canary2.py index 169734cadf9e..830af84fc999 100644 --- a/nemo/collections/common/prompts/canary2.py +++ b/nemo/collections/common/prompts/canary2.py @@ -110,11 +110,9 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s any_special_token_present = False - lang_dict_compat = {"en": "en-US", "es": "es-ES", "fr": "fr-FR", "de": "de-DE"} for k in ("source_lang", "target_lang"): if k in slot_values and not ((v := slot_values[k]).startswith("<|") and v.endswith("|>")): val = slot_values[k] - val = lang_dict_compat.get(val, val) slot_values[k] = "<|" + val + "|>" any_special_token_present = True @@ -130,9 +128,6 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s # and slots for this turn correspond to user prompt. if any_special_token_present and PromptFormatter.PROMPT_LANGUAGE_SLOT not in slot_values: slot_values[PromptFormatter.PROMPT_LANGUAGE_SLOT] = CANARY_SPECIAL_TOKENIZER - else: - if (l := slot_values.get(PromptFormatter.PROMPT_LANGUAGE_SLOT)) is not None: - slot_values[PromptFormatter.PROMPT_LANGUAGE_SLOT] = lang_dict_compat.get(l, l) return slot_values diff --git a/scripts/speech_recognition/canary/build_canary_2_special_tokenizer.py b/scripts/speech_recognition/canary/build_canary_2_special_tokenizer.py index 4efa593a8645..96016fc11974 100644 --- a/scripts/speech_recognition/canary/build_canary_2_special_tokenizer.py +++ b/scripts/speech_recognition/canary/build_canary_2_special_tokenizer.py @@ -38,53 +38,22 @@ def main(output_dir: str) -> None: "<|emo:sad|>", "<|emo:angry|>", ] + # Language special tokens + [ - # Language special tokens "<|unklang|>", - "<|ar-AR|>", - "<|cs-CZ|>", - "<|da-DA|>", - "<|de-DE|>", - "<|en-US|>", - "<|en-GB|>", - "<|es-US|>", - "<|es-ES|>", - "<|fr-CA|>", - "<|fr-FR|>", - "<|hi-IN|>", - "<|he-IL|>", - "<|it-IT|>", - "<|ja-JP|>", - "<|ko-KR|>", - "<|nb-NO|>", - "<|nl-NL|>", - "<|nn-NO|>", - "<|pl-PO|>", - "<|pt-PT|>", - "<|pt-BR|>", - "<|ru-RU|>", - "<|sv-SW|>", - "<|th-TH|>", - "<|tr-TR|>", - "<|zh-CN|>", - ] - + [ - # Timestamp frame special tokens - f"<|{i}|>" - for i in range(900) - ] - + [ - # Speaker indicator special tokens - f"<|spk{i}|>" - for i in range(16) ] + + ISO_LANGS + # Timestamp frame special tokens + + [f"<|{i}|>" for i in range(900)] + # Speaker indicator special tokens + + [f"<|spk{i}|>" for i in range(16)] ) num_tokens = len(tokens) + 3 # count "", "", "_" too print(f"We have {num_tokens} special tokens.") - next_pow_of_2 = next_power_of_2(num_tokens) - num_extra_tokens = next_pow_of_2 - num_tokens - print(f"Adding extra {num_extra_tokens} unused special tokens for a total vocab size of {next_pow_of_2}") + final_num_tokens = next_multiple_of_64(num_tokens) + num_extra_tokens = final_num_tokens - num_tokens + print(f"Adding extra {num_extra_tokens} unused special tokens for a total vocab size of {final_num_tokens}") tokens += [ # Timestamp related special tokens @@ -98,13 +67,198 @@ def main(output_dir: str) -> None: force_rebuild=True, ) - assert tokenizer.vocab_size == 1024, tokenizer.vocab_size + assert tokenizer.vocab_size == 1152, tokenizer.vocab_size + + +def next_multiple_of_64(n): + return ((n + 63) // 64) * 64 -def next_power_of_2(n): - if n <= 0: - return 1 - return 2 ** math.ceil(math.log2(n)) +ISO_LANGS = [ + "<|aa|>", + "<|ab|>", + "<|af|>", + "<|ak|>", + "<|sq|>", + "<|am|>", + "<|ar|>", + "<|an|>", + "<|hy|>", + "<|as|>", + "<|av|>", + "<|ae|>", + "<|ay|>", + "<|az|>", + "<|bm|>", + "<|ba|>", + "<|eu|>", + "<|be|>", + "<|bn|>", + "<|bi|>", + "<|bs|>", + "<|br|>", + "<|bg|>", + "<|my|>", + "<|ca|>", + "<|ch|>", + "<|ce|>", + "<|ny|>", + "<|zh|>", + "<|cu|>", + "<|cv|>", + "<|kw|>", + "<|co|>", + "<|cr|>", + "<|hr|>", + "<|cs|>", + "<|da|>", + "<|dv|>", + "<|nl|>", + "<|dz|>", + "<|en|>", + "<|eo|>", + "<|et|>", + "<|ee|>", + "<|fo|>", + "<|fj|>", + "<|fi|>", + "<|fr|>", + "<|fy|>", + "<|ff|>", + "<|gd|>", + "<|gl|>", + "<|lg|>", + "<|ka|>", + "<|de|>", + "<|el|>", + "<|kl|>", + "<|gn|>", + "<|gu|>", + "<|ht|>", + "<|ha|>", + "<|he|>", + "<|hz|>", + "<|hi|>", + "<|ho|>", + "<|hu|>", + "<|is|>", + "<|io|>", + "<|ig|>", + "<|id|>", + "<|ia|>", + "<|ie|>", + "<|iu|>", + "<|ik|>", + "<|ga|>", + "<|it|>", + "<|ja|>", + "<|jv|>", + "<|kn|>", + "<|kr|>", + "<|ks|>", + "<|kk|>", + "<|km|>", + "<|ki|>", + "<|rw|>", + "<|ky|>", + "<|kv|>", + "<|kg|>", + "<|ko|>", + "<|kj|>", + "<|ku|>", + "<|lo|>", + "<|la|>", + "<|lv|>", + "<|li|>", + "<|ln|>", + "<|lt|>", + "<|lu|>", + "<|lb|>", + "<|mk|>", + "<|mg|>", + "<|ms|>", + "<|ml|>", + "<|mt|>", + "<|gv|>", + "<|mi|>", + "<|mr|>", + "<|mh|>", + "<|mn|>", + "<|na|>", + "<|nv|>", + "<|nd|>", + "<|nr|>", + "<|ng|>", + "<|ne|>", + "<|no|>", + "<|nb|>", + "<|nn|>", + "<|oc|>", + "<|oj|>", + "<|or|>", + "<|om|>", + "<|os|>", + "<|pi|>", + "<|ps|>", + "<|fa|>", + "<|pl|>", + "<|pt|>", + "<|pa|>", + "<|qu|>", + "<|ro|>", + "<|rm|>", + "<|rn|>", + "<|ru|>", + "<|se|>", + "<|sm|>", + "<|sg|>", + "<|sa|>", + "<|sc|>", + "<|sr|>", + "<|sn|>", + "<|sd|>", + "<|si|>", + "<|sk|>", + "<|sl|>", + "<|so|>", + "<|st|>", + "<|es|>", + "<|su|>", + "<|sw|>", + "<|ss|>", + "<|sv|>", + "<|tl|>", + "<|ty|>", + "<|tg|>", + "<|ta|>", + "<|tt|>", + "<|te|>", + "<|th|>", + "<|bo|>", + "<|ti|>", + "<|to|>", + "<|ts|>", + "<|tn|>", + "<|tr|>", + "<|tk|>", + "<|tw|>", + "<|ug|>", + "<|uk|>", + "<|ur|>", + "<|uz|>", + "<|ve|>", + "<|vi|>", + "<|vo|>", + "<|wa|>", + "<|cy|>", + "<|wo|>", + "<|xh|>", + "<|ii|>", + "<|yi|>", + "<|yo|>", + "<|za|>", + "<|zu|>", +] if __name__ == "__main__": diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 059059984c3c..20b1f64fd212 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -623,8 +623,8 @@ def canary2_tokenizer(asr_model, tmp_path): "spl_tokens": CanaryTokenizer.build_special_tokenizer( [ "<|startofcontext|>", - "<|en-US|>", - "<|de-DE|>", + "<|en|>", + "<|de|>", "<|pnc|>", "<|nopnc|>", "<|itn|>", @@ -639,8 +639,8 @@ def canary2_tokenizer(asr_model, tmp_path): tmp_path, force_rebuild=False, ), - "en-US": asr_model.tokenizer.tokenizers_dict["en"], - "de-DE": asr_model.tokenizer.tokenizers_dict["de"], + "en": asr_model.tokenizer.tokenizers_dict["en"], + "de": asr_model.tokenizer.tokenizers_dict["de"], } ) @@ -660,10 +660,10 @@ def test_prompted_dataset_canary2(canary2_tokenizer): # new format c = cuts[1] - c.supervisions[0].language = "en-US" + c.supervisions[0].language = "en" c.supervisions[0].text = "asd" - c.source_lang = "en-US" - c.target_lang = "en-US" + c.source_lang = "en" + c.target_lang = "en" c.pnc = "yes" c.itn = "yes" c.diarize = "yes" @@ -673,10 +673,10 @@ def test_prompted_dataset_canary2(canary2_tokenizer): # new format with extra context c = cuts[2] - c.supervisions[0].language = "en-US" + c.supervisions[0].language = "en" c.supervisions[0].text = "asd" - c.source_lang = "en-US" - c.target_lang = "en-US" + c.source_lang = "en" + c.target_lang = "en" c.pnc = "<|pnc|>" c.itn = "<|noitn|>" c.diarize = "<|diarize|>" @@ -694,14 +694,14 @@ def test_prompted_dataset_canary2(canary2_tokenizer): i = 0 assert ( canary2_tokenizer.ids_to_text(batch.prompt[i]) - == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en-US|><|en-US|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>' + == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>' ) assert batch.prompt_lens[i] == 9 assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'i##r##r##el##e##v##a##nt' assert batch.transcript_lens[i] == 8 assert ( canary2_tokenizer.ids_to_text(batch.prompted_transcript[i]) - == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en-US|><|en-US|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>i##r##r##el##e##v##a##nt<|endoftext|>' + == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>i##r##r##el##e##v##a##nt<|endoftext|>' ) assert batch.prompted_transcript_lens[i] == 18 @@ -709,14 +709,14 @@ def test_prompted_dataset_canary2(canary2_tokenizer): i = 1 assert ( canary2_tokenizer.ids_to_text(batch.prompt[i]) - == '<|startofcontext|><|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|itn|><|timestamp|><|diarize|>' + == '<|startofcontext|><|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|itn|><|timestamp|><|diarize|>' ) assert batch.prompt_lens[i] == 9 assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'a##s##d' assert batch.transcript_lens[i] == 3 assert ( canary2_tokenizer.ids_to_text(batch.prompted_transcript[i]) - == '<|startofcontext|><|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|itn|><|timestamp|><|diarize|>a##s##d<|endoftext|>' + == '<|startofcontext|><|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|itn|><|timestamp|><|diarize|>a##s##d<|endoftext|>' ) assert batch.prompted_transcript_lens[i] == 13 @@ -724,13 +724,13 @@ def test_prompted_dataset_canary2(canary2_tokenizer): i = 2 assert ( canary2_tokenizer.ids_to_text(batch.prompt[i]) - == '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|noitn|><|timestamp|><|diarize|>' + == '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>' ) assert batch.prompt_lens[i] == 25 assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'a##s##d' assert batch.transcript_lens[i] == 3 assert ( canary2_tokenizer.ids_to_text(batch.prompted_transcript[i]) - == '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|noitn|><|timestamp|><|diarize|>a##s##d<|endoftext|>' + == '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>a##s##d<|endoftext|>' ) assert batch.prompted_transcript_lens[i] == 29