Skip to content

Commit

Permalink
feat: ⚡️ Change Krotov's Rule to use topK as ratio of feature maps
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 18, 2024
1 parent bd55051 commit 02321eb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/leibnetz/nets/attentive_scalenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def testing():
# leibnet.array_shapes
# %%
inputs = leibnet.get_example_inputs()
for key, val in inputs.items():
print(f"{key}: {val.shape}")
outputs = leibnet(inputs)
# %%
for key, val in outputs.items():
Expand Down
17 changes: 9 additions & 8 deletions src/leibnetz/nets/bio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@ class KrotovsRule(LearningRule):
k: Ranking parameter
"""

def __init__(self, precision=1e-30, delta=0.4, norm=2, k=2, normalize=False):
def __init__(
self, k_ratio=0.5, delta=0.4, norm=2, normalize=False, precision=1e-30
):
super().__init__()
self.precision = precision
self.delta = delta
self.norm = norm
self.k = k
assert k_ratio <= 1, "k_ratio should be smaller or equal to 1"
self.k_ratio = k_ratio
self.normalize = normalize

def __str__(self):
return f"KrotovsRule(precision={self.precision}, delta={self.delta}, norm={self.norm}, k={self.k})"
return f"KrotovsRule(k_ratio={self.k_ratio}, delta={self.delta}, norm={self.norm}, normalize={self.normalize})"

def init_layers(self, layer):
if hasattr(layer, "weight"):
Expand All @@ -91,9 +94,7 @@ def update(self, inputs: torch.Tensor, weights: torch.Tensor):
batch_size = inputs.shape[0]
num_hidden_units = weights.shape[0]
input_size = inputs[0].shape[0]
assert (
self.k <= num_hidden_units
), "The amount of hidden units should be larger or equal to k!"
k = int(self.k_ratio * num_hidden_units)

# TODO: WIP
if self.normalize:
Expand All @@ -109,12 +110,12 @@ def update(self, inputs: torch.Tensor, weights: torch.Tensor):
)

# Get the top k activations for each input sample (hidden units ranked per input sample)
_, indices = torch.topk(tot_input, k=self.k, dim=0)
_, indices = torch.topk(tot_input, k=k, dim=0)

# Apply the activation function for each input sample
activations = torch.zeros((num_hidden_units, batch_size), device=weights.device)
activations[indices[0], torch.arange(batch_size)] = 1.0
activations[indices[self.k - 1], torch.arange(batch_size)] = -self.delta
activations[indices[k - 1], torch.arange(batch_size)] = -self.delta

# Sum the activations for each hidden unit, the batch dimension is removed here
xx = torch.sum(torch.mul(activations, tot_input), 1)
Expand Down

0 comments on commit 02321eb

Please sign in to comment.