Skip to content

Commit

Permalink
Operators work in fp64 and fp64 with no issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Sep 13, 2024
1 parent 2fd397e commit e2028cb
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
5 changes: 5 additions & 0 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import jax.numpy as jnp
import time
import jax


class FlowOverSphere:
Expand Down Expand Up @@ -141,6 +142,10 @@ def post_process(self, i):
grid_shape = (512 // 2, 128 // 2, 128 // 2)
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
omega = 1.6

Expand Down
5 changes: 5 additions & 0 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import warp as wp
import jax
import jax.numpy as jnp
import xlb.velocity_set

Expand Down Expand Up @@ -106,6 +107,10 @@ def post_process(self, i):
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

Expand Down
5 changes: 5 additions & 0 deletions examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import xlb
import jax
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.operator.stepper import IncompressibleNavierStokesStepper
Expand Down Expand Up @@ -27,6 +28,10 @@ def setup_stepper(self, omega):
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

Expand Down
6 changes: 3 additions & 3 deletions xlb/default_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax
from xlb.compute_backend import ComputeBackend
from dataclasses import dataclass
from xlb.precision_policy import PrecisionPolicy


@dataclass
Expand All @@ -17,7 +19,7 @@ def init(velocity_set, default_backend, default_precision_policy):
if default_backend == ComputeBackend.WARP:
import warp as wp

wp.init()
wp.init() # TODO: Must be removed in the future versions of WARP
elif default_backend == ComputeBackend.JAX:
check_multi_gpu_support()
else:
Expand All @@ -29,8 +31,6 @@ def default_backend() -> ComputeBackend:


def check_multi_gpu_support():
import jax

gpus = jax.devices("gpu")
if len(gpus) > 1:
print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus)))
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def functional(
cu *= self.compute_dtype(3.0)

# Compute usqr
usqr = 1.5 * wp.dot(u, u)
usqr = self.compute_dtype(1.5) * wp.dot(u, u)

# Compute feq
feq[l] = rho * _w[l] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
feq[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu * (self.compute_dtype(1.0) + self.compute_dtype(0.5) * cu) - usqr)

return feq

Expand Down

0 comments on commit e2028cb

Please sign in to comment.