From 13f6ce122eaa3cf014ed68abb30029920e29f705 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 23 Apr 2024 21:36:29 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fixed=20OjasRule?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/leibnetz/nets/bio.py | 100 +++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/src/leibnetz/nets/bio.py b/src/leibnetz/nets/bio.py index 98c6aa3..f730f2a 100644 --- a/src/leibnetz/nets/bio.py +++ b/src/leibnetz/nets/bio.py @@ -288,6 +288,7 @@ def update(self, module, args, kwargs, output): with torch.no_grad(): inputs = args[0] if hasattr(module, "kernel_size"): + # d_W = Y*(X - Y*W) ndims = len(module.kernel_size) # Extract patches for convolutional layers X = extract_kernel_patches( @@ -299,8 +300,16 @@ def update(self, module, args, kwargs, output): ) # = c1 x 3 x3 x N Y = extract_image_patches( output, module.out_channels, ndims - ).T # = N x c2 - d_W = X @ Y # = c1 x 3 x 3 x 3 x c2 + ) # = c2 x N + W = module.weight # = c2 x c1 x 3 x 3 x 3 + if ndims == 2: + W = W.permute(1, 2, 3, 0) + if ndims == 3: + W = W.permute(1, 2, 3, 4, 0) + elif ndims == 4: + W = W.permute(1, 2, 3, 4, 5, 0) + + d_W = (X - W @ Y) @ Y.T # = c1 x 3 x 3 x 3 x c2 if ndims == 2: d_W = d_W.permute(3, 0, 1, 2) if ndims == 3: @@ -309,12 +318,11 @@ def update(self, module, args, kwargs, output): d_W = d_W.permute(5, 0, 1, 2, 3, 4) else: ndims = None - d_W = inputs * output # = c1 x c2 + W = module.weight + d_W = output @ (inputs - output @ W) d_W = self.normalize_fcn(d_W) module.weight.data += d_W * self.learning_rate - d_W = self.c * ((inputs**2) * module.weight - (inputs**2) * (module.weight**3)) - return d_W def extract_kernel_patches(x, channels, kernel_size, stride, dilation, padding=0): @@ -415,51 +423,39 @@ def convert_to_backprop(model: LeibNet): return model -# %% -from leibnetz.nets import build_unet - -unet = build_unet() -inputs = unet.get_example_inputs()["input"] -batch_size = 2 -inputs = torch.cat( - [ - inputs, - ] - * batch_size, - dim=0, -) -for module in unet.modules(): - if hasattr(module, "weight"): - break -output = module(inputs) -X = extract_kernel_patches( - inputs, module.in_channels, module.kernel_size, module.stride, module.dilation -) -Y = extract_image_patches(output, module.out_channels, len(module.kernel_size)).T -weights = module.weight -print(X.shape) -print(Y.shape) -# weights = weights.view( -# *module.weight.shape[: -len(module.kernel_size)], -# torch.prod(torch.as_tensor(module.kernel_size)), -# ) +# # %% +# from leibnetz.nets import build_unet -# print(inputs.shape) -# print(weights.shape) -# ((inputs**2) * weights).shape - -# %% -from leibnetz.nets import build_unet - -unet = build_unet() -model = convert_to_bio(unet, HebbsRule()) -batch = model(model.get_example_inputs()) -# %% -import numpy as np - -c = 0.1 -pad = (np.array(kernel_size) - 1) // 2 -c * output * ( - inputs[..., pad[0] : -pad[0], pad[1] : -pad[1], pad[2] : -pad[2]] - output * weights -) -# %% +# unet = build_unet() +# inputs = unet.get_example_inputs()["input"] +# batch_size = 2 +# inputs = torch.cat( +# [ +# inputs, +# ] +# * batch_size, +# dim=0, +# ) +# for module in unet.modules(): +# if hasattr(module, "weight"): +# break +# output = module(inputs) +# X = extract_kernel_patches( +# inputs, module.in_channels, module.kernel_size, module.stride, module.dilation +# ) +# Y = extract_image_patches(output, module.out_channels, len(module.kernel_size)) +# W = module.weight # == c2 x c1 x 3 x 3 x 3 +# print(X.shape) +# print(Y.shape) +# # d_W = Y*(X - Y*W) +# (X - (W.permute(1, 2, 3, 4, 0) @ Y)) # == c1 x 3 x 3 x 3 x N +# ((X - (W.permute(1, 2, 3, 4, 0) @ Y)) @ Y.T).permute( +# 4, 0, 1, 2, 3 +# ) # == c2 x c1 x 3 x 3 x 3 +# # %% +# from leibnetz.nets import build_unet + +# unet = build_unet() +# model = convert_to_bio(unet, OjasRule()) +# batch = model(model.get_example_inputs()) +# # %%