Skip to content

Commit

Permalink
simplify main
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Nov 30, 2023
1 parent faef61c commit 1c78d7c
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 156 deletions.
8 changes: 3 additions & 5 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ include_positions: False # include positions as features (warning: this is untes
include_self: True # include vector at the center of feature

# network parameters
dropout: 0. #d ropout in the MLP
dropout: 0. # dropout in the MLP
hidden_channels: [16] # number of hidden channels
out_channels: 3 # number of output channels (if null, then =hidden_channels)
vec_norm: False # normalise features at each order of derivatives
bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: True # batch normalisation
emb_norm: False # spherical output
skip_connections: False # use skips in MLP

# other params
seed: 0 # seed for reproducibility
processes: 1
seed: 0 # seed for reproducibility
2 changes: 1 addition & 1 deletion MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def compute_laplacian(data, normalization="rw"):


def compute_connection_laplacian(data, R, normalization="rw"):
r"""Connection Laplacian
"""Connection Laplacian
Args:
data: Pytorch geometric data object.
Expand Down
34 changes: 1 addition & 33 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,6 @@

from MARBLE import geometry as g

class SkipMLP(nn.Module):
def __init__(self, channel_list, dropout=0.0, bias=True):
super(SkipMLP, self).__init__()
self.layers = nn.ModuleList()
self.in_channels = channel_list[0]
for i in range(len(channel_list) - 1):
self.layers.append(nn.Linear(channel_list[i], channel_list[i + 1], bias=bias))
self.layers.append(nn.Dropout(dropout))

# Output layer adjustment for concatenated skip connection
final_out_features = channel_list[-1] + channel_list[0]
self.output_layer = nn.Linear(final_out_features, channel_list[-1], bias=bias)

def forward(self, x):
identity = x
for layer in self.layers:
if isinstance(layer, nn.Linear):
x = relu(layer(x))
else:
x = layer(x)

# Concatenate the input (identity) with the output
x = torch.cat([identity, x], dim=1)
x = self.output_layer(x)
return x


class Diffusion(nn.Module):
"""Diffusion with learned t."""

Expand Down Expand Up @@ -68,12 +41,10 @@ def forward(self, x, L, Lc=None, method="spectral"):
class AnisoConv(MessagePassing):
"""Anisotropic Convolution"""

def __init__(self, vec_norm=False, **kwargs):
def __init__(self, **kwargs):
"""Initialize."""
super().__init__(aggr="add", **kwargs)

self.vec_norm = vec_norm

def forward(self, x, kernels):
"""Forward."""
out = []
Expand All @@ -84,9 +55,6 @@ def forward(self, x, kernels):
out = torch.stack(out, axis=2)
out = out.view(out.shape[0], -1)

# if self.vec_norm:
# out = normalize(out, dim=-1, p=2)

return out

def message_and_aggregate(self, K_t, x):
Expand Down
55 changes: 21 additions & 34 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@
import os
from datetime import datetime
from pathlib import Path
import yaml
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.optim as opt
import yaml
from tensorboardX import SummaryWriter
from torch import nn
from torch_geometric.nn import MLP
from tqdm import tqdm

from MARBLE import dataloader
from MARBLE import geometry
from MARBLE import layers
from MARBLE import utils
from MARBLE import dataloader, geometry, layers, utils

import warnings

Expand All @@ -42,11 +38,10 @@ class net(nn.Module):
hidden_channels: number of hidden channels (default=16). If list, then adds multiple layers.
out_channels: number of output channels (if null, then =hidden_channels) (default=3)
bias: learn bias parameters in MLP (default=True)
vec_norm: normalise features to unit length (default=False)
vec_norm: normalise features at each derivative order to unit length (default=False)
emb_norm: normalise MLP output to unit length (default=False)
batch_norm: batch normalisation (default=False)
seed: seed for reproducibility (default=0)
processes: number of cpus (default=1)
"""

def __init__(self, data, loadpath=None, params=None, verbose=True):
Expand Down Expand Up @@ -149,10 +144,8 @@ def check_parameters(self, data):
"batch_norm",
"vec_norm",
"emb_norm",
"skip_connections",
"seed",
"n_sampled_nb",
"processes",
"include_positions",
"include_self",
]
Expand All @@ -177,7 +170,7 @@ def setup_layers(self):
self.diffusion = layers.Diffusion()

# gradient features
self.grad = nn.ModuleList(layers.AnisoConv(self.params["vec_norm"]) for i in range(o))
self.grad = nn.ModuleList(layers.AnisoConv() for i in range(o))

# cumulated number of channels after gradient features
cum_channels = s * (1 - d ** (o + 1)) // (1 - d)
Expand Down Expand Up @@ -206,18 +199,11 @@ def setup_layers(self):
+ [self.params["out_channels"]]
)

if self.params['skip_connections']:
self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)
else:
self.enc = MLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)
self.enc = MLP(
channel_list=channel_list,
dropout=self.params["dropout"],
bias=self.params["bias"],
)


def forward(self, data, n_id, adjs=None):
Expand Down Expand Up @@ -377,12 +363,12 @@ def fit(self, data, outdir=None, verbose=False):
)
if hasattr(self, "optimizer_state_dict"):
optimizer.load_state_dict(self.optimizer_state_dict)
writer = SummaryWriter("./log/")

# training scheduler
scheduler = opt.lr_scheduler.ReduceLROnPlateau(optimizer)

best_loss = -1
self.losses = {'train_loss': [], 'val_loss': [], 'test_loss': []}
for epoch in range(
self.params.get("epoch", 0), self.params.get("epoch", 0) + self.params["epochs"]
):
Expand All @@ -394,26 +380,25 @@ def fit(self, data, outdir=None, verbose=False):
val_loss, _ = self.batch_loss(data, val_loader, verbose=verbose)
scheduler.step(train_loss)

writer.add_scalar("Loss/train", train_loss, self._epoch)
writer.add_scalar("Loss/validation", val_loss, self._epoch)
writer.flush()
print(
f"\nEpoch: {self._epoch}, Training loss: {train_loss:4f}, Validation loss: {val_loss:.4f}, lr: {scheduler._last_lr[0]:.4f}", # noqa, pylint: disable=line-too-long,protected-access
end="",
)

if best_loss == -1 or (val_loss < best_loss):
outdir = self.save_model(optimizer, outdir, best=True, timestamp=time)
outdir = self.save_model(optimizer, self.losses, outdir=outdir, best=True, timestamp=time)
best_loss = val_loss
print(" *", end="")

self.losses['train_loss'].append(train_loss)
self.losses['val_loss'].append(val_loss)

test_loss, _ = self.batch_loss(data, test_loader)
writer.add_scalar("Loss/test", test_loss)
writer.close()
print(f"\nFinal test loss: {test_loss:.4f}")

self.save_model(optimizer, outdir, best=False, timestamp=time)

self.losses['test_loss'].append(test_loss)

self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=time)
self.load_model(os.path.join(outdir, f"best_model_{time}.pth"))

def load_model(self, loadpath):
Expand All @@ -426,8 +411,9 @@ def load_model(self, loadpath):
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
self.losses = checkpoint['losses']

def save_model(self, optimizer, outdir=None, best=False, timestamp=""):
def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""):
"""Save model."""
if outdir is None:
outdir = "./outputs/"
Expand All @@ -441,6 +427,7 @@ def save_model(self, optimizer, outdir=None, best=False, timestamp=""):
"optimizer_state_dict": optimizer.state_dict(),
"time": timestamp,
"params": self.params,
"losses": losses,
}

if best:
Expand Down
Loading

0 comments on commit 1c78d7c

Please sign in to comment.