diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 8924a9865..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -253,7 +253,7 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3729bc53b..3549911fa 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -310,7 +310,7 @@ def _eval_model_on_split(self, weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 0cec4116b..11d6a67e8 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -331,7 +331,7 @@ def _eval_model_on_split(self, 'num_words': num_words, } total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 5407e8a35..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -215,7 +215,7 @@ def _eval_model_on_split(self, model_state, per_device_model_rngs) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } return self._normalize_eval_metrics(num_examples, total_metrics)