Skip to content

Commit

Permalink
fix: removed the jax.config import from the ivy_tests init as it's de…
Browse files Browse the repository at this point in the history
…precated in the recent release (ivy-llc#28444)
  • Loading branch information
vedpatwardhan authored Feb 28, 2024
1 parent 61e3532 commit 8367486
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
25 changes: 11 additions & 14 deletions ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,14 @@ def check_dev_correct_formatting(device):


def _check_jax_x64_flag(dtype):
if ivy.backend == "jax":
import jax

if not jax.config.x64_enabled:
ivy.utils.assertions.check_elem_in_list(
dtype,
["float64", "int64", "uint64", "complex128"],
inverse=True,
message=(
f"{dtype} output not supported while jax_enable_x64"
" is set to False, please import jax and enable the flag using "
"jax.config.update('jax_enable_x64', True)"
),
)
if ivy.backend == "jax" and not ivy.functional.backends.jax.jax.config.x64_enabled:
ivy.utils.assertions.check_elem_in_list(
dtype,
["float64", "int64", "uint64", "complex128"],
inverse=True,
message=(
f"{dtype} output not supported while jax_enable_x64"
" is set to False, please import jax and enable the flag using "
"jax.config.update('jax_enable_x64', True)"
),
)
4 changes: 2 additions & 2 deletions ivy_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
try:
from jax.config import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
except (ImportError, RuntimeError):
pass

0 comments on commit 8367486

Please sign in to comment.