From fe26536211654e2417fb6f6d606da2c7f5ad5e70 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 24 Oct 2024 03:22:52 -0400 Subject: [PATCH] JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) --- CHANGELOG.md | 1 + setup.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b25184..1d036de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - mkdocs is now configured correctly for the new project structure +- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) ## [0.2.0] - 2024-10-22 diff --git a/setup.py b/setup.py index a7f06f9..d893d39 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='xlb', - version='0.2.0', + version='0.2.1', description='XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML', long_description=open('README.md').read(), long_description_content_type='text/markdown', @@ -11,7 +11,6 @@ license='Apache License 2.0', packages=find_packages(), install_requires=[ - 'jax[cuda]>=0.4.34', 'matplotlib>=3.9.2', 'numpy>=2.1.2', 'pyvista>=0.44.1', @@ -19,7 +18,15 @@ 'warp-lang>=1.4.0', 'numpy-stl>=3.1.2', 'pydantic>=2.9.1', - 'ruff>=0.6.5' + 'ruff>=0.6.5', + 'jax>=0.4.34' # Base JAX CPU-only requirement ], + extras_require={ + 'cuda': ['jax[cuda12]>=0.4.34'], # For CUDA installations + 'tpu': ['jax[tpu]>=0.4.34'], # For TPU installations + }, python_requires='>=3.10', + dependency_links=[ + 'https://storage.googleapis.com/jax-releases/libtpu_releases.html' + ], )