Skip to content

Commit

Permalink
harmonised notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Dec 1, 2023
1 parent 11a35df commit 30acab3
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 371,596 deletions.
2 changes: 1 addition & 1 deletion MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include_self: True # include vector at the center of feature

# network parameters
dropout: 0. # dropout in the MLP
hidden_channels: [16] # number of hidden channels
hidden_channels: [32] # number of hidden channels
out_channels: 3 # number of output channels (if null, then =hidden_channels)
vec_norm: False # normalise features at each order of derivatives
bias: True # learn bias parameters in MLP
Expand Down
3 changes: 2 additions & 1 deletion MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def check_parameters(self, data):
"n_sampled_nb",
"include_positions",
"include_self",
"processes",
]

for p in self.params.keys():
Expand Down Expand Up @@ -279,7 +280,7 @@ def forward(self, data, n_id, adjs=None):
def evaluate(self, data):
"""Evaluate."""
warnings.warn("MARBLE.evaluate() is deprecated. Use MARBLE.transform() instead.")
self.transform(data)
return self.transform(data)

def transform(self, data):
"""Forward pass @ evaluation (no minibatches)"""
Expand Down
5 changes: 4 additions & 1 deletion MARBLE/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from matplotlib import gridspec
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from scipy.spatial import Voronoi
from scipy.spatial import voronoi_plot_2d
Expand Down Expand Up @@ -216,7 +217,7 @@ def embedding(
plot_trajectories = False

if plot_trajectories:
l_ = data.l[mask * (labels == typ)]
l_ = data.label[mask * (labels == typ)]
if len(l_) == 0:
continue
end = np.where(np.diff(l_) < 0)[0] + 1
Expand Down Expand Up @@ -680,6 +681,8 @@ def create_axis(*args, fig=None):
ax = fig.add_subplot(*args)
elif dim == 3:
ax = fig.add_subplot(*args, projection="3d")
else:
raise Exception('Data dimension is {}. We can only plot 2D or 3D data.'.format(dim))

return fig, ax

Expand Down
Loading

0 comments on commit 30acab3

Please sign in to comment.