-
Notifications
You must be signed in to change notification settings - Fork 80
/
benchmarks.yml
72 lines (67 loc) · 2.5 KB
/
benchmarks.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
---
common_options: &common_options
data:
throughput:
regexp: 'throughput: *(.*?) samples\/sec'
skip: 1
output:
- [samples/sec, 'throughput']
config_options: &config_options
requirements_path: requirements.txt
pre_run_commands: [make clean, make all]
pytorch_wenet_conformer_globalbs_288_train_gen_pod16:
<<: [*common_options, *config_options]
description: Conformer public model fp16 precision, 1 replica runs on 4 IPUs
replicate x4. Performance test.
cmd: >-
python3
main.py
train
--trainer.log_every_n_step 1
--train_dataset.use_generated_data true
--trainer.num_epochs=2
--ipu_options.num_replicas=4
--ipu_options.gradient_accumulation=18
--train_iterator.batch_size=4
pytorch_wenet_conformer_globalbs_288_train_real_pod16_conv:
<<: [*common_options, *config_options]
description: Conformer public model fp16 precision, 1 replica runs on 4 IPUs
replicate x4. Train to convergence.
cmd: >-
python3
main.py
train
--trainer.log_every_n_step 100
--train_dataset.data_list $DATASETS_DIR/aishell/train/data.list
--val_dataset.data_list $DATASETS_DIR/aishell/test/data.list
--vocab.vocab_path $DATASETS_DIR/aishell/dict/lang_char.txt
--normalizer.cmvn $DATASETS_DIR/aishell/train/global_cmvn
--compute_cer.label_text $DATASETS_DIR/aishell/test/text
--ipu_options.num_replicas=4
--ipu_options.gradient_accumulation=18
--train_iterator.batch_size=4
--trainer.wandb_project_name torch-conformer
--trainer.wandb_run_name pytorch_wenet_conformer_globalbs_288_train_real_pod16_conv
pytorch_wenet_conformer_globalbs_288_validation_real_pod4:
<<: *config_options
description: Conformer public model fp16 precision, validation
cmd: >-
python3
main.py
validate
--val_dataset.data_list $DATASETS_DIR/aishell/test/data.list
--vocab.vocab_path $DATASETS_DIR/aishell/dict/lang_char.txt
--normalizer.cmvn $DATASETS_DIR/aishell/train/global_cmvn
--compute_cer.label_text $DATASETS_DIR/aishell/test/text
--checkpoints.save_checkpoint_path "./checkpoint"
--checkpoints.pretrained_checkpoint True
--trainer.num_epochs 240
--ipu_options.device_iterations 10
--trainer.wandb_project_name torch-conformer
--trainer.wandb_run_name pytorch_wenet_conformer_globalbs_288_validation_real_pod16
data:
loss:
regexp: loss_valid:\s+([\d\.]+)'
reduction_type: 'final'
output:
- [Validation loss, 'loss']