Skip to content

Commit

Permalink
partial fix
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed May 22, 2024
1 parent 5ae3570 commit f848403
Show file tree
Hide file tree
Showing 24 changed files with 148 additions and 2,377 deletions.
1 change: 1 addition & 0 deletions MARBLE/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MARBLE main functions."""

from MARBLE.main import net
from MARBLE.postprocessing import distribution_distances
from MARBLE.postprocessing import embed_in_2D
Expand Down
1 change: 1 addition & 0 deletions MARBLE/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data loader module."""

import torch
from torch_cluster import random_walk
from torch_geometric.loader import NeighborSampler as NS
Expand Down
1 change: 1 addition & 0 deletions MARBLE/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TODO: clean this up
"""

import sys

import numpy as np
Expand Down
7 changes: 4 additions & 3 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Geometry module."""

import numpy as np
import ot
import scipy.sparse as sp
Expand Down Expand Up @@ -196,7 +197,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None):
centroid_distances: distances between cluster centroids
"""
s = slices

pdists, cdists = None, None
if clusters is not None:
# compute discrete measures supported on cluster centroids
labels = clusters["labels"]
Expand Down Expand Up @@ -231,7 +232,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None):
for j in range(i + 1, nl):
mu, nu = bins_dataset[i], bins_dataset[j]

if data is not None:
if data is not None and pdists is not None:
cdists = pdists[s[i] : s[i + 1], s[j] : s[j + 1]]

dist[i, j] = ot.emd2(mu, nu, cdists)
Expand Down Expand Up @@ -355,7 +356,7 @@ def manifold_dimension(Sigma, frac_explained=0.9):
return int(dim_man)


def fit_graph(x, graph_type="cknn", par=1, delta=1.0, metric='euclidean'):
def fit_graph(x, graph_type="cknn", par=1, delta=1.0, metric="euclidean"):
"""Fit graph to node positions"""

if graph_type == "cknn":
Expand Down
1 change: 1 addition & 0 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Layer module."""

import torch
from torch import nn
from torch_geometric.nn.conv import MessagePassing
Expand Down
1 change: 1 addition & 0 deletions MARBLE/lib/cknn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module imported and adapted from https://github.com/chlorochrule/cknn."""

import numpy as np
from scipy.sparse import csr_matrix
from scipy.spatial.distance import pdist
Expand Down
6 changes: 4 additions & 2 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main network"""

import glob
import os
import warnings
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
self.setup_layers()
self.loss = loss_fun()
self.reset_parameters()
self.timestamp = None

if verbose:
utils.print_settings(self)
Expand Down Expand Up @@ -172,7 +174,7 @@ def setup_layers(self):
cum_channels = s * (1 - d ** (o + 1)) // (1 - d)
if not self.params["include_self"]:
cum_channels -= s

if self.params["inner_product_features"]:
cum_channels //= s
if s == 1:
Expand Down Expand Up @@ -350,7 +352,7 @@ def fit(self, data, outdir=None, verbose=False):

self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

print("\n---- Timestamp: {}".format(self.timestamp))
print(f"\n---- Timestamp: {self.timestamp}")

# load to gpu (if possible)
# pylint: disable=self-cls-assignment
Expand Down
11 changes: 6 additions & 5 deletions MARBLE/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plotting module."""

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -182,6 +183,8 @@ def embedding(
emb = data.emb_2D
elif isinstance(data, np.ndarray) or torch.is_tensor(data):
emb = data
else:
raise TypeError

dim = emb.shape[1]
assert dim in [2, 3], f"Embedding dimension is {dim} which cannot be displayed."
Expand Down Expand Up @@ -531,7 +534,7 @@ def trajectories(
Args:
X (np array): Positions
V (np array): Velocities
ax (matplotlib axes object): If specificed, it will plot on existing axes. The default is None
ax (matplotlib axes object): If specificed, it will plot on existing axes. Default is None
style (string): Plotting style. 'o' for scatter plot or '-' for line plot
node_feature: Color lines. The default is None
lw (int): Line width
Expand Down Expand Up @@ -605,9 +608,7 @@ def trajectories(
X, V = X[skip], V[skip]
plot_arrows(X, V, ax, c, width=lw, scale=scale)
else:
raise Exception(
"Data dimension is: {}. It needs to be 2 or 3 to allow plotting.".format(dim)
)
raise Exception(f"Data dimension is: {dim}. It needs to be 2 or 3 to allow plotting.")

set_axes(ax, axes_visible=axes_visible)

Expand Down Expand Up @@ -690,7 +691,7 @@ def create_axis(*args, fig=None):
elif dim == 3:
ax = fig.add_subplot(*args, projection="3d")
else:
raise Exception("Data dimension is {}. We can only plot 2D or 3D data.".format(dim))
raise Exception(f"Data dimension is {dim}. We can only plot 2D or 3D data.")

return fig, ax

Expand Down
1 change: 1 addition & 0 deletions MARBLE/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Postprocessing module."""

import numpy as np

from MARBLE import geometry as g
Expand Down
16 changes: 8 additions & 8 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Preprocessing module."""

import torch
from torch_geometric.data import Batch
from torch_geometric.data import Data
Expand All @@ -16,14 +17,13 @@ def construct_dataset(
graph_type="cknn",
k=20,
delta=1.0,
n_eigenvalues=None,
frac_geodesic_nb=1.5,
spacing=0.0,
number_of_resamples=1,
var_explained=0.9,
local_gauges=False,
seed=None,
metric='euclidean'
metric="euclidean",
):
"""Construct PyG dataset from node positions and features.
Expand All @@ -35,16 +35,14 @@ def construct_dataset(
graph_type: type of nearest-neighbours graph: cknn (default), knn or radius
k: number of nearest-neighbours to construct the graph
delta: argument for cknn graph construction to decide the radius for each points.
n_eigenvalues: number of eigenvalue/eigenvector pairs to compute (None means all,
but this can be slow)
frac_geodesic_nb: number of geodesic neighbours to fit the gauges to
to map to tangent space k*frac_geodesic_nb
stop_crit: stopping criterion for furthest point sampling
number_of_resamples: number of furthest point sampling runs to prevent bias (experimental)
var_explained: fraction of variance explained by the local gauges
local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2,
embedding dimension is > 2 or dim embedding is not dim of manifold)
seed: Specify for reproducibility in the furthest point sampling.
seed: Specify for reproducibility in the furthest point sampling.
The default is None, which means a random starting vertex.
"""

Expand Down Expand Up @@ -78,9 +76,11 @@ def construct_dataset(
sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx)
sample_ind, _ = torch.sort(sample_ind) # this will make postprocessing easier
a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind]

# fit graph to point cloud
edge_index, edge_weight = g.fit_graph(a_, graph_type=graph_type, par=k, delta=delta, metric=metric)
edge_index, edge_weight = g.fit_graph(
a_, graph_type=graph_type, par=k, delta=delta, metric=metric
)

# define data object
data_ = Data(
Expand Down Expand Up @@ -113,7 +113,7 @@ def construct_dataset(
n_geodesic_nb=k * frac_geodesic_nb,
var_explained=var_explained,
)


def _compute_geometric_objects(
data,
Expand Down
3 changes: 1 addition & 2 deletions MARBLE/smoothing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Smoothing module."""

import torch

Expand Down
1 change: 1 addition & 0 deletions MARBLE/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils module."""

import multiprocessing
from functools import partial
from typing import NamedTuple
Expand Down
Loading

0 comments on commit f848403

Please sign in to comment.