From bc1c94fc130f53387ad0dacb74017b7ff553b99d Mon Sep 17 00:00:00 2001 From: 0417keito <0417keito@gmail.com> Date: Sun, 3 Dec 2023 00:51:44 +0900 Subject: [PATCH] add crop function for matchdimention --- utils/module.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/utils/module.py b/utils/module.py index aada32e..85b5b3b 100644 --- a/utils/module.py +++ b/utils/module.py @@ -182,3 +182,23 @@ def decode1d(self, stft_pair: Tensor) -> Tensor: stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) return self.decode(stft_a, stft_b) + +def crop(x1, x2): + shape_x1 = list(x1.shape) + shape_x2 = list(x2.shape) + + diff = [shape_x1[i] - shape_x2[i] for i in range(len(shape_x1))] + + assert(diff[0] == 0 and diff[1] == 0) + if diff[-1] == 0: + return x1, x2 + + crop_start = [d // 2 for d in diff] + crop_end = [d - s for d, s in zip(diff, crop_start)] + + if diff[-1] > 0: + x1_cropped = x1[:, :, crop_start[-1]: -crop_end[-1]] + return x1_cropped, x2 + else: + x2_cropped = x2[:, :, crop_start[-1]: -crop_end[-1]] + return x1, x2_cropped