diff --git a/tests/test_grad.py b/tests/test_grad.py index 2631134..d3775f7 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -63,5 +63,5 @@ def test_low_order_cpu( at.requires_grad = at_requires_grad rt.requires_grad = rt_requires_grad - assert gradcheck(compressor_core, (x, zi, at, rt), check_forward_ad=True, check_backward_ad=False) + assert gradcheck(compressor_core, (x, zi, at, rt), check_forward_ad=True) assert gradgradcheck(compressor_core, (x, zi, at, rt))