From fcccbedf579149acdb8014656aa666ebc885ec98 Mon Sep 17 00:00:00 2001 From: jintonic3561 Date: Wed, 12 Jul 2023 10:38:07 +0900 Subject: [PATCH] fix: inplace operation and avoid for loop --- models/SCINet.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/models/SCINet.py b/models/SCINet.py index 443f955..e510a22 100644 --- a/models/SCINet.py +++ b/models/SCINet.py @@ -166,16 +166,12 @@ def __init__(self, in_planes, current_level, kernel_size, dropout, groups, hidde def zip_up_the_pants(self, even, odd): even = even.permute(1, 0, 2) odd = odd.permute(1, 0, 2) #L, B, D - even_len = even.shape[0] - odd_len = odd.shape[0] - mlen = min((odd_len, even_len)) - _ = [] - for i in range(mlen): - _.append(even[i].unsqueeze(0)) - _.append(odd[i].unsqueeze(0)) - if odd_len < even_len: - _.append(even[-1].unsqueeze(0)) - return torch.cat(_,0).permute(1,0,2) #B, L, D + + total_len = even.shape[0] + odd.shape[0] + x = torch.zeros(total_len, even.shape[1], even.shape[2]).cuda() + x[0::2] = even + x[1::2] = odd + return x.permute(1,0,2) def forward(self, x): x_even_update, x_odd_update= self.workingblock(x) @@ -326,9 +322,9 @@ def forward(self, x): if self.pe: pe = self.get_position_encoding(x) if pe.shape[2] > x.shape[2]: - x += pe[:, :, :-1] + x = x + pe[:, :, :-1] else: - x += self.get_position_encoding(x) + x = x + self.get_position_encoding(x) ### activated when RIN flag is set ### if self.RIN: @@ -338,7 +334,7 @@ def forward(self, x): x = x - means #var stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5) - x /= stdev + x = x / stdev # affine # print(x.shape,self.affine_weight.shape,self.affine_bias.shape) x = x * self.affine_weight + self.affine_bias @@ -346,7 +342,7 @@ def forward(self, x): # the first stack res1 = x x = self.blocks1(x) - x += res1 + x = x + res1 if self.num_decoder_layer == 1: x = self.projection1(x) else: @@ -380,7 +376,7 @@ def forward(self, x): # the second stack res2 = x x = self.blocks2(x) - x += res2 + x = x + res2 x = self.projection2(x) ### Reverse RIN ###