diff --git a/CHANGELOG.md b/CHANGELOG.md index 69ad1f4..48f9cc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,15 +8,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - _No changes yet_ + +## [0.2.1] - 2024-12-05 + ### Fixed -- mkdocs is now configured correctly for the new project structure -- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) +- mkdocs is now configured correctly for the new project structure +- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) +- Fixed a couple of bugs in 2D regularied_bc and kbc (Warp) that emerged after merging 2d and 3d kernels + +### Added + +- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data +- Added the capability to add profiles to boundary conditions +- Added prepare_fields method to the Stepper class to allow for more automatic preparation of fields ## [0.2.0] - 2024-10-22 ### Added - XLB is now installable via pip - Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern -- Added NVIDIA's Warp backend for state-of-the-art performance -- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data -- Added the capability to add profiles to boundary conditions \ No newline at end of file +- Added NVIDIA's Warp backend for state-of-the-art performance \ No newline at end of file diff --git a/xlb/default_config.py b/xlb/default_config.py index 20eac44..ff54483 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -21,7 +21,7 @@ def init(velocity_set, default_backend, default_precision_policy): wp.init() # TODO: Must be removed in the future versions of WARP elif default_backend == ComputeBackend.JAX: - check_multi_gpu_support() + check_backend_support() else: raise ValueError(f"Unsupported compute backend: {default_backend}") @@ -30,11 +30,19 @@ def default_backend() -> ComputeBackend: return DefaultConfig.default_backend -def check_multi_gpu_support(): - gpus = jax.devices("gpu") - if len(gpus) > 1: - print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) - elif len(gpus) == 1: - print("Single-GPU support is available: 1 GPU detected.") +def check_backend_support(): + if jax.devices()[0].device_kind == "gpu": + gpus = jax.devices("gpu") + if len(gpus) > 1: + print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) + elif len(gpus) == 1: + print("Single-GPU support is available: 1 GPU detected.") + + if jax.devices()[0].device_kind == "tpu": + tpus = jax.devices("tpu") + if len(tpus) > 1: + print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus))) + elif len(tpus) == 1: + print("Single-TPU support is available: 1 TPU detected.") else: print("No GPU support is available; CPU fallback will be used.")