From e2028cbce6e107cfbad475595f09e224fdd0c11e Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 13 Sep 2024 18:29:36 -0400 Subject: [PATCH] Operators work in fp64 and fp64 with no issues --- examples/cfd/flow_past_sphere_3d.py | 5 +++++ examples/cfd/lid_driven_cavity_2d.py | 5 +++++ examples/cfd/lid_driven_cavity_2d_distributed.py | 5 +++++ xlb/default_config.py | 6 +++--- xlb/operator/equilibrium/quadratic_equilibrium.py | 4 ++-- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index e09a4bf..5a70d20 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -19,6 +19,7 @@ import numpy as np import jax.numpy as jnp import time +import jax class FlowOverSphere: @@ -141,6 +142,10 @@ def post_process(self, i): grid_shape = (512 // 2, 128 // 2, 128 // 2) backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 837aa01..f73fb0e 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -8,6 +8,7 @@ from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp +import jax import jax.numpy as jnp import xlb.velocity_set @@ -106,6 +107,10 @@ def post_process(self, i): grid_shape = (grid_size, grid_size) backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 397d476..7efe907 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -1,4 +1,5 @@ import xlb +import jax from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy from xlb.operator.stepper import IncompressibleNavierStokesStepper @@ -27,6 +28,10 @@ def setup_stepper(self, omega): grid_shape = (grid_size, grid_size) backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 diff --git a/xlb/default_config.py b/xlb/default_config.py index f1ca25f..20eac44 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -1,5 +1,7 @@ +import jax from xlb.compute_backend import ComputeBackend from dataclasses import dataclass +from xlb.precision_policy import PrecisionPolicy @dataclass @@ -17,7 +19,7 @@ def init(velocity_set, default_backend, default_precision_policy): if default_backend == ComputeBackend.WARP: import warp as wp - wp.init() + wp.init() # TODO: Must be removed in the future versions of WARP elif default_backend == ComputeBackend.JAX: check_multi_gpu_support() else: @@ -29,8 +31,6 @@ def default_backend() -> ComputeBackend: def check_multi_gpu_support(): - import jax - gpus = jax.devices("gpu") if len(gpus) > 1: print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 2fe1526..0cce91b 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -52,10 +52,10 @@ def functional( cu *= self.compute_dtype(3.0) # Compute usqr - usqr = 1.5 * wp.dot(u, u) + usqr = self.compute_dtype(1.5) * wp.dot(u, u) # Compute feq - feq[l] = rho * _w[l] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + feq[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu * (self.compute_dtype(1.0) + self.compute_dtype(0.5) * cu) - usqr) return feq