Skip to content

Commit

Permalink
Fix chunk issue for sherpa (#1316)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni authored Oct 18, 2023
1 parent d2bd093 commit 807816f
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,33 @@
# limitations under the License.

import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import logging

import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from torch import Tensor, nn

Expand Down Expand Up @@ -2098,7 +2103,7 @@ def forward(
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels

s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)

# s will go through tanh.

Expand Down Expand Up @@ -2151,7 +2156,7 @@ def streaming_forward(
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels

s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)

# s will go through tanh.
s = self.tanh(s)
Expand Down Expand Up @@ -2308,7 +2313,7 @@ def forward(

x = self.in_proj(x) # (time, batch, 2*channels)

x, s = x.chunk(2, dim=-1)
x, s = x.chunk(2, dim=2)
s = self.balancer1(s)
s = self.sigmoid(s)
x = self.activation1(x) # identity.
Expand Down

0 comments on commit 807816f

Please sign in to comment.