Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add TDNNF to pytorch. #3892

Merged
merged 5 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def __call__(self, batch):


def _test_nnet_chain_example_dataset():
egs_dir = '/cache/fangjun/chain/aishell_kaldi_pybind/test'
egs_dir = 'exp/chain/merged_egs'
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
egs_left_context = 23
egs_right_context = 23
egs_left_context = 29
egs_right_context = 29
frame_subsampling_factor = 3

collate_fn = NnetChainExampleDatasetCollateFunc(
Expand All @@ -200,7 +200,9 @@ def _test_nnet_chain_example_dataset():
collate_fn=collate_fn)
for b in dataloader:
key_list, feature_list, supervision_list = b
assert feature_list[0].shape == (128, 192, 120)
assert feature_list[0].shape == (128, 204, 129) \
or feature_list[0].shape == (128, 144, 129) \
or feature_list[0].shape == (128, 165, 129)
assert supervision_list[0].weight == 1
supervision_list[0].num_sequences == 128 # minibach size is 128

Expand Down
5 changes: 3 additions & 2 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def main():
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
kernel_size_list=args.kernel_size_list,
stride_list=args.stride_list)
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)

load_checkpoint(args.checkpoint, model)

Expand Down
240 changes: 143 additions & 97 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import logging
Expand All @@ -10,109 +10,118 @@
import torch.nn.functional as F

from common import load_lda_mat
'''
input dim=$feat_dim name=input

# please note that it is important to have input layer with the name=input
# as the layer immediately preceding the fixed-affine-layer to enable
# the use of short notation for the descriptor
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=$dir/configs/lda.mat

# the first splicing is moved before the lda layer, so no splicing here
relu-batchnorm-layer name=tdnn1 dim=625
relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625
relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625
relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625
relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625
relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625

## adding the layers for chain branch
relu-batchnorm-layer name=prefinal-chain input=tdnn6 dim=625 target-rms=0.5
output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5

# adding the layers for xent branch
# This block prints the configs for a separate output that will be
# trained with a cross-entropy objective in the 'chain' models... this
# has the effect of regularizing the hidden parts of the model. we use
# 0.5 / args.xent_regularize as the learning rate factor- the factor of
# 0.5 / args.xent_regularize is suitable as it means the xent
# final-layer learns at a rate independent of the regularization
# constant; and the 0.5 was tuned so as to make the relative progress
# similar in the xent and regular final layers.
relu-batchnorm-layer name=prefinal-xent input=tdnn6 dim=625 target-rms=0.5
output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5
'''
from tdnnf_layer import FactorizedTDNN
from tdnnf_layer import OrthonormalLinear
from tdnnf_layer import PrefinalLayer


def get_chain_model(feat_dim,
output_dim,
hidden_dim,
kernel_size_list,
stride_list,
bottleneck_dim,
time_stride_list,
conv_stride_list,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
kernel_size_list=kernel_size_list,
stride_list=stride_list)
time_stride_list=time_stride_list,
conv_stride_list=conv_stride_list)
return model


'''
input dim=43 name=input

# please note that it is important to have input layer with the name=input
# as the layer immediately preceding the fixed-affine-layer to enable
# the use of short notation for the descriptor
fixed-affine-layer name=lda input=Append(-1,0,1) affine-transform-file=exp/chain_cleaned_1c/tdnn1c_sp/configs/lda.mat

# the first splicing is moved before the lda layer, so no splicing here
relu-batchnorm-dropout-layer name=tdnn1 l2-regularize=0.008 dropout-proportion=0.0 dropout-per-dim-continuous=true dim=1024
tdnnf-layer name=tdnnf2 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf3 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf4 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=1
tdnnf-layer name=tdnnf5 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=0
tdnnf-layer name=tdnnf6 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf7 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf8 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf9 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf10 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf11 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf12 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
tdnnf-layer name=tdnnf13 l2-regularize=0.008 dropout-proportion=0.0 bypass-scale=0.66 dim=1024 bottleneck-dim=128 time-stride=3
linear-component name=prefinal-l dim=256 l2-regularize=0.008 orthonormal-constraint=-1.0

prefinal-layer name=prefinal-chain input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
output-layer name=output include-log-softmax=false dim=3456 l2-regularize=0.002

prefinal-layer name=prefinal-xent input=prefinal-l l2-regularize=0.008 big-dim=1024 small-dim=256
output-layer name=output-xent dim=3456 learning-rate-factor=5.0 l2-regularize=0.002
'''


# Create a network like the above one
class ChainModel(nn.Module):

def __init__(self,
feat_dim,
output_dim,
lda_mat_filename,
hidden_dim=625,
kernel_size_list=[1, 3, 3, 3, 3, 3],
stride_list=[1, 1, 3, 1, 1, 1],
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
frame_subsampling_factor=3):
super().__init__()

# at present, we support only frame_subsampling_factor to be 3
assert frame_subsampling_factor == 3

assert len(kernel_size_list) == len(stride_list)
num_layers = len(kernel_size_list)
assert len(time_stride_list) == len(conv_stride_list)
num_layers = len(time_stride_list)

# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
out_features=hidden_dim)

tdnns = []
# tdnn1_batchnorm requires [N, C, T]
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim)

tdnnfs = []
for i in range(num_layers):
in_channels = hidden_dim
if i == 0:
in_channels = feat_dim * 3

kernel_size = kernel_size_list[i]
stride = stride_list[i]

# we do not need to perform padding in Conv1d because it
# has been included in left/right context while generating egs
layer = nn.Conv1d(in_channels=in_channels,
out_channels=hidden_dim,
kernel_size=kernel_size,
stride=stride)
tdnns.append(layer)

self.tdnns = nn.ModuleList(tdnns)
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(num_features=hidden_dim) for i in range(num_layers)
])

self.prefinal_chain_tdnn = nn.Conv1d(in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=1)
self.prefinal_chain_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
self.output_fc = nn.Linear(in_features=hidden_dim,
out_features=output_dim)

self.prefinal_xent_tdnn = nn.Conv1d(in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=1)
self.prefinal_xent_batch_norm = nn.BatchNorm1d(num_features=hidden_dim)
self.output_xent_fc = nn.Linear(in_features=hidden_dim,
out_features=output_dim)
time_stride = time_stride_list[i]
conv_stride = conv_stride_list[i]
layer = FactorizedTDNN(dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
time_stride=time_stride,
conv_stride=conv_stride)
tdnnfs.append(layer)

# tdnnfs requires [N, C, T]
self.tdnnfs = nn.ModuleList(tdnnfs)

# prefinal_l affine requires [N, C, T]
self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
bottleneck_dim=bottleneck_dim * 2,
time_stride=0)

# prefinal_chain requires [N, C, T]
self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)

# output_affine requires [N, T, C]
self.output_affine = nn.Linear(in_features=bottleneck_dim * 2,
out_features=output_dim)

# prefinal_xent requires [N, C, T]
self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)

self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2,
out_features=output_dim)

if lda_mat_filename:
logging.info('Use LDA from {}'.format(lda_mat_filename))
Expand Down Expand Up @@ -146,32 +155,69 @@ def forward(self, x):

# at this point, x is [N, C, T]

# Conv1d requires input of shape [N, C, T]
for i in range(len(self.tdnns)):
x = self.tdnns[i](x)
x = F.relu(x)
x = self.batch_norms[i](x)
x = x.permute(0, 2, 1)

# at this point, x is [N, T, C]

x = self.tdnn1_affine(x)

# at this point, x is [N, T, C]

x = F.relu(x)

x = x.permute(0, 2, 1)

# at this point, x is [N, C, T]

x = self.tdnn1_batchnorm(x)

# tdnnf requires input of shape [N, C, T]
for i in range(len(self.tdnnfs)):
x = self.tdnnfs[i](x)

# at this point, x is [N, C, T]

# we have two branches from this point on
x = self.prefinal_l(x)

# at this point, x is [N, C, T]

# first, for the chain branch
x_chain = self.prefinal_chain_tdnn(x)
x_chain = F.relu(x_chain)
x_chain = self.prefinal_chain_batch_norm(x_chain)
x_chain = x_chain.permute(0, 2, 1)
# at this point, x_chain is [N, T, C]
nnet_output = self.output_fc(x_chain)
# for the output node
nnet_output = self.prefinal_chain(x)

# now for the xent branch
x_xent = self.prefinal_xent_tdnn(x)
x_xent = F.relu(x_xent)
x_xent = self.prefinal_xent_batch_norm(x_xent)
x_xent = x_xent.permute(0, 2, 1)
# at this point, nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1)
# at this point, nnet_output is [N, T, C]
nnet_output = self.output_affine(nnet_output)

# for the xent node
xent_output = self.prefinal_xent(x)

# at this point, xent_output is [N, C, T]
xent_output = xent_output.permute(0, 2, 1)
# at this point, xent_output is [N, T, C]
xent_output = self.output_xent_affine(xent_output)

# at this point x_xent is [N, T, C]
xent_output = self.output_xent_fc(x_xent)
xent_output = F.log_softmax(xent_output, dim=-1)

return nnet_output, xent_output

def constrain_orthonormal(self):
for i in range(len(self.tdnnfs)):
self.tdnnfs[i].constrain_orthonormal()

self.prefinal_l.constrain_orthonormal()
self.prefinal_chain.constrain_orthonormal()
self.prefinal_xent.constrain_orthonormal()


if __name__ == '__main__':
feat_dim = 43
output_dim = 4344
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)
N = 1
T = 150 + 27 + 27
C = feat_dim * 3
x = torch.arange(N * T * C).reshape(N, T, C).float()
nnet_output, xent_output = model(x)
print(x.shape, nnet_output.shape, xent_output.shape)
model.constrain_orthonormal()
33 changes: 20 additions & 13 deletions egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ def _check_args(args):
assert args.feat_dim > 0
assert args.output_dim > 0
assert args.hidden_dim > 0
assert args.bottleneck_dim > 0

assert args.kernel_size_list is not None
assert len(args.kernel_size_list) > 0
assert args.time_stride_list is not None
assert len(args.time_stride_list) > 0

assert args.stride_list is not None
assert len(args.stride_list) > 0
assert args.conv_stride_list is not None
assert len(args.conv_stride_list) > 0

args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]

args.stride_list = [int(k) for k in args.stride_list.split(', ')]
args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]

assert len(args.kernel_size_list) == len(args.stride_list)
assert len(args.time_stride_list) == len(args.conv_stride_list)

assert args.log_level in ['debug', 'info', 'warning']

Expand Down Expand Up @@ -195,15 +196,21 @@ def get_args():
required=True,
type=int)

parser.add_argument('--kernel-size-list',
dest='kernel_size_list',
help='kernel size list',
parser.add_argument('--bottleneck-dim',
dest='bottleneck_dim',
help='nn bottleneck dimension',
required=True,
type=int)

parser.add_argument('--time-stride-list',
dest='time_stride_list',
help='time stride list',
required=True,
type=str)

parser.add_argument('--stride-list',
dest='stride_list',
help='stride list',
parser.add_argument('--conv-stride-list',
dest='conv_stride_list',
help='conv stride list',
required=True,
type=str)

Expand Down
Loading