diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index 0569e41..961b485 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -82,7 +82,7 @@ def compute_dot_products_with_loader( for batch in train_loader: batch = send_to_device(tensor=batch, device=state.device) - print(cached_module_lst[0]) + print(cached_module_lst[0].storage) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True)