diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index c36835704..034d5a2bf 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -135,11 +135,15 @@ def _eval_model_on_split(self, repeat_final_dataset=False) loss = 0.0 size = 0 - for eval_batch in self._eval_iters[split]: - # if i == (num_batches - 1): - # print(eval_batch.get('weights')) - loss += self._eval_batch(params, eval_batch) - size += eval_batch.get('weights').sum() + try: + for eval_batch in self._eval_iters[split]: + # if i == (num_batches - 1): + # print(eval_batch.get('weights')) + loss += self._eval_batch(params, eval_batch) + size += eval_batch.get('weights').sum() + except: + pass + if USE_PYTORCH_DDP: dist.all_reduce(loss) dist.all_reduce(size)