From 8966fd95851f5d621b48796ae7b57fe7db276553 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 19 Apr 2024 17:43:23 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9A=A1=EF=B8=8F=20Vectorize=20Ojas?= =?UTF-8?q?=20rule.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/leibnetz/leibnet.py | 8 + src/leibnetz/nets/attentive_scalenet.py | 7 +- src/leibnetz/nets/bio.py | 178 +++++++++++++++--- src/leibnetz/nets/scalenet.py | 11 +- .../nodes/additive_attention_gate_node.py | 1 + src/leibnetz/nodes/node_ops.py | 7 +- 6 files changed, 179 insertions(+), 33 deletions(-) diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index f0fa324..de5d0da 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -374,6 +374,14 @@ def output_shapes(self): def array_shapes(self): return self._get_shapes(self._array_shapes) + @property + def param_num(self): + param_num = 0 + for key, val in self.named_parameters(): + # print(f"{key}: {val.shape}") + param_num += val.numel() + return param_num + def mps(self): if torch.backends.mps.is_available(): self.to("mps") diff --git a/src/leibnetz/nets/attentive_scalenet.py b/src/leibnetz/nets/attentive_scalenet.py index 2e50e3e..5f3ac5a 100644 --- a/src/leibnetz/nets/attentive_scalenet.py +++ b/src/leibnetz/nets/attentive_scalenet.py @@ -151,7 +151,12 @@ def build_subnet( # %% -def build_attentive_scale_net(subnet_dict_list: list[dict]): +def build_attentive_scale_net( + subnet_dict_list: list[dict] = [ + {"top_resolution": (32, 32, 32)}, + {"top_resolution": (8, 8, 8)}, + ] +): nodes = [] outputs = {} bottleneck_input_dict = None diff --git a/src/leibnetz/nets/bio.py b/src/leibnetz/nets/bio.py index a73d151..26ebff5 100644 --- a/src/leibnetz/nets/bio.py +++ b/src/leibnetz/nets/bio.py @@ -60,6 +60,93 @@ def update(self, inputs: torch.Tensor, weights: torch.Tensor): return torch.mean(d_ws, dim=0) +class RhoadesRule(LearningRule): + """Rule modifying Krotov-Hopfield Hebbian learning rule fast implementation. + + Args: + precision: Numerical precision of the weight updates. + delta: Anti-hebbian learning strength. + norm: Lebesgue norm of the weights. + k_ratio: Ranking parameter + """ + + def __init__( + self, k_ratio=0.5, delta=0.4, norm=2, normalize=False, precision=1e-30 + ): + super().__init__() + self.precision = precision + self.delta = delta + self.norm = norm + assert k_ratio <= 1, "k_ratio should be smaller or equal to 1" + self.k_ratio = k_ratio + self.normalize = normalize + + def __str__(self): + return f"RhoadesRule(k_ratio={self.k_ratio}, delta={self.delta}, norm={self.norm}, normalize={self.normalize})" + + def init_layers(self, layer): + if hasattr(layer, "weight"): + layer.weight.data.normal_(mean=0.0, std=1.0) + + def update(self, inputs: torch.Tensor, module: torch.nn.Module): + # TODO: WIP + if hasattr(module, "kernel_size"): + # Extract patches for convolutional layers + inputs = extract_image_patches( + inputs, module.kernel_size, module.stride, module.dilation + ) + weights = module.weight.view( + -1, torch.prod(torch.as_tensor(module.kernel_size)) + ) + else: + weights = module.weight + inputs = inputs.view(inputs.size(0), -1) + + # TODO: needs re-implementation + batch_size = inputs.shape[0] + num_hidden_units = torch.prod(torch.as_tensor(weights.shape)) + input_size = inputs[0].shape[0] + k = int(self.k_ratio * num_hidden_units) + + # TODO: WIP + if self.normalize: + norm = torch.norm(inputs, dim=1) + norm[norm == 0] = 1 + inputs = torch.div(inputs, norm.view(-1, 1)) + + inputs = torch.t(inputs) + + # Calculate overlap for each hidden unit and input sample + tot_input = torch.matmul( + torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs + ) + + # Get the top k activations for each input sample (hidden units ranked per input sample) + _, indices = torch.topk(tot_input, k=k, dim=0) + + # Apply the activation function for each input sample + activations = torch.zeros((num_hidden_units, batch_size), device=weights.device) + activations[indices[0], torch.arange(batch_size)] = 1.0 + activations[indices[k - 1], torch.arange(batch_size)] = -self.delta + + # Sum the activations for each hidden unit, the batch dimension is removed here + xx = torch.sum(torch.mul(activations, tot_input), 1) + + # Apply the actual learning rule, from here on the tensor has the same dimension as the weights + norm_factor = torch.mul( + xx.view(xx.shape[0], 1).repeat((1, input_size)), weights + ) + ds = torch.matmul(activations, torch.t(inputs)) - norm_factor + + # Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate) + nc = torch.max(torch.abs(ds)) + if nc < self.precision: + nc = self.precision + d_w = torch.true_divide(ds, nc) + + return d_w + + class KrotovsRule(LearningRule): """Krotov-Hopfield Hebbian learning rule fast implementation. @@ -69,7 +156,7 @@ class KrotovsRule(LearningRule): precision: Numerical precision of the weight updates. delta: Anti-hebbian learning strength. norm: Lebesgue norm of the weights. - k: Ranking parameter + k_ratio: Ranking parameter """ def __init__( @@ -90,7 +177,19 @@ def init_layers(self, layer): if hasattr(layer, "weight"): layer.weight.data.normal_(mean=0.0, std=1.0) - def update(self, inputs: torch.Tensor, weights: torch.Tensor): + def update(self, inputs: torch.Tensor, module: torch.nn.Module): + if hasattr(module, "kernel_size"): + # Extract patches for convolutional layers + inputs = extract_image_patches( + inputs, module.kernel_size, module.stride, module.dilation + ) + weights = module.weight.view( + -1, torch.prod(torch.as_tensor(module.kernel_size)) + ) + else: + weights = module.weight + inputs = inputs.view(inputs.size(0), -1) + batch_size = inputs.shape[0] num_hidden_units = weights.shape[0] input_size = inputs[0].shape[0] @@ -144,20 +243,17 @@ def __init__(self, c=0.1): def __str__(self): return f"OjasRule(c={self.c})" - def update(self, inputs: torch.Tensor, weights: torch.Tensor): - # TODO: needs re-implementation - d_ws = torch.zeros(inputs.size(0), *weights.shape) - for idx, x in enumerate(inputs): - y = torch.mm(weights, x.unsqueeze(1)) - - d_w = torch.zeros(weights.shape) - for i in range(y.shape[0]): - for j in range(x.shape[0]): - d_w[i, j] = self.c * y[i] * (x[j] - y[i] * weights[i, j]) - - d_ws[idx] = d_w + def update(self, inputs: torch.Tensor, module: torch.nn.Module): + # dW = c (I**2 * W - I**2 * W**3) / n - return torch.mean(d_ws, dim=0) + if hasattr(module, "kernel_size"): + # Extract patches for convolutional layers + inputs = extract_image_patches( + inputs, module.kernel_size, module.stride, module.dilation + ) + d_W = self.c * ((inputs**2) * module.weight - (inputs**2) * (module.weight**3)) + d_W = d_W.sum(dim=1) / inputs.shape[0] + return d_W def extract_image_patches(x, kernel_size, stride, dilation, padding=0): @@ -198,36 +294,38 @@ def convert_to_bio( LeibNet: Model with local learning rules applied in forward hooks. """ + @torch.no_grad() def hook(module, args, kwargs, output): if module.training: with torch.no_grad(): - if hasattr(module, "kernel_size"): - # Extract patches for convolutional layers - inputs = extract_image_patches( - args[0], module.kernel_size, module.stride, module.dilation - ) - weights = module.weight.view( - -1, torch.prod(torch.as_tensor(module.kernel_size)) - ) + inputs = args[0] + out = learning_rule.update(inputs, module) + if isinstance(out, tuple): + d_w = out[0] + output = out[1] else: - inputs = args[0] - weights = module.weight - inputs = inputs.view(inputs.size(0), -1) - d_w = learning_rule.update(inputs, weights) + d_w = out d_w = d_w.view(module.weight.size()) module.weight.data += d_w * learning_rate + output = output.detach().requires_grad_(False) + + return output hooks = [] for module in model.modules(): if hasattr(module, "weight"): hooks.append(module.register_forward_hook(hook, with_kwargs=True)) - module.weight.requires_grad = False if init_layers: learning_rule.init_layers(module) + module.weight.requires_grad = False + setattr(module, "learning_rule", learning_rule) + setattr(module, "learning_rate", learning_rate) + setattr(module, "learning_hook", hooks[-1]) setattr(model, "learning_rule", learning_rule) setattr(model, "learning_rate", learning_rate) setattr(model, "learning_hooks", hooks) + model.requires_grad_(False) return model @@ -255,3 +353,27 @@ def convert_to_backprop(model: LeibNet): ) return model + + +# %% +from leibnetz.nets import build_unet + +unet = build_unet() +inputs = unet.get_example_inputs()["input"] +for module in unet.modules(): + if hasattr(module, "weight"): + break +output = module(inputs) +inputs = extract_image_patches( + inputs, module.kernel_size, module.stride, module.dilation +) +# inputs = inputs.view(inputs.size(0), -1) +weights = module.weight +# weights = weights.view(-1, torch.prod(torch.as_tensor(module.kernel_size))) + +print(inputs.shape) +print(weights.shape) +((inputs**2) * weights).shape +# # # %% + +# %% diff --git a/src/leibnetz/nets/scalenet.py b/src/leibnetz/nets/scalenet.py index 76f18f1..2a0beae 100644 --- a/src/leibnetz/nets/scalenet.py +++ b/src/leibnetz/nets/scalenet.py @@ -132,7 +132,12 @@ def build_subnet( # %% -def build_scale_net(subnet_dict_list: list[dict]): +def build_scale_net( + subnet_dict_list: list[dict] = [ + {"top_resolution": (32, 32, 32)}, + {"top_resolution": (8, 8, 8)}, + ] +): nodes = [] outputs = {} bottleneck_input_dict = None @@ -158,6 +163,10 @@ def testing(): {"top_resolution": (8, 8, 8)}, ] scalenet = build_scale_net(subnet_dict_list) + param_num = 0 + for key, val in scalenet.named_parameters(): + print(f"{key}: {val.shape}") + param_num += val.numel() scalenet.array_shapes # %% inputs = scalenet.get_example_inputs() diff --git a/src/leibnetz/nodes/additive_attention_gate_node.py b/src/leibnetz/nodes/additive_attention_gate_node.py index 935d77a..ffba406 100644 --- a/src/leibnetz/nodes/additive_attention_gate_node.py +++ b/src/leibnetz/nodes/additive_attention_gate_node.py @@ -5,6 +5,7 @@ from leibnetz.nodes.node_ops import ConvPass +# TODO: Not 2D compatible class AdditiveAttentionGateNode(Node): def __init__( self, diff --git a/src/leibnetz/nodes/node_ops.py b/src/leibnetz/nodes/node_ops.py index 083ce9d..307e54c 100644 --- a/src/leibnetz/nodes/node_ops.py +++ b/src/leibnetz/nodes/node_ops.py @@ -79,8 +79,6 @@ def __init__( if norm_layer is not None: layers.append(norm_layer(input_nc)) - layers.append(self.activation) - if dropout_prob is not None and dropout_prob > 0: layers.append(nn.Dropout(dropout_prob)) @@ -112,6 +110,9 @@ def __init__( bias=False, groups=groups, ) + else: + layers.append(self.activation) + except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) @@ -139,4 +140,4 @@ def forward(self, x): init_x = self.crop(self.x_init_map(x), res.size()[-self.dims :]) else: init_x = self.x_init_map(x) - return res + init_x + return self.activation(res + init_x)