Skip to content

Commit

Permalink
fixed skip and added parameter choice
Browse files Browse the repository at this point in the history
  • Loading branch information
peach-lucien committed Nov 27, 2023
1 parent 6d8d587 commit f8551f4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
1 change: 1 addition & 0 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: False # batch normalisation
emb_norm: False # spherical output
skip_connections: True # use skips in MLP

# other params
seed: 0 # seed for reproducibility
Expand Down
29 changes: 13 additions & 16 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,35 @@
"""Layer module."""
import torch
from torch import nn
from torch.nn.functional import normalize
from torch.nn.functional import normalize, relu
from torch_geometric.nn.conv import MessagePassing

from MARBLE import geometry as g

class SkipMLP(nn.Module):
""" MLP with skip connections """

def __init__(self, channel_list, dropout=0.0, bias=True):
super(SkipMLP, self).__init__()
assert len(channel_list) > 1, "Channel list must have at least two elements for an MLP."
self.layers = nn.ModuleList()
self.dropout = dropout
self.in_channels = channel_list[0]
for i in range(len(channel_list) - 1):
self.layers.append(nn.Linear(channel_list[i], channel_list[i+1], bias=bias))
if i < len(channel_list) - 2: # Don't add activation or dropout to the last layer
self.layers.append(nn.ReLU(inplace=True))
if dropout > 0:
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.Linear(channel_list[i], channel_list[i + 1], bias=bias))
self.layers.append(nn.Dropout(dropout))

# Output layer adjustment for concatenated skip connection
final_out_features = channel_list[-1] + channel_list[0]
self.output_layer = nn.Linear(final_out_features, channel_list[-1], bias=bias)

def forward(self, x):
identity = x
for layer in self.layers:
if isinstance(layer, nn.Linear):
if x.shape[1] == layer.weight.shape[1]: # Check if skip connection is possible
identity = x # Save identity for skip connection
x = layer(x)
if x.shape[1] == identity.shape[1]: # Apply skip connection if shapes match
x += identity
x = relu(layer(x))
else:
x = layer(x) # Apply activation or dropout
x = layer(x)

# Concatenate the input (identity) with the output
x = torch.cat([identity, x], dim=1)
x = self.output_layer(x)
return x


Expand Down
26 changes: 14 additions & 12 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def check_parameters(self, data):
"batch_norm",
"vec_norm",
"emb_norm",
"skip_connections",
"seed",
"n_sampled_nb",
"processes",
Expand Down Expand Up @@ -203,19 +204,20 @@ def setup_layers(self):
+ [self.params["out_channels"]]
)

# self.enc = MLP(
# channel_list=channel_list,
# dropout=self.params["dropout"],
# #norm=self.params["batch_norm"],
# bias=self.params["bias"],
# )
if self.params['skip_connections']:
self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)
else:
self.enc = MLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)

self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
#norm=self.params["batch_norm"],
bias=self.params["bias"],
)




Expand Down

0 comments on commit f8551f4

Please sign in to comment.