diff --git a/repvgg.py b/repvgg.py index 4714afc..fbf4add 100644 --- a/repvgg.py +++ b/repvgg.py @@ -79,12 +79,12 @@ def _fuse_bn_tensor(self, branch): assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, 'id_tensor'): input_dim = self.in_channels // self.groups - kernel_value = np.zeros((self.in_channels, input_dim, 3, 3)) + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 self.id_tensor = torch.from_numpy(kernel_value) - if torch.cuda.is_available(): - self.id_tensor = self.id_tensor.cuda() + # if torch.cuda.is_available(): + # self.id_tensor = self.id_tensor.cuda() kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var