Skip to content

Commit

Permalink
migrate to PTL 2.0
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandra Antonova <[email protected]>
  • Loading branch information
bene-ges committed Nov 21, 2023
1 parent 08937c8 commit 3c7981d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@

@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config")
def main(cfg: DictConfig) -> None:
# PTL 2.0 has find_unused_parameters as False by default, so its required to set it to True
# when there are unused parameters like here
if cfg.trainer.strategy == 'ddp':
cfg.trainer.strategy = "ddp_find_unused_parameters_true"
logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

# Train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@ def validation_step(self, batch, batch_idx):
)

val_loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask)
self.validation_step_outputs.append({'val_loss': val_loss})
return {'val_loss': val_loss}

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()

# Calculate metrics and classification report
# Note that in our task recall = accuracy, and the recall column is the per class accuracy
Expand All @@ -300,12 +301,12 @@ def test_step(self, batch, batch_idx):
"""
return self.validation_step(batch, batch_idx)

def test_epoch_end(self, outputs):
def on_test_epoch_end(self):
"""
Called at the end of test to aggregate outputs.
:param outputs: list of individual outputs of each test step.
"""
return self.validation_epoch_end(outputs)
return self.on_validation_epoch_end()

# Functions for inference

Expand Down

0 comments on commit 3c7981d

Please sign in to comment.