Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper large fine-tuning on wenetspeech, mutli-hans-zh #1483

Merged
merged 50 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e43c4da
add whisper fbank for wenetspeech
yuekaizhang Jan 19, 2024
315175a
add whisper fbank for other dataset
yuekaizhang Jan 19, 2024
046e071
add str to bool
yuekaizhang Jan 19, 2024
72c9d01
add decode for wenetspeech
yuekaizhang Jan 19, 2024
38f5f45
add requirments.txt
yuekaizhang Jan 19, 2024
d1b0104
add original model decode with 30s
yuekaizhang Jan 19, 2024
aa7b17e
test feature extractor speed
yuekaizhang Jan 23, 2024
f4cf9fb
add aishell2 feat
yuekaizhang Jan 23, 2024
fd77c57
change compute feature batch
yuekaizhang Jan 23, 2024
e46e9b7
fix overwrite
yuekaizhang Jan 23, 2024
f66b266
fix executor
yuekaizhang Jan 23, 2024
08db305
regression
yuekaizhang Jan 23, 2024
af29455
add kaldifeatwhisper fbank
yuekaizhang Jan 23, 2024
df54121
fix io issue
yuekaizhang Jan 23, 2024
cf85019
parallel jobs
yuekaizhang Jan 23, 2024
baa7c5f
use multi machines
yuekaizhang Jan 23, 2024
e1a55b9
add wenetspeech fine-tune scripts
yuekaizhang Jan 25, 2024
e49534f
add monkey patch codes
yuekaizhang Jan 25, 2024
ad796d9
remove useless file
yuekaizhang Jan 25, 2024
b76cd65
fix subsampling factor
yuekaizhang Jan 25, 2024
1600f7d
fix too long audios
yuekaizhang Jan 25, 2024
bb07b65
add remove long short
yuekaizhang Jan 26, 2024
c19891e
add remove long short
yuekaizhang Jan 26, 2024
341c29e
fix whisper version to support multi batch beam
yuekaizhang Jan 28, 2024
d8a329e
decode all wav files
yuekaizhang Jan 28, 2024
4826f08
remove utterance more than 30s in test_net
yuekaizhang Jan 29, 2024
955d16e
only test net
yuekaizhang Jan 29, 2024
97aa482
only test net
yuekaizhang Jan 29, 2024
ff75cf6
using soft links
yuekaizhang Jan 31, 2024
6fd14d2
add kespeech whisper feats
yuekaizhang Feb 19, 2024
be001a8
fix index error
yuekaizhang Feb 20, 2024
910e5db
add manifests for whisper
yuekaizhang Feb 22, 2024
0212266
change to licomchunky writer
yuekaizhang Feb 22, 2024
f893ae2
add missing option
yuekaizhang Feb 22, 2024
5a62723
decrease cpu
yuekaizhang Feb 22, 2024
73e5cae
add speed perturb for kespeech
yuekaizhang Feb 23, 2024
fa58ed2
fix kespeech speed perturb
yuekaizhang Feb 23, 2024
73a7687
add dataset
yuekaizhang Feb 23, 2024
50b575a
load checkpoint from specific path
yuekaizhang Mar 5, 2024
b422e7a
add speechio
yuekaizhang Mar 6, 2024
a00c0c5
add speechio results
yuekaizhang Mar 7, 2024
19e21ba
Merge branch 'master' into whisper_zh
yuekaizhang Mar 7, 2024
1c6a6a2
Update train.py
JinZr Mar 7, 2024
4cca65a
Update compute_fbank_wenetspeech_splits.py
JinZr Mar 7, 2024
211ce4c
Update train.py
JinZr Mar 7, 2024
262792f
Update prepare.sh
JinZr Mar 7, 2024
e96f533
Update compute_fbank_wenetspeech_splits.py
JinZr Mar 7, 2024
21af721
Update train.py
JinZr Mar 7, 2024
ab57bb5
fixed a formatting issue
JinZr Mar 7, 2024
2c54d6b
remove submodule
yuekaizhang Mar 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion egs/aishell/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |

```bash
./prepare.sh
./prepare.sh

export CUDA_VISIBLE_DEVICES="0,1"

Expand Down
2 changes: 1 addition & 1 deletion egs/aishell2/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Results

### Aishell2 char-based training results
### Aishell2 char-based training results

#### Pruned transducer stateless 5

Expand Down
36 changes: 28 additions & 8 deletions egs/aishell2/ASR/local/compute_fbank_aishell2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand All @@ -42,10 +49,12 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_jobs = min(8, os.cpu_count())

dataset_parts = (
"train",
Expand All @@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
Expand Down Expand Up @@ -111,7 +124,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()


Expand All @@ -122,5 +140,7 @@ def get_args():

args = get_args()
compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)
10 changes: 10 additions & 0 deletions egs/aishell2/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
fi

whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute whisper fbank for aishell2"
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell2.whisper.done
fi
fi

if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
Expand Down
2 changes: 1 addition & 1 deletion egs/aishell4/ASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).

The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.

(From [Open Speech and Language Resources](https://www.openslr.org/111/))

Expand Down
37 changes: 29 additions & 8 deletions egs/aishell4/ASR/local/compute_fbank_aishell4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from pathlib import Path

import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand All @@ -42,10 +49,12 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_jobs = min(8, os.cpu_count())

dataset_parts = (
"train_S",
Expand All @@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
Expand All @@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=ChunkedLilcomHdf5Writer,
storage_type=LilcomChunkyWriter,
)

logging.info("About splitting cuts into smaller chunks")
Expand All @@ -121,7 +135,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()


Expand All @@ -132,5 +151,7 @@ def get_args():

args = get_args()
compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)
29 changes: 15 additions & 14 deletions egs/aishell4/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail

stage=-1
stop_stage=100
stop_stage=7
perturb_speed=true


Expand Down Expand Up @@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi

if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4"
log "Stage 2: Compute fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/aishell4/.fbank.done
touch data/fbank/.fbank.done
fi
fi

whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: Compute whisper fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi
fi

Expand All @@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell4.done
fi
fi

if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir

Expand Down
37 changes: 29 additions & 8 deletions egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand All @@ -42,18 +49,20 @@
torch.set_num_interop_threads(1)


def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_jobs = min(8, os.cpu_count())

dataset_parts = (
"train",
"eval",
"test",
)

prefix = "alimeeting"
prefix = "alimeeting-far"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
Expand All @@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
Expand Down Expand Up @@ -121,7 +135,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use the Whisper Fbank feature extractor. Default: False.",
)
return parser.parse_args()


Expand All @@ -132,5 +151,7 @@ def get_args():

args = get_args()
compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)
30 changes: 16 additions & 14 deletions egs/alimeeting/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail

stage=-1
stop_stage=100
stop_stage=7
perturb_speed=true

# We assume dl_dir (download dir) contains the following
Expand Down Expand Up @@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi

if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process alimeeting"
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
mkdir -p data/fbank/alimeeting
log "Stage 2: compute fbank for alimeeting"
if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
touch data/fbank/.fbank.done
fi
fi

whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: compute whisper fbank for alimeeting"
if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi
fi

Expand All @@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting"
if [ ! -f data/fbank/.alimeeting.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed True
touch data/fbank/.alimeeting.done
fi
fi

if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir

Expand Down
Loading
Loading