Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeech workload variants #628

Merged
merged 16 commits into from
Feb 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DeepspeechConfig:
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
use_tanh: bool = False
layernorm_everywhere: bool = False


class Subsample(nn.Module):
Expand All @@ -80,15 +82,18 @@ def __call__(self, inputs, output_paddings, train):
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=1,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
output_channels=config.encoder_dim,
use_tanh=config.use_tanh
)(outputs, output_paddings, train)

outputs, output_paddings = Conv2dSubsampling(
encoder_dim=config.encoder_dim,
dtype=config.dtype,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon,
input_channels=config.encoder_dim,
output_channels=config.encoder_dim)(outputs, output_paddings, train)
output_channels=config.encoder_dim,
use_tanh=config.use_tanh)(outputs, output_paddings, train)

batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape

Expand Down Expand Up @@ -127,6 +132,7 @@ class Conv2dSubsampling(nn.Module):
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_tanh: bool = False

def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
Expand All @@ -150,7 +156,11 @@ def __call__(self, inputs, paddings, train):
feature_group_count=feature_group_count)

outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
outputs = nn.relu(outputs)

if self.use_tanh:
outputs = nn.tanh(outputs)
else:
outputs = nn.relu(outputs)

# Computing correct paddings post input convolution.
input_length = paddings.shape[1]
Expand Down Expand Up @@ -182,16 +192,24 @@ def __call__(self, inputs, input_paddings=None, train=False):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config

inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs,
input_paddings,
train)
inputs = nn.Dense(
config.encoder_dim,
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
inputs)
inputs = nn.relu(inputs)
if config.use_tanh:
inputs = nn.tanh(inputs)
else:
inputs = nn.relu(inputs)
inputs *= padding_mask

if config.feed_forward_dropout_rate is None:
Expand Down Expand Up @@ -416,10 +434,15 @@ class BatchRNN(nn.Module):
def __call__(self, inputs, input_paddings, train):
config = self.config

inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs, input_paddings, train)
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
inputs = BatchNorm(config.encoder_dim,
config.dtype,
config.batch_norm_momentum,
config.batch_norm_epsilon)(inputs,
input_paddings,
train)
output = CudnnLSTM(
features=config.encoder_dim // 2,
bidirectional=config.bidirectional,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def init_model_fn(
model_config = models.DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
input_dropout_rate=aux_dropout_rate)
input_dropout_rate=aux_dropout_rate,
use_tanh=self.use_tanh,
enable_residual_connections=self.enable_residual_connections,
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
layernorm_everywhere=self.layernorm_everywhere,
freq_mask_count=self.freq_mask_count,
time_mask_count=self.time_mask_count,
)
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
Expand Down Expand Up @@ -67,3 +74,65 @@ def step_hint(self) -> int:
@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours

@property
def use_tanh(self) -> bool:
return False

@property
def enable_residual_connections(self) -> bool:
return True

@property
def enable_decoder_layer_norm(self) -> bool:
return True

@property
def layernorm_everywhere(self) -> bool:
return False

@property
def freq_mask_count(self) -> int:
return 2

@property
def time_mask_count(self) -> int:
return 10


class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):

@property
def use_tanh(self) -> bool:
return True


class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):

@property
def enable_residual_connections(self) -> bool:
return False


class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
):

@property
def eval_batch_size(self) -> int:
return 128

@property
def enable_decoder_layer_norm(self) -> bool:
return False

@property
def layernorm_everywhere(self) -> bool:
return True

@property
def freq_mask_count(self) -> int:
return 4

@property
def time_mask_count(self) -> int:
return 15
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class DeepspeechConfig:
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
use_tanh: bool = False
layernorm_everywhere: bool = False


class LayerNorm(nn.Module):
Expand Down Expand Up @@ -77,9 +79,11 @@ def __init__(self, config: DeepspeechConfig):
self.encoder_dim = encoder_dim

self.conv1 = Conv2dSubsampling(
input_channels=1, output_channels=encoder_dim)
input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim, output_channels=encoder_dim)
input_channels=encoder_dim,
output_channels=encoder_dim,
use_tanh=config.use_tanh)

self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)

Expand Down Expand Up @@ -115,7 +119,8 @@ def __init__(self,
filter_stride: Tuple[int] = (2, 2),
padding: str = 'SAME',
batch_norm_momentum: float = 0.999,
batch_norm_epsilon: float = 0.001):
batch_norm_epsilon: float = 0.001,
use_tanh: bool = False):
super().__init__()

self.input_channels = input_channels
Expand All @@ -129,6 +134,8 @@ def __init__(self,
nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
self.bias = nn.Parameter(torch.zeros(output_channels))

self.use_tanh = use_tanh

def get_same_padding(self, input_shape):
in_height, in_width = input_shape[2:]
stride_height, stride_width = self.filter_stride
Expand Down Expand Up @@ -162,7 +169,10 @@ def forward(self, inputs, paddings):
dilation=(1, 1),
groups=groups)

outputs = F.relu(outputs)
if self.use_tanh:
outputs = F.tanh(outputs)
else:
outputs = F.relu(outputs)

input_length = paddings.shape[1]
stride = self.filter_stride[0]
Expand All @@ -187,10 +197,13 @@ def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config

self.bn = BatchNorm(
dim=config.encoder_dim,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(
dim=config.encoder_dim,
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
if config.feed_forward_dropout_rate is None:
feed_forward_dropout_rate = 0.1
Expand All @@ -200,9 +213,18 @@ def __init__(self, config: DeepspeechConfig):

def forward(self, inputs, input_paddings):
padding_mask = (1 - input_paddings)[:, :, None]
inputs = self.bn(inputs, input_paddings)
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else: # batchnorm
inputs = self.bn_normalization_layer(inputs, input_paddings)

inputs = self.lin(inputs)
inputs = F.relu(inputs)

if self.config.use_tanh:
inputs = F.tanh(inputs)
else:
inputs = F.relu(inputs)

inputs = inputs * padding_mask
inputs = self.dropout(inputs)

Expand Down Expand Up @@ -265,9 +287,12 @@ def __init__(self, config: DeepspeechConfig):
bidirectional = config.bidirectional
self.bidirectional = bidirectional

self.bn = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(config.encoder_dim,
config.batch_norm_momentum,
config.batch_norm_epsilon)

if bidirectional:
self.lstm = nn.LSTM(
Expand All @@ -280,7 +305,10 @@ def __init__(self, config: DeepspeechConfig):
input_size=input_size, hidden_size=hidden_size, batch_first=True)

def forward(self, inputs, input_paddings):
inputs = self.bn(inputs, input_paddings)
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
else:
inputs = self.bn_normalization_layer(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, lengths, batch_first=True, enforce_sorted=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def init_model_fn(
DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
input_dropout_rate=aux_dropout_rate)).eval()
input_dropout_rate=aux_dropout_rate,
use_tanh=self.use_tanh,
enable_residual_connections=self.enable_residual_connections,
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
layernorm_everywhere=self.layernorm_everywhere,
freq_mask_count=self.freq_mask_count,
time_mask_count=self.time_mask_count)).eval()
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
# Run model once to initialize lazy layers.
t = MAX_INPUT_LENGTH
Expand Down Expand Up @@ -76,3 +82,65 @@ def step_hint(self) -> int:
@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours

@property
def use_tanh(self) -> bool:
return False

@property
def enable_residual_connections(self) -> bool:
return True

@property
def enable_decoder_layer_norm(self) -> bool:
return True

@property
def layernorm_everywhere(self) -> bool:
return False

@property
def freq_mask_count(self) -> int:
return 2

@property
def time_mask_count(self) -> int:
return 10


class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):

@property
def use_tanh(self) -> bool:
return True


class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):

@property
def enable_residual_connections(self) -> bool:
return False


class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
):

@property
def eval_batch_size(self) -> int:
return 128

@property
def enable_decoder_layer_norm(self) -> bool:
return False

@property
def layernorm_everywhere(self) -> bool:
return True

@property
def freq_mask_count(self) -> int:
return 4

@property
def time_mask_count(self) -> int:
return 15
Loading
Loading