Skip to content

Commit

Permalink
feat: ⚡️ Vectorize Ojas rule.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 19, 2024
1 parent 02321eb commit 8966fd9
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 33 deletions.
8 changes: 8 additions & 0 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ def output_shapes(self):
def array_shapes(self):
return self._get_shapes(self._array_shapes)

@property
def param_num(self):
param_num = 0
for key, val in self.named_parameters():
# print(f"{key}: {val.shape}")
param_num += val.numel()
return param_num

def mps(self):
if torch.backends.mps.is_available():
self.to("mps")
Expand Down
7 changes: 6 additions & 1 deletion src/leibnetz/nets/attentive_scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def build_subnet(


# %%
def build_attentive_scale_net(subnet_dict_list: list[dict]):
def build_attentive_scale_net(
subnet_dict_list: list[dict] = [
{"top_resolution": (32, 32, 32)},
{"top_resolution": (8, 8, 8)},
]
):
nodes = []
outputs = {}
bottleneck_input_dict = None
Expand Down
178 changes: 150 additions & 28 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,93 @@ def update(self, inputs: torch.Tensor, weights: torch.Tensor):
return torch.mean(d_ws, dim=0)


class RhoadesRule(LearningRule):
"""Rule modifying Krotov-Hopfield Hebbian learning rule fast implementation.
Args:
precision: Numerical precision of the weight updates.
delta: Anti-hebbian learning strength.
norm: Lebesgue norm of the weights.
k_ratio: Ranking parameter
"""

def __init__(
self, k_ratio=0.5, delta=0.4, norm=2, normalize=False, precision=1e-30
):
super().__init__()
self.precision = precision
self.delta = delta
self.norm = norm
assert k_ratio <= 1, "k_ratio should be smaller or equal to 1"
self.k_ratio = k_ratio
self.normalize = normalize

def __str__(self):
return f"RhoadesRule(k_ratio={self.k_ratio}, delta={self.delta}, norm={self.norm}, normalize={self.normalize})"

def init_layers(self, layer):
if hasattr(layer, "weight"):
layer.weight.data.normal_(mean=0.0, std=1.0)

def update(self, inputs: torch.Tensor, module: torch.nn.Module):
# TODO: WIP
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
-1, torch.prod(torch.as_tensor(module.kernel_size))
)
else:
weights = module.weight
inputs = inputs.view(inputs.size(0), -1)

# TODO: needs re-implementation
batch_size = inputs.shape[0]
num_hidden_units = torch.prod(torch.as_tensor(weights.shape))
input_size = inputs[0].shape[0]
k = int(self.k_ratio * num_hidden_units)

# TODO: WIP
if self.normalize:
norm = torch.norm(inputs, dim=1)
norm[norm == 0] = 1
inputs = torch.div(inputs, norm.view(-1, 1))

inputs = torch.t(inputs)

# Calculate overlap for each hidden unit and input sample
tot_input = torch.matmul(
torch.sign(weights) * torch.abs(weights) ** (self.norm - 1), inputs
)

# Get the top k activations for each input sample (hidden units ranked per input sample)
_, indices = torch.topk(tot_input, k=k, dim=0)

# Apply the activation function for each input sample
activations = torch.zeros((num_hidden_units, batch_size), device=weights.device)
activations[indices[0], torch.arange(batch_size)] = 1.0
activations[indices[k - 1], torch.arange(batch_size)] = -self.delta

# Sum the activations for each hidden unit, the batch dimension is removed here
xx = torch.sum(torch.mul(activations, tot_input), 1)

# Apply the actual learning rule, from here on the tensor has the same dimension as the weights
norm_factor = torch.mul(
xx.view(xx.shape[0], 1).repeat((1, input_size)), weights
)
ds = torch.matmul(activations, torch.t(inputs)) - norm_factor

# Normalize the weight updates so that the largest update is 1 (which is then multiplied by the learning rate)
nc = torch.max(torch.abs(ds))
if nc < self.precision:
nc = self.precision
d_w = torch.true_divide(ds, nc)

return d_w


class KrotovsRule(LearningRule):
"""Krotov-Hopfield Hebbian learning rule fast implementation.
Expand All @@ -69,7 +156,7 @@ class KrotovsRule(LearningRule):
precision: Numerical precision of the weight updates.
delta: Anti-hebbian learning strength.
norm: Lebesgue norm of the weights.
k: Ranking parameter
k_ratio: Ranking parameter
"""

def __init__(
Expand All @@ -90,7 +177,19 @@ def init_layers(self, layer):
if hasattr(layer, "weight"):
layer.weight.data.normal_(mean=0.0, std=1.0)

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
def update(self, inputs: torch.Tensor, module: torch.nn.Module):
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
-1, torch.prod(torch.as_tensor(module.kernel_size))
)
else:
weights = module.weight
inputs = inputs.view(inputs.size(0), -1)

batch_size = inputs.shape[0]
num_hidden_units = weights.shape[0]
input_size = inputs[0].shape[0]
Expand Down Expand Up @@ -144,20 +243,17 @@ def __init__(self, c=0.1):
def __str__(self):
return f"OjasRule(c={self.c})"

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: needs re-implementation
d_ws = torch.zeros(inputs.size(0), *weights.shape)
for idx, x in enumerate(inputs):
y = torch.mm(weights, x.unsqueeze(1))

d_w = torch.zeros(weights.shape)
for i in range(y.shape[0]):
for j in range(x.shape[0]):
d_w[i, j] = self.c * y[i] * (x[j] - y[i] * weights[i, j])

d_ws[idx] = d_w
def update(self, inputs: torch.Tensor, module: torch.nn.Module):
# dW = c (I**2 * W - I**2 * W**3) / n

return torch.mean(d_ws, dim=0)
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
d_W = self.c * ((inputs**2) * module.weight - (inputs**2) * (module.weight**3))
d_W = d_W.sum(dim=1) / inputs.shape[0]
return d_W


def extract_image_patches(x, kernel_size, stride, dilation, padding=0):
Expand Down Expand Up @@ -198,36 +294,38 @@ def convert_to_bio(
LeibNet: Model with local learning rules applied in forward hooks.
"""

@torch.no_grad()
def hook(module, args, kwargs, output):
if module.training:
with torch.no_grad():
if hasattr(module, "kernel_size"):
# Extract patches for convolutional layers
inputs = extract_image_patches(
args[0], module.kernel_size, module.stride, module.dilation
)
weights = module.weight.view(
-1, torch.prod(torch.as_tensor(module.kernel_size))
)
inputs = args[0]
out = learning_rule.update(inputs, module)
if isinstance(out, tuple):
d_w = out[0]
output = out[1]
else:
inputs = args[0]
weights = module.weight
inputs = inputs.view(inputs.size(0), -1)
d_w = learning_rule.update(inputs, weights)
d_w = out
d_w = d_w.view(module.weight.size())
module.weight.data += d_w * learning_rate
output = output.detach().requires_grad_(False)

return output

hooks = []
for module in model.modules():
if hasattr(module, "weight"):
hooks.append(module.register_forward_hook(hook, with_kwargs=True))
module.weight.requires_grad = False
if init_layers:
learning_rule.init_layers(module)
module.weight.requires_grad = False
setattr(module, "learning_rule", learning_rule)
setattr(module, "learning_rate", learning_rate)
setattr(module, "learning_hook", hooks[-1])

setattr(model, "learning_rule", learning_rule)
setattr(model, "learning_rate", learning_rate)
setattr(model, "learning_hooks", hooks)
model.requires_grad_(False)

return model

Expand Down Expand Up @@ -255,3 +353,27 @@ def convert_to_backprop(model: LeibNet):
)

return model


# %%
from leibnetz.nets import build_unet

unet = build_unet()
inputs = unet.get_example_inputs()["input"]
for module in unet.modules():
if hasattr(module, "weight"):
break
output = module(inputs)
inputs = extract_image_patches(
inputs, module.kernel_size, module.stride, module.dilation
)
# inputs = inputs.view(inputs.size(0), -1)
weights = module.weight
# weights = weights.view(-1, torch.prod(torch.as_tensor(module.kernel_size)))

print(inputs.shape)
print(weights.shape)
((inputs**2) * weights).shape
# # # %%

# %%
11 changes: 10 additions & 1 deletion src/leibnetz/nets/scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def build_subnet(


# %%
def build_scale_net(subnet_dict_list: list[dict]):
def build_scale_net(
subnet_dict_list: list[dict] = [
{"top_resolution": (32, 32, 32)},
{"top_resolution": (8, 8, 8)},
]
):
nodes = []
outputs = {}
bottleneck_input_dict = None
Expand All @@ -158,6 +163,10 @@ def testing():
{"top_resolution": (8, 8, 8)},
]
scalenet = build_scale_net(subnet_dict_list)
param_num = 0
for key, val in scalenet.named_parameters():
print(f"{key}: {val.shape}")
param_num += val.numel()
scalenet.array_shapes
# %%
inputs = scalenet.get_example_inputs()
Expand Down
1 change: 1 addition & 0 deletions src/leibnetz/nodes/additive_attention_gate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from leibnetz.nodes.node_ops import ConvPass


# TODO: Not 2D compatible
class AdditiveAttentionGateNode(Node):
def __init__(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/leibnetz/nodes/node_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def __init__(
if norm_layer is not None:
layers.append(norm_layer(input_nc))

layers.append(self.activation)

if dropout_prob is not None and dropout_prob > 0:
layers.append(nn.Dropout(dropout_prob))

Expand Down Expand Up @@ -112,6 +110,9 @@ def __init__(
bias=False,
groups=groups,
)
else:
layers.append(self.activation)

except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

Expand Down Expand Up @@ -139,4 +140,4 @@ def forward(self, x):
init_x = self.crop(self.x_init_map(x), res.size()[-self.dims :])
else:
init_x = self.x_init_map(x)
return res + init_x
return self.activation(res + init_x)

0 comments on commit 8966fd9

Please sign in to comment.