Skip to content

Commit

Permalink
fix: 🐛 Fixed HebbsRule
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 23, 2024
1 parent 8966fd9 commit 74545fe
Showing 1 changed file with 165 additions and 79 deletions.
244 changes: 165 additions & 79 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
"""
This code is taken from https://github.com/Joxis/pytorch-hebbian.git
The code is licensed under the MIT license.
Please reference the following paper if you use this code:
@inproceedings{talloen2020pytorchhebbian,
author = {Jules Talloen and Joni Dambre and Alexander Vandesompele},
location = {Online},
title = {PyTorch-Hebbian: facilitating local learning in a deep learning framework},
year = {2020},
}
"""

# %%
import logging
from abc import ABC, abstractmethod
Expand All @@ -20,6 +7,18 @@


class LearningRule(ABC):
"""
This code is taken from https://github.com/Joxis/pytorch-hebbian.git
The code is licensed under the MIT license.
Please reference the following paper if you use this code:
@inproceedings{talloen2020pytorchhebbian,
author = {Jules Talloen and Joni Dambre and Alexander Vandesompele},
location = {Online},
title = {PyTorch-Hebbian: facilitating local learning in a deep learning framework},
year = {2020},
}
"""

def __init__(self):
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
Expand All @@ -37,27 +36,51 @@ def update(self, x, w):

class HebbsRule(LearningRule):

def __init__(self, c=0.1):
def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}):
super().__init__()
self.c = c
self.learning_rate = learning_rate
self.normalize_kwargs = normalize_kwargs
if normalize_kwargs is not None:
self.normalize_fcn = lambda x: torch.nn.functional.normalize(
x, **normalize_kwargs
)
else:
self.normalize_fcn = lambda x: x

def __str__(self):
return f"HebbsRule(c={self.c})"

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: Needs re-implementation
d_ws = torch.zeros(inputs.size(0))
for idx, x in enumerate(inputs):
y = torch.dot(weights, x)
return f"HebbsRule(learning_rate={self.learning_rate}, normalize_kwargs={self.normalize_kwargs})"

d_w = torch.zeros(weights.shape)
for i in range(y.shape[0]):
for j in range(x.shape[0]):
d_w[i, j] = self.c * x[j] * y[i]

d_ws[idx] = d_w
@torch.no_grad()
def update(self, module, args, kwargs, output):
if module.training:
with torch.no_grad():
inputs = args[0]
if hasattr(module, "kernel_size"):
ndims = len(module.kernel_size)
# Extract patches for convolutional layers
X = extract_kernel_patches(
inputs,
module.in_channels,
module.kernel_size,
module.stride,
module.dilation,
) # = c1 x 3 x3 x N
Y = extract_image_patches(
output, module.out_channels, ndims
).T # = N x c2
d_W = X @ Y # = c1 x 3 x 3 x 3 x c2
if ndims == 2:
d_W = d_W.permute(3, 0, 1, 2)
if ndims == 3:
d_W = d_W.permute(4, 0, 1, 2, 3)
elif ndims == 4:
d_W = d_W.permute(5, 0, 1, 2, 3, 4)
else:
ndims = None
d_W = inputs * output # = c1 x c2

return torch.mean(d_ws, dim=0)
d_W = self.normalize_fcn(d_W)
module.weight.data += d_W * self.learning_rate


class RhoadesRule(LearningRule):
Expand Down Expand Up @@ -92,7 +115,7 @@ def update(self, inputs: torch.Tensor, module: torch.nn.Module):
# TODO: WIP
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs = extract_kernel_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
Expand Down Expand Up @@ -149,7 +172,16 @@ def update(self, inputs: torch.Tensor, module: torch.nn.Module):

class KrotovsRule(LearningRule):
"""Krotov-Hopfield Hebbian learning rule fast implementation.
This code is taken from https://github.com/Joxis/pytorch-hebbian.git
The code is licensed under the MIT license.
Please reference the following paper if you use this code:
@inproceedings{talloen2020pytorchhebbian,
author = {Jules Talloen and Joni Dambre and Alexander Vandesompele},
location = {Online},
title = {PyTorch-Hebbian: facilitating local learning in a deep learning framework},
year = {2020},
}
Original source: https://github.com/DimaKrotov/Biological_Learning
Args:
Expand Down Expand Up @@ -180,7 +212,7 @@ def init_layers(self, layer):
def update(self, inputs: torch.Tensor, module: torch.nn.Module):
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs = extract_kernel_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
Expand Down Expand Up @@ -236,27 +268,56 @@ def update(self, inputs: torch.Tensor, module: torch.nn.Module):

class OjasRule(LearningRule):

def __init__(self, c=0.1):
def __init__(self, learning_rate=0.1, normalize_kwargs={"dim": 0}):
super().__init__()
self.c = c
self.learning_rate = learning_rate
self.normalize_kwargs = normalize_kwargs
if normalize_kwargs is not None:
self.normalize_fcn = lambda x: torch.nn.functional.normalize(
x, **normalize_kwargs
)
else:
self.normalize_fcn = lambda x: x

def __str__(self):
return f"OjasRule(c={self.c})"
return f"OjasRule(learning_rate={self.learning_rate}, normalize_kwargs={self.normalize_kwargs})"

def update(self, inputs: torch.Tensor, module: torch.nn.Module):
# dW = c (I**2 * W - I**2 * W**3) / n
@torch.no_grad()
def update(self, module, args, kwargs, output):
if module.training:
with torch.no_grad():
inputs = args[0]
if hasattr(module, "kernel_size"):
ndims = len(module.kernel_size)
# Extract patches for convolutional layers
X = extract_kernel_patches(
inputs,
module.in_channels,
module.kernel_size,
module.stride,
module.dilation,
) # = c1 x 3 x3 x N
Y = extract_image_patches(
output, module.out_channels, ndims
).T # = N x c2
d_W = X @ Y # = c1 x 3 x 3 x 3 x c2
if ndims == 2:
d_W = d_W.permute(3, 0, 1, 2)
if ndims == 3:
d_W = d_W.permute(4, 0, 1, 2, 3)
elif ndims == 4:
d_W = d_W.permute(5, 0, 1, 2, 3, 4)
else:
ndims = None
d_W = inputs * output # = c1 x c2

if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
d_W = self.normalize_fcn(d_W)
module.weight.data += d_W * self.learning_rate
d_W = self.c * ((inputs**2) * module.weight - (inputs**2) * (module.weight**3))
d_W = d_W.sum(dim=1) / inputs.shape[0]
return d_W


def extract_image_patches(x, kernel_size, stride, dilation, padding=0):
def extract_kernel_patches(x, channels, kernel_size, stride, dilation, padding=0):
# TODO: implement dilation and padding
# does the order in which the patches are returned matter?
# [b, c, h, w] OR [b, c, d, h, w] OR [b, c, t, d, h, w]
Expand All @@ -269,19 +330,30 @@ def extract_image_patches(x, kernel_size, stride, dilation, padding=0):
d += 1

if len(kernel_size) == 2:
patches = patches.permute(0, 4, 5, 1, 2, 3).contiguous()
patches = patches.permute(1, 4, 5, 0, 2, 3).contiguous()
if len(kernel_size) == 3:
# TODO: Not sure if this is right
patches = patches.permute(0, 5, 6, 7, 1, 2, 3, 4).contiguous()
patches = patches.permute(1, 5, 6, 7, 0, 2, 3, 4).contiguous()
elif len(kernel_size) == 4:
patches = patches.permute(0, 6, 7, 8, 9, 1, 2, 3, 4, 5).contiguous()
patches = patches.permute(1, 6, 7, 8, 9, 0, 2, 3, 4, 5).contiguous()

return patches.view(channels, *kernel_size, -1)


def extract_image_patches(x, channels, ndims):
# does the order in which the patches are returned matter?
# [b, c, h, w] OR [b, c, d, h, w] OR [b, c, t, d, h, w]

if ndims == 2:
x = x.permute(1, 0, 2, 3).contiguous()
if ndims == 3:
x = x.permute(1, 0, 2, 3, 4).contiguous()
elif ndims == 4:
x = x.permute(1, 0, 2, 3, 4, 5).contiguous()

return patches.view(-1, *kernel_size)
return x.view(channels, -1)


def convert_to_bio(
model: LeibNet, learning_rule: LearningRule, learning_rate=1.0, init_layers=True
):
def convert_to_bio(model: LeibNet, learning_rule: LearningRule, init_layers=True):
"""Converts a LeibNet model to use local bio-inspired learning rules.
Args:
Expand All @@ -294,36 +366,24 @@ def convert_to_bio(
LeibNet: Model with local learning rules applied in forward hooks.
"""

@torch.no_grad()
def hook(module, args, kwargs, output):
if module.training:
with torch.no_grad():
inputs = args[0]
out = learning_rule.update(inputs, module)
if isinstance(out, tuple):
d_w = out[0]
output = out[1]
else:
d_w = out
d_w = d_w.view(module.weight.size())
module.weight.data += d_w * learning_rate
output = output.detach().requires_grad_(False)

return output
hook = learning_rule.update

hooks = []
for module in model.modules():
if hasattr(module, "weight"):
hooks.append(module.register_forward_hook(hook, with_kwargs=True))
if init_layers:
learning_rule.init_layers(module)
if hasattr(learning_rule, "init_layers"):
learning_rule.init_layers(module)
else:
torch.nn.init.sparse_(module.weight, sparsity=0.5)
module.weight.requires_grad = False
setattr(module, "learning_rule", learning_rule)
setattr(module, "learning_rate", learning_rate)
setattr(module, "learning_rate", learning_rule.learning_rate)
setattr(module, "learning_hook", hooks[-1])

setattr(model, "learning_rule", learning_rule)
setattr(model, "learning_rate", learning_rate)
setattr(model, "learning_rate", learning_rule.learning_rate)
setattr(model, "learning_hooks", hooks)
model.requires_grad_(False)

Expand Down Expand Up @@ -360,20 +420,46 @@ def convert_to_backprop(model: LeibNet):

unet = build_unet()
inputs = unet.get_example_inputs()["input"]
batch_size = 2
inputs = torch.cat(
[
inputs,
]
* batch_size,
dim=0,
)
for module in unet.modules():
if hasattr(module, "weight"):
break
output = module(inputs)
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
X = extract_kernel_patches(
inputs, module.in_channels, module.kernel_size, module.stride, module.dilation
)
# inputs = inputs.view(inputs.size(0), -1)
Y = extract_image_patches(output, module.out_channels, len(module.kernel_size)).T
weights = module.weight
# weights = weights.view(-1, torch.prod(torch.as_tensor(module.kernel_size)))
print(X.shape)
print(Y.shape)
# weights = weights.view(
# *module.weight.shape[: -len(module.kernel_size)],
# torch.prod(torch.as_tensor(module.kernel_size)),
# )

print(inputs.shape)
print(weights.shape)
((inputs**2) * weights).shape
# # # %%
# print(inputs.shape)
# print(weights.shape)
# ((inputs**2) * weights).shape

# %%
from leibnetz.nets import build_unet

unet = build_unet()
model = convert_to_bio(unet, HebbsRule())
batch = model(model.get_example_inputs())
# %%
import numpy as np

c = 0.1
pad = (np.array(kernel_size) - 1) // 2
c * output * (
inputs[..., pad[0] : -pad[0], pad[1] : -pad[1], pad[2] : -pad[2]] - output * weights
)
# %%

0 comments on commit 74545fe

Please sign in to comment.