diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 3ba247360..3ac6c292c 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -346,6 +346,8 @@ def data_selection(workload: spec.Workload, del global_step del rng batch = next(input_queue) + breakpoint() print('BATCH STATS') - print(sum(batch['weights'])) + + print((batch['weights']).shape) return batch