diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 06ec347..88c6af2 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -592,6 +592,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: storage=self._storage, damping=self.score_args.damping, ) + assert preconditioned_gradient.is_contiguous() self._cached_per_sample_gradient = None preconditioned_gradient = preconditioned_gradient.to(dtype=self.score_args.score_dtype)