diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index c1d6dfb4a..7c73e1863 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -134,8 +134,10 @@ def _eval_model_on_split(self, global_batch_size=global_batch_size, repeat_final_dataset=True) loss = 0.0 - for _ in range(num_batches): + for i in range(num_batches): eval_batch = next(self._eval_iters[split]) + if i == (num_batches - 1): + print(eval_batch.get('weights')) loss += self._eval_batch(params, eval_batch) if USE_PYTORCH_DDP: dist.all_reduce(loss)