Skip to content

Commit

Permalink
Add version limit for jaxlib (#362)
Browse files Browse the repository at this point in the history
* add version limit for jaxlib

* inject jaxlib version in setup.py instead
  • Loading branch information
dionhaefner authored Aug 1, 2022
1 parent 3584066 commit 90f5add
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
1 change: 0 additions & 1 deletion requirements_jax.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
jax==0.3.14
jaxlib
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ def parse_requirements(reqfile):

INSTALL_REQUIRES = parse_requirements("requirements.txt")


jax_req = parse_requirements("requirements_jax.txt")
for line in jax_req: # inject jaxlib requirement
if line.startswith("jax"):
jax_req.append(line.replace("jax", "jaxlib"))
break

EXTRAS_REQUIRE = {
"test": ["pytest", "pytest-cov", "pytest-forked", "codecov", "xarray"],
"jax": jax_req,
}
EXTRAS_REQUIRE["jax"] = parse_requirements("requirements_jax.txt")


def get_extensions(require_cython_ext, require_cuda_ext):
Expand Down
14 changes: 2 additions & 12 deletions veros/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,13 @@ def get_backend_module(backend_name):
if backend_name not in BACKENDS:
raise ValueError(f"unrecognized backend {backend_name} (must be either of: {list(BACKENDS.keys())!r})")

backend_module = None

if backend_name == "jax":
try:
import jax # noqa: F401
except ImportError:
pass
else:
init_jax_config()
import jax.numpy as backend_module
init_jax_config()
import jax.numpy as backend_module

elif backend_name == "numpy":
import numpy as backend_module

if backend_module is None:
raise ValueError(f'backend "{backend_name}" failed to import')

return backend_module


Expand Down

0 comments on commit 90f5add

Please sign in to comment.