Skip to content

Commit

Permalink
add difference between val and test
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandra Antonova <[email protected]>
  • Loading branch information
bene-ges committed Dec 1, 2023
1 parent 5ed2e7f commit e196bca
Showing 1 changed file with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def training_step(self, batch, batch_idx):
return {'loss': loss, 'lr': lr}

# Validation and Testing
def validation_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx, split="val"):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
Expand Down Expand Up @@ -271,16 +271,25 @@ def validation_step(self, batch, batch_idx):
torch.tensor(span_predictions).to(self.device), torch.tensor(span_labels).to(self.device)
)

val_loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask)
self.validation_step_outputs.append({'val_loss': val_loss})
return {'val_loss': val_loss}
loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask)

if split == 'val':
self.validation_step_outputs.append({f'{split}_loss': loss})
elif split == 'test':
self.test_step_outputs.append({f'{split}_loss': loss})

return {f'{split}_loss': loss}

def on_validation_epoch_end(self):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()
split = "test" if self.trainer.testing else "val"
if split == 'val':
avg_loss = torch.stack([x[f'{split}_loss'] for x in self.validation_step_outputs]).mean()
elif split == 'test':
avg_loss = torch.stack([x[f'{split}_loss'] for x in self.test_step_outputs]).mean()

# Calculate metrics and classification report
# Note that in our task recall = accuracy, and the recall column is the per class accuracy
Expand All @@ -289,8 +298,8 @@ def on_validation_epoch_end(self):
logging.info("Total tag accuracy: " + str(tag_accuracy))
logging.info(tag_report)

self.log('val_loss', avg_loss, prog_bar=True)
self.log('tag accuracy', tag_accuracy)
self.log(f"{split}_loss", avg_loss, prog_bar=True)
self.log(f"{split}_tag_accuracy", tag_accuracy)

self.tag_classification_report.reset()

Expand All @@ -299,7 +308,7 @@ def test_step(self, batch, batch_idx):
Lightning calls this inside the test loop with the data from the test dataloader
passed in as `batch`.
"""
return self.validation_step(batch, batch_idx)
return self.validation_step(batch, batch_idx, split="test")

def on_test_epoch_end(self):
"""
Expand Down

0 comments on commit e196bca

Please sign in to comment.