Skip to content

Commit

Permalink
Merge pull request #68 from hsalehipour/major-refactoring
Browse files Browse the repository at this point in the history
A stable and robust BC added for stationary and moving curved BCs
  • Loading branch information
mehdiataei authored Oct 5, 2024
2 parents 60e95f7 + a786ccd commit a50fbbd
Show file tree
Hide file tree
Showing 22 changed files with 843 additions and 245 deletions.
2 changes: 1 addition & 1 deletion examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def setup_stepper(self, omega):
def run(self, num_steps, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup_stepper(self, omega):

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def setup_stepper(self):
def run(self, num_steps, print_interval, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
27 changes: 19 additions & 8 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
RegularizedBC,
HalfwayBounceBackBC,
ExtrapolationOutflowBC,
GradsApproximationBC,
)
from xlb.operator.force.momentum_transfer import MomentumTransfer
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker
from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker, MeshDistanceBoundaryMasker
from xlb.utils import save_fields_vtk, save_image
import warp as wp
import numpy as np
Expand Down Expand Up @@ -51,9 +52,10 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi
self.lift_coefficients = []

def _setup(self):
# NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information
self.initialize_fields()
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper()

def voxelize_stl(self, stl_filename, length_lbm_unit):
Expand Down Expand Up @@ -99,7 +101,8 @@ def setup_boundary_conditions(self):
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
bc_car = HalfwayBounceBackBC(mesh_vertices=car)
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
bc_car = GradsApproximationBC(mesh_vertices=car)
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car]

Expand All @@ -109,7 +112,12 @@ def setup_boundary_masker(self):
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
mesh_boundary_masker = MeshBoundaryMasker(
# mesh_boundary_masker = MeshBoundaryMasker(
# velocity_set=self.velocity_set,
# precision_policy=self.precision_policy,
# compute_backend=self.backend,
# )
mesh_distance_boundary_masker = MeshDistanceBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
Expand All @@ -119,10 +127,12 @@ def setup_boundary_masker(self):
dx = self.grid_spacing
origin, spacing = (0, 0, 0), (dx, dx, dx)
self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask)
self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)
# self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)
self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)
self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC")
Expand All @@ -134,7 +144,7 @@ def run(self, num_steps, print_interval, post_process_interval=100):

start_time = time.time()
for i in range(num_steps):
self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down Expand Up @@ -169,7 +179,7 @@ def post_process(self, i):
save_image(fields["u_magnitude"][:, grid_size_y // 2, :], timestep=i)

# Compute lift and drag
boundary_force = self.momentum_transfer(self.f_0, self.bc_mask, self.missing_mask)
boundary_force = self.momentum_transfer(self.f_0, self.f_1, self.bc_mask, self.missing_mask)
drag = np.sqrt(boundary_force[0] ** 2 + boundary_force[1] ** 2) # xy-plane
lift = boundary_force[2]
c_d = 2.0 * drag / (self.wind_speed**2 * self.car_cross_section)
Expand Down Expand Up @@ -226,7 +236,8 @@ def plot_drag_coefficient(self):
print_interval = 1000

# Set up Reynolds number and deduce relaxation time (omega)
Re = 50000.0
# Re = 50000.0
Re = 500000000000.0
clength = grid_size_x - 1
visc = wind_speed * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
Expand Down
2 changes: 1 addition & 1 deletion examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, num_st
start_time = time.time()

for i in range(num_steps):
f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i)
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i)
f_0, f_1 = f_1, f_0
wp.synchronize()

Expand Down
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC
from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC
from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC
from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC as GradsApproximationBC
15 changes: 9 additions & 6 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def _construct_warp(self):
# Construct the functional for this BC
@wp.func
def functional(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
f_aux: Any,
missing_mask: Any,
):
return f_pre

Expand All @@ -79,8 +82,8 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f_aux = _f_post
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand All @@ -105,8 +108,8 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f_aux = _f_post
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand Down
15 changes: 9 additions & 6 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ def _construct_warp(self):
# Construct the functional for this BC
@wp.func
def functional(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
f_aux: Any,
missing_mask: Any,
):
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f
Expand All @@ -104,8 +107,8 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f_aux = _f_post
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand All @@ -130,8 +133,8 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f_aux = _f_post
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
timestep = 0
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand Down
86 changes: 55 additions & 31 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def prepare_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def apply_jax(self, f_pre, f_post, bc_mask, missing_mask):
def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
boundary = bc_mask == self.id
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
Expand Down Expand Up @@ -160,34 +160,74 @@ def get_normal_vectors_3d(
# Construct the functionals for this BC
@wp.func
def functional(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
f_aux: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
for l in range(self.velocity_set.q):
# If the mask is missing then take the opposite index
if missing_mask[l] == wp.uint8(1):
_f[l] = f_pre[_opp_indices[l]]

return _f

@wp.func
def prepare_bc_auxilary_data(
def prepare_bc_auxilary_data_2d(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
f_aux: Any,
):
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
nv = get_normal_vectors_2d(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
# f_0 is the post-collision values of the current time-step
# Get pull index associated with the "neighbours" pull_index
pull_index = type(index)()
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]])
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
return _f

@wp.func
def prepare_bc_auxilary_data_3d(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
# Preparing the formulation for this BC using the neighbour's populations stored in f_aux and
# f_pre (posti-streaming values of the current voxel). We use directions that leave the domain
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
nv = get_normal_vectors_3d(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux[l]
# f_0 is the post-collision values of the current time-step
# Get pull index associated with the "neighbours" pull_index
pull_index = type(index)()
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]])
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
return _f

# Construct the warp kernel
Expand All @@ -201,29 +241,20 @@ def kernel2d(
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)
timestep = 0

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, bc_mask, missing_mask, index)
_f_aux = _f_vec()

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
nv = get_normal_vectors_2d(_missing_mask)
for l in range(self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1):
# f_0 is the post-collision values of the current time-step
# Get pull index associated with the "neighbours" pull_index
pull_index = type(index)()
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
_f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]]
_f_pre = prepare_bc_auxilary_data_2d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post)

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both
# collision and streaming?
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand All @@ -242,29 +273,21 @@ def kernel3d(
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)
timestep = 0

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, bc_mask, missing_mask, index)
_f_aux = _f_vec()

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
nv = get_normal_vectors_3d(_missing_mask)
for l in range(self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1):
# f_0 is the post-collision values of the current time-step
# Get pull index associated with the "neighbours" pull_index
pull_index = type(index)()
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
_f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]]
_f_pre = prepare_bc_auxilary_data_3d(index, timestep, missing_mask, f_pre, f_post, f_pre, f_post)

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both
# collision and streaming?
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
_f = functional(index, timestep, _missing_mask, f_pre, f_post, _f_pre, _f_post)
else:
_f = _f_post

Expand All @@ -273,6 +296,7 @@ def kernel3d(
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
prepare_bc_auxilary_data = prepare_bc_auxilary_data_3d if self.velocity_set.d == 3 else prepare_bc_auxilary_data_2d

return (functional, prepare_bc_auxilary_data), kernel

Expand Down
Loading

0 comments on commit a50fbbd

Please sign in to comment.