Skip to content

Commit

Permalink
add crop function for matchdimention
Browse files Browse the repository at this point in the history
  • Loading branch information
0417keito committed Dec 2, 2023
1 parent 2a44992 commit bc1c94f
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,23 @@ def decode1d(self, stft_pair: Tensor) -> Tensor:
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
return self.decode(stft_a, stft_b)

def crop(x1, x2):
shape_x1 = list(x1.shape)
shape_x2 = list(x2.shape)

diff = [shape_x1[i] - shape_x2[i] for i in range(len(shape_x1))]

assert(diff[0] == 0 and diff[1] == 0)
if diff[-1] == 0:
return x1, x2

crop_start = [d // 2 for d in diff]
crop_end = [d - s for d, s in zip(diff, crop_start)]

if diff[-1] > 0:
x1_cropped = x1[:, :, crop_start[-1]: -crop_end[-1]]
return x1_cropped, x2
else:
x2_cropped = x2[:, :, crop_start[-1]: -crop_end[-1]]
return x1, x2_cropped

0 comments on commit bc1c94f

Please sign in to comment.