diff --git a/xlb/default_config.py b/xlb/default_config.py index ff54483..8b60e96 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -31,14 +31,14 @@ def default_backend() -> ComputeBackend: def check_backend_support(): - if jax.devices()[0].device_kind == "gpu": + if jax.devices()[0].platform == "gpu": gpus = jax.devices("gpu") if len(gpus) > 1: print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) elif len(gpus) == 1: print("Single-GPU support is available: 1 GPU detected.") - if jax.devices()[0].device_kind == "tpu": + if jax.devices()[0].platform == "tpu": tpus = jax.devices("tpu") if len(tpus) > 1: print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus))) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index ff014f3..304ae99 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -35,8 +35,9 @@ def are_indices_in_interior(self, indices, shape): :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. :return: Array of boolean flags where each flag indicates whether the corresponding index is inside the bounds. """ + d = self.velocity_set.d shape_array = np.array(shape) - return np.all((indices > 0) & (indices < shape_array[:, np.newaxis] - 1), axis=0) + return np.all((indices[:d] > 0) & (indices[:d] < shape_array[:d, np.newaxis] - 1), axis=0) @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators!