Skip to content

Commit

Permalink
update training scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jan 31, 2020
1 parent 7c7dda3 commit 9508860
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 61 deletions.
9 changes: 9 additions & 0 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ def forward(self, x):

return nnet_output, xent_output

def constraint_orthonormal(self):

This comment has been minimized.

Copy link
@danpovey

danpovey Jan 31, 2020

Contributor

should be constrain_orthonormal

This comment has been minimized.

Copy link
@csukuangfj

csukuangfj Jan 31, 2020

Author Contributor

fixed.

for i in range(len(self.tdnnfs)):
self.tdnnfs[i].constraint_orthonormal()

self.prefinal_l.constraint_orthonormal()
self.prefinal_chain.constraint_orthonormal()
self.prefinal_xent.constraint_orthonormal()


if __name__ == '__main__':
feat_dim = 43
Expand All @@ -212,3 +220,4 @@ def forward(self, x):
x = torch.arange(N * T * C).reshape(N, T, C).float()
nnet_output, xent_output = model(x)
print(x.shape, nnet_output.shape, xent_output.shape)
model.constraint_orthonormal()
33 changes: 20 additions & 13 deletions egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ def _check_args(args):
assert args.feat_dim > 0
assert args.output_dim > 0
assert args.hidden_dim > 0
assert args.bottleneck_dim > 0

assert args.kernel_size_list is not None
assert len(args.kernel_size_list) > 0
assert args.time_stride_list is not None
assert len(args.time_stride_list) > 0

assert args.stride_list is not None
assert len(args.stride_list) > 0
assert args.conv_stride_list is not None
assert len(args.conv_stride_list) > 0

args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]

args.stride_list = [int(k) for k in args.stride_list.split(', ')]
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]

assert len(args.kernel_size_list) == len(args.stride_list)
assert len(args.time_stride_list) == len(args.conv_stride_list)

assert args.log_level in ['debug', 'info', 'warning']

Expand Down Expand Up @@ -195,15 +196,21 @@ def get_args():
required=True,
type=int)

parser.add_argument('--kernel-size-list',
dest='kernel_size_list',
help='kernel size list',
parser.add_argument('--bottleneck-dim',
dest='bottleneck_dim',
help='nn bottleneck dimension',
required=True,
type=int)

parser.add_argument('--time-stride-list',
dest='time_stride_list',
help='time stride list',
required=True,
type=str)

parser.add_argument('--stride-list',
dest='stride_list',
help='stride list',
parser.add_argument('--conv-stride-list',
dest='conv_stride_list',
help='conv stride list',
required=True,
type=str)

Expand Down
12 changes: 10 additions & 2 deletions egs/aishell/s10/chain/tdnnf_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, dim, bottleneck_dim, time_stride):
assert time_stride in [0, 1]
# WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
# and [0, 1] for the second affine layer;
# We use [-1, 0, 1] for the first linear layer
# we use [-1, 0, 1] for the first linear layer if time_stride == 1

if time_stride == 0:
kernel_size = 1
Expand Down Expand Up @@ -142,6 +142,9 @@ def forward(self, x):

return x

def constraint_orthonormal(self):
self.linear.constraint_orthonormal()


class FactorizedTDNN(nn.Module):
'''
Expand Down Expand Up @@ -175,6 +178,8 @@ def __init__(self,
time_stride=time_stride)

# affine requires [N, C, T]
# WARNING(fangjun): we do not use nn.Linear here
# since we want to use `stride`
self.affine = nn.Conv1d(in_channels=bottleneck_dim,
out_channels=dim,
kernel_size=1,
Expand All @@ -191,20 +196,23 @@ def forward(self, x):
input_x = x

x = self.linear(x)

# at this point, x is [N, C, T]

x = self.affine(x)

# at this point, x is [N, C, T]

x = F.relu(x)

# at this point, x is [N, C, T]

x = self.batchnorm(x)

# at this point, x is [N, C, T]

# TODO(fangjun): implement GeneralDropoutComponent in PyTorch

# at this point, x is [N, C, T]
if self.linear.kernel_size == 3:
x = self.bypass_scale * input_x[:, :, 1:-1:self.conv_stride] + x
else:
Expand Down
11 changes: 9 additions & 2 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# disable warnings when loading tensorboard
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_value_
Expand Down Expand Up @@ -84,6 +85,11 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
total_weight += objf_l2_term_weight[2].item()
num_frames = nnet_output.shape[0]
total_frames += num_frames

if np.random.choice(4) == 0:
with torch.no_grad():
model.constraint_orthonormal()

if batch_idx % 100 == 0:
logging.info(
'Process {}/{}({:.6f}%) global average objf: {:.6f} over {} '
Expand Down Expand Up @@ -135,8 +141,9 @@ def main():
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
kernel_size_list=args.kernel_size_list,
stride_list=args.stride_list)
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)

start_epoch = 0
num_epochs = args.num_epochs
Expand Down
10 changes: 10 additions & 0 deletions egs/aishell/s10/conf/mfcc_hires.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# config for high-resolution MFCC features, intended for neural network training.
# Note: we keep all cepstra, so it has the same info as filterbank features,
# but MFCC is more easily compressible (because less correlated) which is why
# we prefer this method.
--use-energy=false # use average of log energy, not energy.
--sample-frequency=16000 # AISHELL-2 is sampled at 16kHz
--num-mel-bins=40 # similar to Google's setup.
--num-ceps=40 # there is no dimensionality reduction.
--low-freq=20 # low cutoff frequency for mel bins
--high-freq=-400 # high cutoff frequency, relative to Nyquist of 8000 (=7600)
48 changes: 28 additions & 20 deletions egs/aishell/s10/local/run_chain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

set -e

stage=0
stage=8

# GPU device id to use (count from 0).
# you can also set `CUDA_VISIBLE_DEVICES` and set `device_id=0`
device_id=0
device_id=6

nj=10

Expand All @@ -19,8 +19,8 @@ lat_dir=exp/tri5a_lats # input lat dir
treedir=exp/chain/tri5_tree # output tree dir

# You should know how to calculate your model's left/right context **manually**
model_left_context=12
model_right_context=12
model_left_context=28
model_right_context=28
egs_left_context=$[$model_left_context + 1]
egs_right_context=$[$model_right_context + 1]
frames_per_eg=150,110,90
Expand All @@ -30,9 +30,10 @@ minibatch_size=128
num_epochs=6
lr=1e-3

hidden_dim=625
kernel_size_list="1, 3, 3, 3, 3, 3" # comma separated list
stride_list="1, 1, 3, 1, 1, 1" # comma separated list
hidden_dim=1024
bottleneck_dim=128
time_stride_list="1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list
conv_stride_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list

log_level=info # valid values: debug, info, warning

Expand All @@ -47,11 +48,16 @@ save_nn_output_as_compressed=false

if [[ $stage -le 0 ]]; then
for datadir in train dev test; do
dst_dir=data/fbank_pitch/$datadir
dst_dir=data/mfcc_hires/$datadir
if [[ ! -f $dst_dir/feats.scp ]]; then
echo "making mfcc-pitch features for LF-MMI training"
utils/copy_data_dir.sh data/$datadir $dst_dir
echo "making fbank-pitch features for LF-MMI training"
steps/make_fbank_pitch.sh --cmd $train_cmd --nj $nj $dst_dir || exit 1
steps/make_mfcc_pitch.sh \
--mfcc-config conf/mfcc_hires.conf \
--pitch-config conf/pitch.conf \
--cmd "$train_cmd" \
--nj $nj \
$dst_dir || exit 1
steps/compute_cmvn_stats.sh $dst_dir || exit 1
utils/fix_data_dir.sh $dst_dir
else
Expand Down Expand Up @@ -80,12 +86,12 @@ if [[ $stage -le 2 ]]; then
# step compared with other recipes.
steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \
--context-opts "--context-width=2 --central-position=1" \
--cmd "$train_cmd" 5000 data/train $lang $ali_dir $treedir
--cmd "$train_cmd" 5000 data/mfcc/train $lang $ali_dir $treedir
fi

if [[ $stage -le 3 ]]; then
echo "creating phone language-model"
$train_cmd exp/chain/log/make_phone_lm.log \
"$train_cmd" exp/chain/log/make_phone_lm.log \
chain-est-phone-lm \
"ark:gunzip -c $treedir/ali.*.gz | ali-to-phones $treedir/final.mdl ark:- ark:- |" \
exp/chain/phone_lm.fst || exit 1
Expand All @@ -95,7 +101,7 @@ if [[ $stage -le 4 ]]; then
echo "creating denominator FST"
copy-transition-model $treedir/final.mdl exp/chain/0.trans_mdl
cp $treedir/tree exp/chain
$train_cmd exp/chain/log/make_den_fst.log \
"$train_cmd" exp/chain/log/make_den_fst.log \
chain-make-den-fst exp/chain/tree exp/chain/0.trans_mdl exp/chain/phone_lm.fst \
exp/chain/den.fst exp/chain/normalization.fst || exit 1
fi
Expand All @@ -119,7 +125,7 @@ if [[ $stage -le 5 ]]; then
--right-tolerance 5 \
--srand 0 \
--stage -10 \
data/fbank_pitch/train \
data/mfcc_hires/train \
exp/chain $lat_dir exp/chain/egs
fi

Expand Down Expand Up @@ -157,16 +163,17 @@ if [[ $stage -le 8 ]]; then

# sort the options alphabetically
python3 ./chain/train.py \
--bottleneck-dim $bottleneck_dim \
--checkpoint=${train_checkpoint:-} \
--conv-stride-list "$conv_stride_list" \
--device-id $device_id \
--dir exp/chain/train \
--feat-dim $feat_dim \
--hidden-dim $hidden_dim \
--is-training true \
--kernel-size-list "$kernel_size_list" \
--log-level $log_level \
--output-dim $output_dim \
--stride-list "$stride_list" \
--time-stride-list "$time_stride_list" \
--train.cegs-dir exp/chain/merged_egs \
--train.den-fst exp/chain/den.fst \
--train.egs-left-context $egs_left_context \
Expand All @@ -186,20 +193,21 @@ if [[ $stage -le 9 ]]; then
best_epoch=$(cat exp/chain/train/best-epoch-info | grep 'best epoch' | awk '{print $NF}')
inference_checkpoint=exp/chain/train/epoch-${best_epoch}.pt
python3 ./chain/inference.py \
--bottleneck-dim $bottleneck_dim \
--checkpoint $inference_checkpoint \
--conv-stride-list "$conv_stride_list"
--device-id $device_id \
--dir exp/chain/inference/$x \
--feat-dim $feat_dim \
--feats-scp data/fbank_pitch/$x/feats.scp \
--feats-scp data/mfcc_hires/$x/feats.scp \
--hidden-dim $hidden_dim \
--is-training false \
--kernel-size-list "$kernel_size_list" \
--log-level $log_level \
--model-left-context $model_left_context \
--model-right-context $model_right_context \
--output-dim $output_dim \
--save-as-compressed $save_nn_output_as_compressed \
--stride-list "$stride_list" || exit 1
--time-stride-list "$time_stride_list" || exit 1
fi
done
fi
Expand Down Expand Up @@ -228,7 +236,7 @@ if [[ $stage -le 11 ]]; then

for x in test dev; do
./local/score.sh --cmd "$decode_cmd" \
data/fbank_pitch/$x \
data/mfcc_hires/$x \
exp/chain/graph \
exp/chain/decode_res/$x || exit 1
done
Expand Down
Loading

0 comments on commit 9508860

Please sign in to comment.