diff --git a/utils/module.py b/utils/module.py index aada32e..85b5b3b 100644 --- a/utils/module.py +++ b/utils/module.py @@ -182,3 +182,23 @@ def decode1d(self, stft_pair: Tensor) -> Tensor: stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) return self.decode(stft_a, stft_b) + +def crop(x1, x2): + shape_x1 = list(x1.shape) + shape_x2 = list(x2.shape) + + diff = [shape_x1[i] - shape_x2[i] for i in range(len(shape_x1))] + + assert(diff[0] == 0 and diff[1] == 0) + if diff[-1] == 0: + return x1, x2 + + crop_start = [d // 2 for d in diff] + crop_end = [d - s for d, s in zip(diff, crop_start)] + + if diff[-1] > 0: + x1_cropped = x1[:, :, crop_start[-1]: -crop_end[-1]] + return x1_cropped, x2 + else: + x2_cropped = x2[:, :, crop_start[-1]: -crop_end[-1]] + return x1, x2_cropped