Skip to content

Commit

Permalink
Removed the need to have separate JAX/Warp constants
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Sep 13, 2024
1 parent fe8c945 commit 84e1d46
Show file tree
Hide file tree
Showing 36 changed files with 179 additions and 114 deletions.
8 changes: 6 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)
rho, u = macro(f_0)

# remove boundary cells
Expand All @@ -135,9 +139,9 @@ def post_process(self, i):
if __name__ == "__main__":
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
velocity_set = xlb.velocity_set.D3Q19()
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
9 changes: 7 additions & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xlb.utils import save_fields_vtk, save_image
import warp as wp
import jax.numpy as jnp
import xlb.velocity_set


class LidDrivenCavity2D:
Expand Down Expand Up @@ -80,7 +81,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand All @@ -100,8 +105,8 @@ def post_process(self, i):
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def setup_stepper(self, omega):
grid_size = 512
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
8 changes: 6 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand Down Expand Up @@ -215,8 +219,8 @@ def plot_drag_coefficient(self):

# Configuration
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D3Q27()
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend)
wind_speed = 0.02
num_steps = 100000
print_interval = 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
13 changes: 7 additions & 6 deletions tests/grids/test_grid_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import jax.numpy as jnp


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=xlb.velocity_set.D2Q9, # does not affect the test
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_2d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -34,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size):

@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_3d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
grid_shape = (grid_size, grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -54,7 +55,7 @@ def test_jax_3d_grid_initialization(grid_size):


def test_jax_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -66,7 +67,7 @@ def test_jax_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions tests/grids/test_grid_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from xlb.precision_policy import Precision


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=xlb.velocity_set.D2Q9,
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_warp_grid_create_field(grid_size):
for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]:
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9, dtype=Precision.FP32)

Expand All @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size):


def test_warp_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape,omega",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/equilibrium/test_equilibrium_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/equilibrium/test_equilibrium_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/macroscopic/test_macroscopic_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/macroscopic/test_macroscopic_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def _construct_warp(self):
# Set local constants
sound_speed = 1.0 / wp.sqrt(3.0)
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_c = self.velocity_set.wp_c
_c = self.velocity_set.c
_q = self.velocity_set.q
_opp_indices = self.velocity_set.wp_opp_indices
_opp_indices = self.velocity_set.opp_indices

@wp.func
def get_normal_vectors_2d(
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask):

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_opp_indices = self.velocity_set.wp_opp_indices
_opp_indices = self.velocity_set.opp_indices
_q = wp.constant(self.velocity_set.q)
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask):

def _construct_warp(self):
# Set local constants
_opp_indices = self.velocity_set.wp_opp_indices
_opp_indices = self.velocity_set.opp_indices

# Construct the functional for this BC
@wp.func
Expand Down
10 changes: 5 additions & 5 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def _construct_warp(self):
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = wp.float32(rho)
_u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1])
_opp_indices = self.velocity_set.wp_opp_indices
_w = self.velocity_set.wp_w
_c = self.velocity_set.wp_c
_c32 = self.velocity_set.wp_c32
_qi = self.velocity_set.wp_qi
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c32 = self.velocity_set.c32
_qi = self.velocity_set.qi
# TODO: related to _c32: this is way less than ideal. we should not be making new types

@wp.func
Expand Down
Loading

0 comments on commit 84e1d46

Please sign in to comment.