Skip to content

Commit

Permalink
FIX - logging working for callback refactoring advances #17
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed May 28, 2024
1 parent 798ba24 commit 7abf4f9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions callbacks/csv_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CSVLogging(Callback):
def __init__(self, csv_path):
self.csv_path = csv_path
self.headers_written = False
self.headers = []

def on_epoch_end(self, epoch, logs=None):
"""
Expand All @@ -31,7 +32,7 @@ def on_epoch_end(self, epoch, logs=None):
if logs is None:
return

epoch_data = logs.get('epoch')
epoch_data = logs.get('epoch', epoch)
train_loss = logs.get('train_loss')
val_loss = logs.get('val_loss')
train_metrics = logs.get('train_metrics', {})
Expand All @@ -42,13 +43,13 @@ def on_epoch_end(self, epoch, logs=None):
metrics.update({f'val_{key}': value for key, value in val_metrics.items()})

if not self.headers_written:
headers = ['epoch'] + list(metrics.keys())
self.headers = ['epoch'] + list(metrics.keys())
with open(self.csv_path, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(headers)
writer.writerow(self.headers)
self.headers_written = True

values = [epoch_data] + [metrics[key] for key in headers[1:]] # Ensure the order matches headers
values = [epoch_data] + [metrics.get(key) for key in self.headers[1:]]
with open(self.csv_path, 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow(values)

0 comments on commit 7abf4f9

Please sign in to comment.