diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index 3320faf7..f40ad214 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -190,7 +190,7 @@ def test_feature_matrix_dtype(dtype, expected_dtype): ) def test_apply_constraint_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" - jax.config.update('jax_enable_x64', True) + jax.config.update("jax_enable_x64", True) x = np.random.randn(10, 5) # add invalid x[:2, 2] = invalid_entries