Skip to content

Commit

Permalink
feat: ✨ Add logging properties.
Browse files Browse the repository at this point in the history
Added model names and local learning rule strings to improve logability of experiments.
  • Loading branch information
rhoadesScholar committed Apr 17, 2024
1 parent 5674d58 commit bd55051
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
outputs: dict[str, Sequence[Tuple]],
retain_buffer=True,
initialization="kaiming",
name="LeibNet",
):
super().__init__()
full_node_list = []
Expand Down Expand Up @@ -81,6 +82,8 @@ def __init__(
else:
self.cpu()

self.name = name

def assemble(self, outputs: dict[str, Sequence[Tuple]]):
"""
NOTE: If your scales are non-integer realworld units, you need to treat the scale as integer factors instead.
Expand Down
2 changes: 1 addition & 1 deletion src/leibnetz/nets/attentive_scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def build_attentive_scale_net(subnet_dict_list: list[dict]):
output = subnet_outputs.pop(f"{subnet_id}_output")
outputs[f"{subnet_id}_output"] = output
bottleneck_input_dict = subnet_outputs
network = LeibNet(nodes, outputs=outputs)
network = LeibNet(nodes, outputs=outputs, name="AttentiveScaleNet")
return network


Expand Down
15 changes: 12 additions & 3 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self):
def init_layers(self, model):
pass

def __str__(self):
return self.__class__.__name__

@abstractmethod
def update(self, x, w):
pass
Expand All @@ -38,6 +41,9 @@ def __init__(self, c=0.1):
super().__init__()
self.c = c

def __str__(self):
return f"HebbsRule(c={self.c})"

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: Needs re-implementation
d_ws = torch.zeros(inputs.size(0))
Expand Down Expand Up @@ -74,6 +80,9 @@ def __init__(self, precision=1e-30, delta=0.4, norm=2, k=2, normalize=False):
self.k = k
self.normalize = normalize

def __str__(self):
return f"KrotovsRule(precision={self.precision}, delta={self.delta}, norm={self.norm}, k={self.k})"

def init_layers(self, layer):
if hasattr(layer, "weight"):
layer.weight.data.normal_(mean=0.0, std=1.0)
Expand Down Expand Up @@ -131,6 +140,9 @@ def __init__(self, c=0.1):
super().__init__()
self.c = c

def __str__(self):
return f"OjasRule(c={self.c})"

def update(self, inputs: torch.Tensor, weights: torch.Tensor):
# TODO: needs re-implementation
d_ws = torch.zeros(inputs.size(0), *weights.shape)
Expand Down Expand Up @@ -242,6 +254,3 @@ def convert_to_backprop(model: LeibNet):
)

return model


# %%
2 changes: 1 addition & 1 deletion src/leibnetz/nets/scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def build_scale_net(subnet_dict_list: list[dict]):
output = subnet_outputs.pop(f"{subnet_id}_output")
outputs[f"{subnet_id}_output"] = output
bottleneck_input_dict = subnet_outputs
network = LeibNet(nodes, outputs=outputs)
network = LeibNet(nodes, outputs=outputs, name="ScaleNet")
return network


Expand Down
4 changes: 3 additions & 1 deletion src/leibnetz/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def build_unet(

# define network
network = LeibNet(
nodes, outputs={"output": [tuple(np.ones(len(top_resolution))), top_resolution]}
nodes,
outputs={"output": [tuple(np.ones(len(top_resolution))), top_resolution]},
name="UNet",
)

return network
Expand Down
1 change: 1 addition & 0 deletions src/leibnetz/nodes/node_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
layers = []

for i, kernel_size in enumerate(kernel_sizes):
# TODO: Use of BatchNorm does not work with bio-inspired learning rules
if norm_layer is not None:
layers.append(norm_layer(input_nc))

Expand Down

0 comments on commit bd55051

Please sign in to comment.