diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index fc094e4..9684043 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -269,7 +269,7 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg return { key: value.item() for key, value in trainer.callback_metrics.items() - if any(key.startswith(key) for key in ["training/", "validation/"]) + if any(key.startswith(k) for k in ["training/", "validation/"]) } @torch.enable_grad()