Skip to content

Commit

Permalink
Simplified language codes back to Canary1 format but expanded to all …
Browse files Browse the repository at this point in the history
…ISO lang codes

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Nov 26, 2024
1 parent f8f4964 commit 8464f2e
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 66 deletions.
5 changes: 0 additions & 5 deletions nemo/collections/common/prompts/canary2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,9 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s

any_special_token_present = False

lang_dict_compat = {"en": "en-US", "es": "es-ES", "fr": "fr-FR", "de": "de-DE"}
for k in ("source_lang", "target_lang"):
if k in slot_values and not ((v := slot_values[k]).startswith("<|") and v.endswith("|>")):
val = slot_values[k]
val = lang_dict_compat.get(val, val)
slot_values[k] = "<|" + val + "|>"
any_special_token_present = True

Expand All @@ -130,9 +128,6 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s
# and slots for this turn correspond to user prompt.
if any_special_token_present and PromptFormatter.PROMPT_LANGUAGE_SLOT not in slot_values:
slot_values[PromptFormatter.PROMPT_LANGUAGE_SLOT] = CANARY_SPECIAL_TOKENIZER
else:
if (l := slot_values.get(PromptFormatter.PROMPT_LANGUAGE_SLOT)) is not None:
slot_values[PromptFormatter.PROMPT_LANGUAGE_SLOT] = lang_dict_compat.get(l, l)

return slot_values

Expand Down
244 changes: 199 additions & 45 deletions scripts/speech_recognition/canary/build_canary_2_special_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,22 @@ def main(output_dir: str) -> None:
"<|emo:sad|>",
"<|emo:angry|>",
]
# Language special tokens
+ [
# Language special tokens
"<|unklang|>",
"<|ar-AR|>",
"<|cs-CZ|>",
"<|da-DA|>",
"<|de-DE|>",
"<|en-US|>",
"<|en-GB|>",
"<|es-US|>",
"<|es-ES|>",
"<|fr-CA|>",
"<|fr-FR|>",
"<|hi-IN|>",
"<|he-IL|>",
"<|it-IT|>",
"<|ja-JP|>",
"<|ko-KR|>",
"<|nb-NO|>",
"<|nl-NL|>",
"<|nn-NO|>",
"<|pl-PO|>",
"<|pt-PT|>",
"<|pt-BR|>",
"<|ru-RU|>",
"<|sv-SW|>",
"<|th-TH|>",
"<|tr-TR|>",
"<|zh-CN|>",
]
+ [
# Timestamp frame special tokens
f"<|{i}|>"
for i in range(900)
]
+ [
# Speaker indicator special tokens
f"<|spk{i}|>"
for i in range(16)
]
+ ISO_LANGS
# Timestamp frame special tokens
+ [f"<|{i}|>" for i in range(900)]
# Speaker indicator special tokens
+ [f"<|spk{i}|>" for i in range(16)]
)

num_tokens = len(tokens) + 3 # count "<pad>", "<unk>", "_" too
print(f"We have {num_tokens} special tokens.")
next_pow_of_2 = next_power_of_2(num_tokens)
num_extra_tokens = next_pow_of_2 - num_tokens
print(f"Adding extra {num_extra_tokens} unused special tokens for a total vocab size of {next_pow_of_2}")
final_num_tokens = next_multiple_of_64(num_tokens)
num_extra_tokens = final_num_tokens - num_tokens
print(f"Adding extra {num_extra_tokens} unused special tokens for a total vocab size of {final_num_tokens}")

tokens += [
# Timestamp related special tokens
Expand All @@ -98,13 +67,198 @@ def main(output_dir: str) -> None:
force_rebuild=True,
)

assert tokenizer.vocab_size == 1024, tokenizer.vocab_size
assert tokenizer.vocab_size == 1152, tokenizer.vocab_size


def next_multiple_of_64(n):
return ((n + 63) // 64) * 64


def next_power_of_2(n):
if n <= 0:
return 1
return 2 ** math.ceil(math.log2(n))
ISO_LANGS = [
"<|aa|>",
"<|ab|>",
"<|af|>",
"<|ak|>",
"<|sq|>",
"<|am|>",
"<|ar|>",
"<|an|>",
"<|hy|>",
"<|as|>",
"<|av|>",
"<|ae|>",
"<|ay|>",
"<|az|>",
"<|bm|>",
"<|ba|>",
"<|eu|>",
"<|be|>",
"<|bn|>",
"<|bi|>",
"<|bs|>",
"<|br|>",
"<|bg|>",
"<|my|>",
"<|ca|>",
"<|ch|>",
"<|ce|>",
"<|ny|>",
"<|zh|>",
"<|cu|>",
"<|cv|>",
"<|kw|>",
"<|co|>",
"<|cr|>",
"<|hr|>",
"<|cs|>",
"<|da|>",
"<|dv|>",
"<|nl|>",
"<|dz|>",
"<|en|>",
"<|eo|>",
"<|et|>",
"<|ee|>",
"<|fo|>",
"<|fj|>",
"<|fi|>",
"<|fr|>",
"<|fy|>",
"<|ff|>",
"<|gd|>",
"<|gl|>",
"<|lg|>",
"<|ka|>",
"<|de|>",
"<|el|>",
"<|kl|>",
"<|gn|>",
"<|gu|>",
"<|ht|>",
"<|ha|>",
"<|he|>",
"<|hz|>",
"<|hi|>",
"<|ho|>",
"<|hu|>",
"<|is|>",
"<|io|>",
"<|ig|>",
"<|id|>",
"<|ia|>",
"<|ie|>",
"<|iu|>",
"<|ik|>",
"<|ga|>",
"<|it|>",
"<|ja|>",
"<|jv|>",
"<|kn|>",
"<|kr|>",
"<|ks|>",
"<|kk|>",
"<|km|>",
"<|ki|>",
"<|rw|>",
"<|ky|>",
"<|kv|>",
"<|kg|>",
"<|ko|>",
"<|kj|>",
"<|ku|>",
"<|lo|>",
"<|la|>",
"<|lv|>",
"<|li|>",
"<|ln|>",
"<|lt|>",
"<|lu|>",
"<|lb|>",
"<|mk|>",
"<|mg|>",
"<|ms|>",
"<|ml|>",
"<|mt|>",
"<|gv|>",
"<|mi|>",
"<|mr|>",
"<|mh|>",
"<|mn|>",
"<|na|>",
"<|nv|>",
"<|nd|>",
"<|nr|>",
"<|ng|>",
"<|ne|>",
"<|no|>",
"<|nb|>",
"<|nn|>",
"<|oc|>",
"<|oj|>",
"<|or|>",
"<|om|>",
"<|os|>",
"<|pi|>",
"<|ps|>",
"<|fa|>",
"<|pl|>",
"<|pt|>",
"<|pa|>",
"<|qu|>",
"<|ro|>",
"<|rm|>",
"<|rn|>",
"<|ru|>",
"<|se|>",
"<|sm|>",
"<|sg|>",
"<|sa|>",
"<|sc|>",
"<|sr|>",
"<|sn|>",
"<|sd|>",
"<|si|>",
"<|sk|>",
"<|sl|>",
"<|so|>",
"<|st|>",
"<|es|>",
"<|su|>",
"<|sw|>",
"<|ss|>",
"<|sv|>",
"<|tl|>",
"<|ty|>",
"<|tg|>",
"<|ta|>",
"<|tt|>",
"<|te|>",
"<|th|>",
"<|bo|>",
"<|ti|>",
"<|to|>",
"<|ts|>",
"<|tn|>",
"<|tr|>",
"<|tk|>",
"<|tw|>",
"<|ug|>",
"<|uk|>",
"<|ur|>",
"<|uz|>",
"<|ve|>",
"<|vi|>",
"<|vo|>",
"<|wa|>",
"<|cy|>",
"<|wo|>",
"<|xh|>",
"<|ii|>",
"<|yi|>",
"<|yo|>",
"<|za|>",
"<|zu|>",
]


if __name__ == "__main__":
Expand Down
32 changes: 16 additions & 16 deletions tests/collections/asr/test_asr_multitask_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def canary2_tokenizer(asr_model, tmp_path):
"spl_tokens": CanaryTokenizer.build_special_tokenizer(
[
"<|startofcontext|>",
"<|en-US|>",
"<|de-DE|>",
"<|en|>",
"<|de|>",
"<|pnc|>",
"<|nopnc|>",
"<|itn|>",
Expand All @@ -639,8 +639,8 @@ def canary2_tokenizer(asr_model, tmp_path):
tmp_path,
force_rebuild=False,
),
"en-US": asr_model.tokenizer.tokenizers_dict["en"],
"de-DE": asr_model.tokenizer.tokenizers_dict["de"],
"en": asr_model.tokenizer.tokenizers_dict["en"],
"de": asr_model.tokenizer.tokenizers_dict["de"],
}
)

Expand All @@ -660,10 +660,10 @@ def test_prompted_dataset_canary2(canary2_tokenizer):

# new format
c = cuts[1]
c.supervisions[0].language = "en-US"
c.supervisions[0].language = "en"
c.supervisions[0].text = "asd"
c.source_lang = "en-US"
c.target_lang = "en-US"
c.source_lang = "en"
c.target_lang = "en"
c.pnc = "yes"
c.itn = "yes"
c.diarize = "yes"
Expand All @@ -673,10 +673,10 @@ def test_prompted_dataset_canary2(canary2_tokenizer):

# new format with extra context
c = cuts[2]
c.supervisions[0].language = "en-US"
c.supervisions[0].language = "en"
c.supervisions[0].text = "asd"
c.source_lang = "en-US"
c.target_lang = "en-US"
c.source_lang = "en"
c.target_lang = "en"
c.pnc = "<|pnc|>"
c.itn = "<|noitn|>"
c.diarize = "<|diarize|>"
Expand All @@ -694,43 +694,43 @@ def test_prompted_dataset_canary2(canary2_tokenizer):
i = 0
assert (
canary2_tokenizer.ids_to_text(batch.prompt[i])
== '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en-US|><|en-US|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
== '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
)
assert batch.prompt_lens[i] == 9
assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'i##r##r##el##e##v##a##nt'
assert batch.transcript_lens[i] == 8
assert (
canary2_tokenizer.ids_to_text(batch.prompted_transcript[i])
== '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en-US|><|en-US|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>i##r##r##el##e##v##a##nt<|endoftext|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
== '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|nopnc|><|noitn|><|notimestamp|><|nodiarize|>i##r##r##el##e##v##a##nt<|endoftext|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
)
assert batch.prompted_transcript_lens[i] == 18

# Test example 1
i = 1
assert (
canary2_tokenizer.ids_to_text(batch.prompt[i])
== '<|startofcontext|><|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|itn|><|timestamp|><|diarize|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
== '<|startofcontext|><|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|itn|><|timestamp|><|diarize|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
)
assert batch.prompt_lens[i] == 9
assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'a##s##d<pad><pad><pad><pad><pad>'
assert batch.transcript_lens[i] == 3
assert (
canary2_tokenizer.ids_to_text(batch.prompted_transcript[i])
== '<|startofcontext|><|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|itn|><|timestamp|><|diarize|>a##s##d<|endoftext|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
== '<|startofcontext|><|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|itn|><|timestamp|><|diarize|>a##s##d<|endoftext|><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'
)
assert batch.prompted_transcript_lens[i] == 13

# Test example 2
i = 2
assert (
canary2_tokenizer.ids_to_text(batch.prompt[i])
== '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|noitn|><|timestamp|><|diarize|>'
== '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>'
)
assert batch.prompt_lens[i] == 25
assert canary2_tokenizer.ids_to_text(batch.transcript[i]) == 'a##s##d<pad><pad><pad><pad><pad>'
assert batch.transcript_lens[i] == 3
assert (
canary2_tokenizer.ids_to_text(batch.prompted_transcript[i])
== '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en-US|><|en-US|><|pnc|><|noitn|><|timestamp|><|diarize|>a##s##d<|endoftext|>'
== '<|startofcontext|>s##o##m##ed##e##c##o##d##erc##o##nt##e##x##t<|startoftranscript|><|emo:happy|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>a##s##d<|endoftext|>'
)
assert batch.prompted_transcript_lens[i] == 29

0 comments on commit 8464f2e

Please sign in to comment.