Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bio lr #7

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn import Module
import numpy as np
from leibnetz.nodes import Node
from funlib.learn.torch.models.conv4d import Conv4d

# from model_opt.apis import optimize

Expand All @@ -19,6 +20,7 @@ def __init__(
nodes: Iterable,
outputs: dict[str, Sequence[Tuple]],
retain_buffer=True,
initialization="kaiming",
):
super().__init__()
full_node_list = []
Expand All @@ -35,6 +37,41 @@ def __init__(
self.nodes_dict = torch.nn.ModuleDict({node.id: node for node in nodes})
self.graph = nx.DiGraph()
self.assemble(outputs)
self.initialization = initialization
if initialization == "kaiming":
self.apply(
lambda m: (
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out")
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
else None
)
)
elif initialization == "xavier":
self.apply(
lambda m: (
torch.nn.init.xavier_normal_(m.weight)
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
else None
)
)
elif initialization == "orthogonal":
self.apply(
lambda m: (
torch.nn.init.orthogonal_(m.weight)
if isinstance(m, torch.nn.Conv2d)
or isinstance(m, torch.nn.Conv3d)
or isinstance(m, Conv4d)
else None
)
)
elif initialization is None:
pass
else:
raise ValueError(f"Unknown initialization {initialization}")
self.retain_buffer = retain_buffer
self.retain_buffer = True
if torch.cuda.is_available():
Expand Down
7 changes: 7 additions & 0 deletions src/leibnetz/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from .attentive_scalenet import build_attentive_scale_net
from .scalenet import build_scale_net
from .unet import build_unet
from .bio import (
convert_to_bio,
convert_to_backprop,
HebbsRule,
KrotovsRule,
OjasRule,
)

# from .resnet import build_resnet
247 changes: 247 additions & 0 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
"""
This code is taken from https://github.com/Joxis/pytorch-hebbian.git
The code is licensed under the MIT license.

Please reference the following paper if you use this code:
@inproceedings{talloen2020pytorchhebbian,
author = {Jules Talloen and Joni Dambre and Alexander Vandesompele},
location = {Online},
title = {PyTorch-Hebbian: facilitating local learning in a deep learning framework},
year = {2020},
}
"""

# %%
import logging
from abc import ABC, abstractmethod
import torch

from leibnetz import LeibNet


class LearningRule(ABC):

def __init__(self):
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)

def init_layers(self, model):
pass

@abstractmethod
def update(self, x, w):
pass


class HebbsRule(LearningRule):

def __init__(self, c=0.1):
super().__init__()
self.c = c

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: Needs re-implementation
d_ws = torch.zeros(inputs.size(0))
for idx, x in enumerate(inputs):
y = torch.dot(weights, x)

d_w = torch.zeros(weights.shape)
for i in range(y.shape[0]):
for j in range(x.shape[0]):
d_w[i, j] = self.c * x[j] * y[i]

d_ws[idx] = d_w

return torch.mean(d_ws, dim=0)


class KrotovsRule(LearningRule):
"""Krotov-Hopfield Hebbian learning rule fast implementation.

Original source: https://github.com/DimaKrotov/Biological_Learning

Args:
precision: Numerical precision of the weight updates.
delta: Anti-hebbian learning strength.
norm: Lebesgue norm of the weights.
k: Ranking parameter
"""

def __init__(self, precision=1e-30, delta=0.4, norm=2, k=2, normalize=False):
super().__init__()
self.precision = precision
self.delta = delta
self.norm = norm
self.k = k
self.normalize = normalize

def init_layers(self, layer):
if hasattr(layer, "weight"):
layer.weight.data.normal_(mean=0.0, std=1.0)

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
batch_size = inputs.shape[0]
num_hidden_units = weights.shape[0]
input_size = inputs[0].shape[0]
assert (
self.k <= num_hidden_units
), "The amount of hidden units should be larger or equal to k!"

# TODO: WIP
if self.normalize:
norm = torch.norm(inputs, dim=1)
norm[norm == 0] = 1
inputs = torch.div(inputs, norm.view(-1, 1))

inputs = torch.t(inputs)

# Calculate overlap for each hidden unit and input sample
tot_input = torch.matmul(
torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs
)

# Get the top k activations for each input sample (hidden units ranked per input sample)
_, indices = torch.topk(tot_input, k=self.k, dim=0)

# Apply the activation function for each input sample
activations = torch.zeros((num_hidden_units, batch_size), device=weights.device)
activations[indices[0], torch.arange(batch_size)] = 1.0
activations[indices[self.k - 1], torch.arange(batch_size)] = -self.delta

# Sum the activations for each hidden unit, the batch dimension is removed here
xx = torch.sum(torch.mul(activations, tot_input), 1)

# Apply the actual learning rule, from here on the tensor has the same dimension as the weights
norm_factor = torch.mul(
xx.view(xx.shape[0], 1).repeat((1, input_size)), weights
)
ds = torch.matmul(activations, torch.t(inputs)) - norm_factor

# Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate)
nc = torch.max(torch.abs(ds))
if nc < self.precision:
nc = self.precision
d_w = torch.true_divide(ds, nc)

return d_w


class OjasRule(LearningRule):

def __init__(self, c=0.1):
super().__init__()
self.c = c

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: needs re-implementation
d_ws = torch.zeros(inputs.size(0), *weights.shape)
for idx, x in enumerate(inputs):
y = torch.mm(weights, x.unsqueeze(1))

d_w = torch.zeros(weights.shape)
for i in range(y.shape[0]):
for j in range(x.shape[0]):
d_w[i, j] = self.c * y[i] * (x[j] - y[i] * weights[i, j])

d_ws[idx] = d_w

return torch.mean(d_ws, dim=0)


def extract_image_patches(x, kernel_size, stride, dilation, padding=0):
# TODO: implement dilation and padding
# does the order in which the patches are returned matter?
# [b, c, h, w] OR [b, c, d, h, w] OR [b, c, t, d, h, w]

# Extract patches
d = 2
patches = x
for k, s in zip(kernel_size, stride):
patches = patches.unfold(d, k, s)
d += 1

if len(kernel_size) == 2:
patches = patches.permute(0, 4, 5, 1, 2, 3).contiguous()
if len(kernel_size) == 3:
# TODO: Not sure if this is right
patches = patches.permute(0, 5, 6, 7, 1, 2, 3, 4).contiguous()
elif len(kernel_size) == 4:
patches = patches.permute(0, 6, 7, 8, 9, 1, 2, 3, 4, 5).contiguous()

return patches.view(-1, *kernel_size)


def convert_to_bio(
model: LeibNet, learning_rule: LearningRule, learning_rate=1.0, init_layers=True
):
"""Converts a LeibNet model to use local bio-inspired learning rules.

Args:
model (LeibNet): Initial LeibNet model to convert.
learning_rule (LearningRule): Learning rule to apply to the model. Can be `HebbsRule`, `KrotovsRule` or `OjasRule`.
learning_rate (float, optional): Learning rate for the learning rule. Defaults to 1.0.
init_layers (bool, optional): Whether to initialize the model's layers. Defaults to True. This will discard existing weights.

Returns:
LeibNet: Model with local learning rules applied in forward hooks.
"""

def hook(module, args, kwargs, output):
if module.training:
with torch.no_grad():
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
args[0], module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
-1, torch.prod(torch.as_tensor(module.kernel_size))
)
else:
inputs = args[0]
weights = module.weight
inputs = inputs.view(inputs.size(0), -1)
d_w = learning_rule.update(inputs, weights)
d_w = d_w.view(module.weight.size())
module.weight.data += d_w * learning_rate

hooks = []
for module in model.modules():
if hasattr(module, "weight"):
hooks.append(module.register_forward_hook(hook, with_kwargs=True))
module.weight.requires_grad = False
if init_layers:
learning_rule.init_layers(module)

setattr(model, "learning_rule", learning_rule)
setattr(model, "learning_rate", learning_rate)
setattr(model, "learning_hooks", hooks)

return model


def convert_to_backprop(model: LeibNet):
"""Converts a LeibNet model to use backpropagation for training.

Args:
model (LeibNet): Initial LeibNet model to convert.

Returns:
LeibNet: Model with backpropagation applied in forward hooks.
"""

for module in model.modules():
if hasattr(module, "weight"):
module.weight.requires_grad = True

try:
for hook in model.learning_hooks:
hook.remove()
except AttributeError:
UserWarning(
"Model does not have learning hooks. It is already using backpropagation."
)

return model


# %%
Loading