diff --git a/asteroid/masknn/convolutional.py b/asteroid/masknn/convolutional.py index 333ce5383..1ec58df5d 100644 --- a/asteroid/masknn/convolutional.py +++ b/asteroid/masknn/convolutional.py @@ -486,7 +486,20 @@ class DCUMaskNet(BaseDCUMaskNet): Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders, and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time - strides of the encoders. + strides of the encoders. If `fix_length_mode` is not `None`, the time dimension + may is automatically padded or trimmed to a valid size before running it through + the network. + + .. note:: + If using `fix_length_mode="trim"`, the network's output will be all-zero at the + trimmed time-steps. You might want to ignore those time-steps in your loss function. + + The time-domain length of the network's internal working size (the trimmed size) + can be retrieved using :meth:`~asteroid.models.BaseDCUNet.get_masker_working_size`:: + + >>> dcu16 = DCUNet("DCUNet-16", fix_length_mode="trim") + >>> dcu16.get_masker_working_size(3 * 16000) + 45568 References [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net", diff --git a/asteroid/models/base_models.py b/asteroid/models/base_models.py index 35f4b11e9..30684c370 100644 --- a/asteroid/models/base_models.py +++ b/asteroid/models/base_models.py @@ -239,6 +239,20 @@ def forward(self, wav): reconstructed = pad_x_to_y(decoded, wav) return _shape_reconstructed(reconstructed, shape) + def get_masker_working_size(self, n_samples: int) -> int: + """Get the masker's internal working size for an input of `n_samples`. + + Generally, if `fix_length_mode = "pad"`, the internal working size is + `>= n_samples`, otherwise it is `<= n_samples`. + """ + x = torch.zeros(1, 1, n_samples) + tf_rep = self.forward_encoder(x) + tf_rep = self.masker.fix_input_dims(tf_rep) + masked = self.apply_masks(tf_rep, 1) + decoded = self.forward_decoder(masked) + return decoded.shape[-1] + + def forward_encoder(self, wav: torch.Tensor) -> torch.Tensor: """Computes time-frequency representation of `wav`. diff --git a/asteroid/models/dcunet.py b/asteroid/models/dcunet.py index a83759f82..965914f04 100644 --- a/asteroid/models/dcunet.py +++ b/asteroid/models/dcunet.py @@ -1,3 +1,4 @@ +import torch from asteroid_filterbanks import make_enc_dec from asteroid_filterbanks.transforms import from_torch_complex, to_torch_complex from ..masknn.convolutional import DCUMaskNet diff --git a/tests/models/models_test.py b/tests/models/models_test.py index 985b60890..0480da78e 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -175,6 +175,16 @@ def test_dcunet(): DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64)) +def test_masker_working_size(): + dcu_mini_trim = DCUNet("mini", fix_length_mode="trim") + dcu_mini_pad = DCUNet("mini", fix_length_mode="pad") + dccrn_mini = DCCRNet("mini") + inp_size = 3 * 16000 + assert dcu_mini_trim.get_masker_working_size(inp_size) == 47616 + assert dcu_mini_pad.get_masker_working_size(inp_size) == 49664 + assert dccrn_mini.get_masker_working_size(inp_size) == 47872 + + def test_dccrnet(): _, istft = make_enc_dec("stft", 512, 512) input_samples = istft(torch.zeros((514, 16))).shape[0]