Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completed Regularized bc and debugged ZouHe in both JAX and Warp #59

Merged
merged 11 commits into from
Aug 20, 2024
20 changes: 14 additions & 6 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
ZouHeBC,
RegularizedBC,
EquilibriumBC,
DoNothingBC,
)
Expand Down Expand Up @@ -67,11 +69,16 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = DoNothingBC(indices=outlet)
bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
# bc_outlet = DoNothingBC(indices=outlet)
bc_sphere = FullwayBounceBackBC(indices=sphere)
self.boundary_conditions = [bc_left, bc_walls, bc_do_nothing, bc_sphere]
self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls]
# Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because
# of the corner nodes. This way the corners are treated as wall and not inlet/outlet.
# TODO: how to ensure about this behind in the src code?

def setup_boundary_masks(self):
indices_boundary_masker = IndicesBoundaryMasker(
Expand All @@ -85,7 +92,7 @@ def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
Expand All @@ -107,17 +114,18 @@ def post_process(self, i):

# remove boundary cells
u = u[:, 1:-1, 1:-1, 1:-1]
rho = rho[:, 1:-1, 1:-1, 1:-1]
u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5

fields = {"u_magnitude": u_magnitude}
fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]}

save_fields_vtk(fields, timestep=i)
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)


if __name__ == "__main__":
# Running the simulation
grid_shape = (512, 128, 128)
grid_shape = (512 // 2, 128 // 2, 128 // 2)
velocity_set = xlb.velocity_set.D3Q19()
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32
Expand Down
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC
from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC
4 changes: 3 additions & 1 deletion xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -47,7 +48,8 @@ def __init__(
@partial(jit, static_argnums=(0))
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post)

def _construct_warp(self):
Expand Down
4 changes: 3 additions & 1 deletion xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(
@partial(jit, static_argnums=(0))
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
return jnp.where(
jnp.logical_and(missing_mask, boundary),
f_pre[self.velocity_set.opp_indices],
Expand Down
Loading
Loading