Skip to content

Commit

Permalink
Update documentation to PromptASR (#1321)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 authored Oct 19, 2023
1 parent 36c60b0 commit ce372cc
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Mingshuang Luo
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -21,21 +22,35 @@
Usage:
# For mix precision training:
# For mix precision training, using MCP style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style MCP \
--max-duration 1000
# For mix precision training, using UC style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style UC \
--max-duration 1000
# To train a streaming model
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
Expand Down Expand Up @@ -100,7 +115,7 @@
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]


def get_first(
def get_mixed_cased_with_punc(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
Expand Down Expand Up @@ -479,6 +494,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--transcript-style",
type=str,
default="UC",
choices=["UC", "MCP"],
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
MCP stands for mix-cased text with punctuation.
""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -1223,7 +1248,11 @@ def remove_short_and_long_utt(c: Cut):
else:
sampler_state_dict = None

text_sampling_func = get_upper_only_alpha
if params.transcript_style == "UC":
text_sampling_func = get_upper_only_alpha
else:
text_sampling_func = get_mixed_cased_with_punc
logging.info(f"Using {params.transcript_style} style for training.")
logging.info(f"Text sampling func: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(
train_cuts,
Expand Down

0 comments on commit ce372cc

Please sign in to comment.