Skip to content

Commit

Permalink
Merge pull request #95 from mehdiataei/main
Browse files Browse the repository at this point in the history
Get ready for 0.2.1 release. Fixed a bug running JAX on CPU.
  • Loading branch information
mehdiataei authored Dec 6, 2024
2 parents f7bed81 + 2f8ffe8 commit 237ac22
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
18 changes: 13 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- _No changes yet_ <!-- Placeholder for future changes -->


## [0.2.1] - 2024-12-05

### Fixed
- mkdocs is now configured correctly for the new project structure
- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU)
- mkdocs is now configured correctly for the new project structure
- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU)
- Fixed a couple of bugs in 2D regularied_bc and kbc (Warp) that emerged after merging 2d and 3d kernels

### Added

- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
- Added prepare_fields method to the Stepper class to allow for more automatic preparation of fields

## [0.2.0] - 2024-10-22

### Added
- XLB is now installable via pip
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
- Added NVIDIA's Warp backend for state-of-the-art performance
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
- Added NVIDIA's Warp backend for state-of-the-art performance
22 changes: 15 additions & 7 deletions xlb/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def init(velocity_set, default_backend, default_precision_policy):

wp.init() # TODO: Must be removed in the future versions of WARP
elif default_backend == ComputeBackend.JAX:
check_multi_gpu_support()
check_backend_support()
else:
raise ValueError(f"Unsupported compute backend: {default_backend}")

Expand All @@ -30,11 +30,19 @@ def default_backend() -> ComputeBackend:
return DefaultConfig.default_backend


def check_multi_gpu_support():
gpus = jax.devices("gpu")
if len(gpus) > 1:
print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus)))
elif len(gpus) == 1:
print("Single-GPU support is available: 1 GPU detected.")
def check_backend_support():
if jax.devices()[0].device_kind == "gpu":
gpus = jax.devices("gpu")
if len(gpus) > 1:
print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus)))
elif len(gpus) == 1:
print("Single-GPU support is available: 1 GPU detected.")

if jax.devices()[0].device_kind == "tpu":
tpus = jax.devices("tpu")
if len(tpus) > 1:
print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus)))
elif len(tpus) == 1:
print("Single-TPU support is available: 1 TPU detected.")
else:
print("No GPU support is available; CPU fallback will be used.")

0 comments on commit 237ac22

Please sign in to comment.