Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai authored Dec 6, 2023
1 parent 8a1fbe4 commit 67f1bc8
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def fit(self, data, outdir=None, verbose=False):

print("\n---- Training network ...")

time = datetime.now().strftime("%Y%m%d-%H%M%S")
self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

print("\n---- Timestamp: {}".format(self.timestamp))

# load to gpu (if possible)
# pylint: disable=self-cls-assignment
Expand Down Expand Up @@ -385,7 +387,7 @@ def fit(self, data, outdir=None, verbose=False):

if best_loss == -1 or (val_loss < best_loss):
outdir = self.save_model(
optimizer, self.losses, outdir=outdir, best=True, timestamp=time
optimizer, self.losses, outdir=outdir, best=True, timestamp=self.timestamp
)
best_loss = val_loss
print(" *", end="")
Expand All @@ -398,8 +400,8 @@ def fit(self, data, outdir=None, verbose=False):

self.losses["test_loss"].append(test_loss)

self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=time)
self.load_model(os.path.join(outdir, f"best_model_{time}.pth"))
self.save_model(optimizer, self.losses, outdir=outdir, best=False, timestamp=self.timestamp)
self.load_model(os.path.join(outdir, f"best_model_{self.timestamp}.pth"))

def load_model(self, loadpath):
"""Load model.
Expand Down

0 comments on commit 67f1bc8

Please sign in to comment.