Skip to content

Commit

Permalink
Zipformer recipe for CommonVoice (#1546)
Browse files Browse the repository at this point in the history
* added scripts for char-based lang prep training scripts

* added `Zipformer` recipe for commonvoice

---------

Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
JinZr and csukuangfj authored Apr 9, 2024
1 parent 87843e9 commit f2e36ec
Show file tree
Hide file tree
Showing 43 changed files with 6,764 additions and 573 deletions.
89 changes: 78 additions & 11 deletions egs/commonvoice/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,91 @@
## Results
### GigaSpeech BPE training results (Pruned Stateless Transducer 7)

### Commonvoice Cantonese (zh-HK) Char training results (Zipformer)

See #1546 for more details.

Number of model parameters: 72526519, i.e., 72.53 M

The best CER, for CommonVoice 16.1 (cv-corpus-16.1-2023-12-06/zh-HK) is below:

| | Dev | Test | Note |
|----------------------|-------|------|--------------------|
| greedy_search | 1.17 | 1.22 | --epoch 24 --avg 5 |
| modified_beam_search | 0.98 | 1.11 | --epoch 24 --avg 5 |
| fast_beam_search | 1.08 | 1.27 | --epoch 24 --avg 5 |

When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (w/o blank penalty),
the best CER is below:

| | Dev | Test | Note |
|----------------------|-------|------|--------------------|
| greedy_search | 42.40 | 42.03| --epoch 24 --avg 5 |
| modified_beam_search | 39.73 | 39.19| --epoch 24 --avg 5 |
| fast_beam_search | 42.14 | 41.98| --epoch 24 --avg 5 |

When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (with blank penalty set to 2.2),
the best CER is below:

| | Dev | Test | Note |
|----------------------|-------|------|----------------------------------------|
| greedy_search | 39.19 | 39.09| --epoch 24 --avg 5 --blank-penalty 2.2 |
| modified_beam_search | 37.73 | 37.65| --epoch 24 --avg 5 --blank-penalty 2.2 |
| fast_beam_search | 37.73 | 37.74| --epoch 24 --avg 5 --blank-penalty 2.2 |

To reproduce the above result, use the following commands for training:

```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train_char.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--cv-manifest-dir data/zh-HK/fbank \
--language zh-HK \
--use-validated-set 1 \
--context-size 1 \
--max-duration 1000
```

and the following commands for decoding:

```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode_char.py \
--epoch 24 \
--avg 5 \
--decoding-method $method \
--exp-dir zipformer/exp \
--cv-manifest-dir data/zh-HK/fbank \
--context-size 1 \
--language zh-HK
done
```

Detailed experimental results and pre-trained model are available at:
<https://huggingface.co/zrjin/icefall-asr-commonvoice-zh-HK-zipformer-2024-03-20>


### CommonVoice English (en) BPE training results (Pruned Stateless Transducer 7)

#### [pruned_transducer_stateless7](./pruned_transducer_stateless7)

See #997 for more details.
See #997 for more details.

Number of model parameters: 70369391, i.e., 70.37 M

Note that the result is obtained using GigaSpeech transcript trained BPE model

The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below:

Results are:

| | Dev | Test |
|----------------------|-------|-------|
| greedy search | 9.96 | 12.54 |
| modified beam search | 9.86 | 12.48 |
| greedy_search | 9.96 | 12.54 |
| modified_beam_search | 9.86 | 12.48 |

To reproduce the above result, use the following commands for training:

Expand Down Expand Up @@ -55,10 +126,6 @@ and the following commands for decoding:
Pretrained model is available at
<https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17>

The tensorboard log for training is available at
<https://tensorboard.dev/experiment/j4pJQty6RMOkMJtRySREKw/>


### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)

#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
Expand All @@ -73,9 +140,9 @@ Results are:

| decoding method | Test |
|----------------------|-------|
| greedy search | 9.95 |
| modified beam search | 9.57 |
| fast beam search | 9.67 |
| greedy_search | 9.95 |
| modified_beam_search | 9.57 |
| fast_beam_search | 9.67 |

Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.

Expand Down
33 changes: 27 additions & 6 deletions egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Yifan Yang)
# Copyright 2023-2024 Xiaomi Corp. (Yifan Yang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -17,7 +18,6 @@

import argparse
import logging
from datetime import datetime
from pathlib import Path

import torch
Expand All @@ -30,6 +30,8 @@
set_caching_enabled,
)

from icefall.utils import str2bool

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
Expand All @@ -41,6 +43,14 @@
def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--subset",
type=str,
default="train",
choices=["train", "validated", "invalidated"],
help="""Dataset parts to compute fbank. """,
)

parser.add_argument(
"--language",
type=str,
Expand All @@ -66,28 +76,35 @@ def get_args():
"--num-splits",
type=int,
required=True,
help="The number of splits of the train subset",
help="The number of splits of the subset",
)

parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
help="Process pieces starting from this number (included).",
)

parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
help="Stop processing pieces until this number (excluded).",
)

parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)

return parser.parse_args()


def compute_fbank_commonvoice_splits(args):
subset = "train"
subset = args.subset
num_splits = args.num_splits
language = args.language
output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}"
Expand Down Expand Up @@ -130,6 +147,10 @@ def compute_fbank_commonvoice_splits(args):
keep_overlapping=False, min_duration=None
)

if args.perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)

logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
Expand Down
1 change: 1 addition & 0 deletions egs/commonvoice/ASR/local/prepare_char.py
1 change: 1 addition & 0 deletions egs/commonvoice/ASR/local/prepare_lang.py
1 change: 1 addition & 0 deletions egs/commonvoice/ASR/local/prepare_lang_fst.py
46 changes: 37 additions & 9 deletions egs/commonvoice/ASR/local/preprocess_commonvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pathlib import Path
from typing import Optional

from lhotse import CutSet, SupervisionSegment
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached


Expand Down Expand Up @@ -52,14 +52,20 @@ def normalize_text(utt: str, language: str) -> str:
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
elif language == "pl":
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
elif language == "yue":
return (
utt.replace(" ", "")
.replace(",", "")
.replace("。", " ")
.replace("?", "")
.replace("!", "")
.replace("?", "")
elif language in ["yue", "zh-HK"]:
# Mozilla Common Voice uses both "yue" and "zh-HK" for Cantonese
# Not sure why they decided to do this...
# None en/zh-yue tokens are manually removed here

# fmt: off
tokens_to_remove = [",", "。", "?", "!", "?", "!", "‘", "、", ",", "\.", ":", ";", "「", "」", "“", "”", "~", "—", "ㄧ", "《", "》", "…", "⋯", "·", "﹒", ".", ":", "︰", "﹖", "(", ")", "-", "~", ";", "", "⠀", "﹔", "/", "A", "B", "–", "‧"]

# fmt: on
utt = utt.upper().replace("\\", "")
return re.sub(
pattern="|".join([f"[{token}]" for token in tokens_to_remove]),
repl="",
string=utt,
)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -130,6 +136,28 @@ def preprocess_commonvoice(
supervisions=m["supervisions"],
).resample(16000)

if partition == "validated":
logging.warning(
"""
The 'validated' partition contains the data of both 'train', 'dev'
and 'test' partitions. We filter out the 'dev' and 'test' partition
here.
"""
)
dev_ids = src_dir / f"cv-{language}_dev_ids"
test_ids = src_dir / f"cv-{language}_test_ids"
assert (
dev_ids.is_file()
), f"{dev_ids} does not exist, please check stage 1 of the prepare.sh"
assert (
test_ids.is_file()
), f"{test_ids} does not exist, please check stage 1 of the prepare.sh"
dev_ids = dev_ids.read_text().strip().split("\n")
test_ids = test_ids.read_text().strip().split("\n")
cut_set = cut_set.filter(
lambda x: x.supervisions[0].id not in dev_ids + test_ids
)

# Run data augmentation that needs to be done in the
# time domain.
logging.info(f"Saving to {raw_cuts_path}")
Expand Down
Loading

0 comments on commit f2e36ec

Please sign in to comment.