Skip to content

Commit

Permalink
fix: 🐛 Fixed OjasRule
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 24, 2024
1 parent 74545fe commit 13f6ce1
Showing 1 changed file with 48 additions and 52 deletions.
100 changes: 48 additions & 52 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def update(self, module, args, kwargs, output):
with torch.no_grad():
inputs = args[0]
if hasattr(module, "kernel_size"):
# d_W = Y*(X - Y*W)
ndims = len(module.kernel_size)
# Extract patches for convolutional layers
X = extract_kernel_patches(
Expand All @@ -299,8 +300,16 @@ def update(self, module, args, kwargs, output):
) # = c1 x 3 x3 x N
Y = extract_image_patches(
output, module.out_channels, ndims
).T # = N x c2
d_W = X @ Y # = c1 x 3 x 3 x 3 x c2
) # = c2 x N
W = module.weight # = c2 x c1 x 3 x 3 x 3
if ndims == 2:
W = W.permute(1, 2, 3, 0)
if ndims == 3:
W = W.permute(1, 2, 3, 4, 0)
elif ndims == 4:
W = W.permute(1, 2, 3, 4, 5, 0)

d_W = (X - W @ Y) @ Y.T # = c1 x 3 x 3 x 3 x c2
if ndims == 2:
d_W = d_W.permute(3, 0, 1, 2)
if ndims == 3:
Expand All @@ -309,12 +318,11 @@ def update(self, module, args, kwargs, output):
d_W = d_W.permute(5, 0, 1, 2, 3, 4)
else:
ndims = None
d_W = inputs * output # = c1 x c2
W = module.weight
d_W = output @ (inputs - output @ W)

d_W = self.normalize_fcn(d_W)
module.weight.data += d_W * self.learning_rate
d_W = self.c * ((inputs**2) * module.weight - (inputs**2) * (module.weight**3))
return d_W


def extract_kernel_patches(x, channels, kernel_size, stride, dilation, padding=0):
Expand Down Expand Up @@ -415,51 +423,39 @@ def convert_to_backprop(model: LeibNet):
return model


# %%
from leibnetz.nets import build_unet

unet = build_unet()
inputs = unet.get_example_inputs()["input"]
batch_size = 2
inputs = torch.cat(
[
inputs,
]
* batch_size,
dim=0,
)
for module in unet.modules():
if hasattr(module, "weight"):
break
output = module(inputs)
X = extract_kernel_patches(
inputs, module.in_channels, module.kernel_size, module.stride, module.dilation
)
Y = extract_image_patches(output, module.out_channels, len(module.kernel_size)).T
weights = module.weight
print(X.shape)
print(Y.shape)
# weights = weights.view(
# *module.weight.shape[: -len(module.kernel_size)],
# torch.prod(torch.as_tensor(module.kernel_size)),
# )
# # %%
# from leibnetz.nets import build_unet

# print(inputs.shape)
# print(weights.shape)
# ((inputs**2) * weights).shape

# %%
from leibnetz.nets import build_unet

unet = build_unet()
model = convert_to_bio(unet, HebbsRule())
batch = model(model.get_example_inputs())
# %%
import numpy as np

c = 0.1
pad = (np.array(kernel_size) - 1) // 2
c * output * (
inputs[..., pad[0] : -pad[0], pad[1] : -pad[1], pad[2] : -pad[2]] - output * weights
)
# %%
# unet = build_unet()
# inputs = unet.get_example_inputs()["input"]
# batch_size = 2
# inputs = torch.cat(
# [
# inputs,
# ]
# * batch_size,
# dim=0,
# )
# for module in unet.modules():
# if hasattr(module, "weight"):
# break
# output = module(inputs)
# X = extract_kernel_patches(
# inputs, module.in_channels, module.kernel_size, module.stride, module.dilation
# )
# Y = extract_image_patches(output, module.out_channels, len(module.kernel_size))
# W = module.weight # == c2 x c1 x 3 x 3 x 3
# print(X.shape)
# print(Y.shape)
# # d_W = Y*(X - Y*W)
# (X - (W.permute(1, 2, 3, 4, 0) @ Y)) # == c1 x 3 x 3 x 3 x N
# ((X - (W.permute(1, 2, 3, 4, 0) @ Y)) @ Y.T).permute(
# 4, 0, 1, 2, 3
# ) # == c2 x c1 x 3 x 3 x 3
# # %%
# from leibnetz.nets import build_unet

# unet = build_unet()
# model = convert_to_bio(unet, OjasRule())
# batch = model(model.get_example_inputs())
# # %%

0 comments on commit 13f6ce1

Please sign in to comment.