Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/agosztolai/MARBLE into main
Browse files Browse the repository at this point in the history
  • Loading branch information
peach-lucien committed Dec 6, 2023
2 parents 0844b59 + 67f1bc8 commit 4ef6755
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
13 changes: 8 additions & 5 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def fit(self, data, outdir=None, verbose=False):

print("\n---- Training network ...")

time = datetime.now().strftime("%Y%m%d-%H%M%S")
self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

print("\n---- Timestamp: {}".format(self.timestamp))

# load to gpu (if possible)
# pylint: disable=self-cls-assignment
Expand Down Expand Up @@ -385,7 +387,7 @@ def fit(self, data, outdir=None, verbose=False):

if best_loss == -1 or (val_loss < best_loss):
outdir = self.save_model(
optimizer, self.losses, outdir=outdir, best=True, timestamp=time
optimizer, self.losses, outdir=outdir, best=True, timestamp=self.timestamp
)
best_loss = val_loss
print(" *", end="")
Expand All @@ -398,8 +400,8 @@ def fit(self, data, outdir=None, verbose=False):

self.losses["test_loss"].append(test_loss)

self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=time)
self.load_model(os.path.join(outdir, f"best_model_{time}.pth"))
self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=self.timestamp)
self.load_model(os.path.join(outdir, f"best_model_{self.timestamp}.pth"))

def load_model(self, loadpath):
"""Load model.
Expand All @@ -413,7 +415,8 @@ def load_model(self, loadpath):
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
self.losses = checkpoint["losses"]
if hasattr(self, 'losses'):
self.losses = checkpoint["losses"]

def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""):
"""Save model."""
Expand Down
2 changes: 2 additions & 0 deletions examples/RNN/RNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@
"#params = {'epochs': 40, \n",
"# 'hidden_channels': 64, \n",
"# 'out_channels': 5,\n",
"# 'diffusion': False,\n",
"# 'inner_product_features': False, #geometry-aware for maximal expressivity\n",
"# }\n",
"#model = MARBLE.net(data, params=params)\n",
Expand Down Expand Up @@ -764,6 +765,7 @@
"# 'order': 2,\n",
"# 'hidden_channels': 64,\n",
"# 'out_channels': 5,\n",
"# 'diffusion': False,\n",
"# 'inner_product_features': True, #geometry-agnostic as manifolds are differently oriented across networks\n",
"# }\n",
"\n",
Expand Down
Binary file not shown.

0 comments on commit 4ef6755

Please sign in to comment.