Skip to content

Commit

Permalink
Added abstraction layer for boundary condition aux data and implement…
Browse files Browse the repository at this point in the history
…aiton, and the capability to add profiles to boundary conditions
  • Loading branch information
mehdiataei committed Nov 28, 2024
1 parent 2b6355b commit 49cb066
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 194 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- XLB is now installable via pip
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
- Added NVIDIA's Warp backend for state-of-the-art performance
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
- Added the capability to add profiles to boundary conditions
71 changes: 57 additions & 14 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import jax.numpy as jnp
import time
from functools import partial
from jax import jit


class FlowOverSphere:
Expand All @@ -37,13 +39,13 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.grid, self.f_0, self.f_1, self.missing_mask, self.bc_mask = create_nse_fields(grid_shape)
self.stepper = None
self.boundary_conditions = []
self.u_max = 0.04

# Setup the simulation BC, its initial conditions, and the stepper
self._setup(omega)

def _setup(self, omega):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper(omega)

Expand All @@ -69,7 +71,7 @@ def define_boundary_indices(self):

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", (0.04, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
# bc_left = EquilibriumBC(rho = 1, u=(0.04, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
# bc_outlet = RegularizedBC("pressure", 1.0, indices=outlet)
Expand All @@ -78,22 +80,63 @@ def setup_boundary_conditions(self):
bc_sphere = HalfwayBounceBackBC(indices=sphere)
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0))

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
f_0=self.f_0,
f_1=self.f_1,
bc_mask=self.bc_mask,
missing_mask=self.missing_mask,
omega=omega,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def bc_profile(self):
u_max = self.u_max # u_max = 0.04
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)
# return u_max

# @partial(jit, inline=True)
def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
y_center = Y - (H_y / 2.0)
z_center = Z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile for x velocity, zero for y and z
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
u_y = jnp.zeros_like(u_x)
u_z = jnp.zeros_like(u_x)

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def run(self, num_steps, post_process_interval=100):
start_time = time.time()
Expand Down
23 changes: 10 additions & 13 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre

def _setup(self, omega):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper(omega)

Expand All @@ -54,21 +53,19 @@ def setup_boundary_conditions(self):
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_walls, bc_top]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
f_0=self.f_0,
f_1=self.f_1,
bc_mask=self.bc_mask,
missing_mask=self.missing_mask,
omega=omega,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
Expand Down Expand Up @@ -109,7 +106,7 @@ def post_process(self, i):
# Running the simulation
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
Expand Down
13 changes: 10 additions & 3 deletions examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)

def setup_stepper(self, omega):
stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
distributed_stepper = distribute(
stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
f_0=self.f_0,
f_1=self.f_1,
bc_mask=self.bc_mask,
missing_mask=self.missing_mask,
omega=omega,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)
self.stepper = distribute(
stepper,
self.grid,
self.velocity_set,
)
self.stepper = distributed_stepper
return


Expand Down
23 changes: 10 additions & 13 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def get_force(self):

def _setup(self):
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.initialize_fields()
self.setup_stepper()

Expand All @@ -86,14 +85,6 @@ def setup_boundary_conditions(self):
bc_walls = RegularizedBC("velocity", (0.0, 0.0, 0.0), indices=walls)
self.boundary_conditions = [bc_walls]

def setup_boundary_masker(self):
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
shape = (self.velocity_set.d,) + (self.grid_shape)
np.random.seed(0)
Expand All @@ -104,10 +95,16 @@ def initialize_fields(self):
u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init)

def setup_stepper(self):
force = self.get_force()
self.stepper = IncompressibleNavierStokesStepper(
self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC", forcing_scheme="exact_difference", force_vector=force
def setup_stepper(self, omega):
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
f_0=self.f_0,
f_1=self.f_1,
bc_mask=self.bc_mask,
missing_mask=self.missing_mask,
omega=omega,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
force=self.get_force(),
)

def run(self, num_steps, print_interval, post_process_interval=100):
Expand Down
100 changes: 60 additions & 40 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from jax import jit


class WindTunnel3D:
Expand Down Expand Up @@ -55,8 +57,7 @@ def _setup(self):
# NOTE: it is important to initialize fields before setup_boundary_masker is called because f_0 or f_1 might be used to store BC information
self.initialize_fields()
self.setup_boundary_conditions()
self.setup_boundary_masker()
self.setup_stepper()
self.setup_stepper(self.omega)

def voxelize_stl(self, stl_filename, length_lbm_unit):
mesh = trimesh.load_mesh(stl_filename, process=False)
Expand Down Expand Up @@ -85,57 +86,77 @@ def define_boundary_indices(self):
length_phys_unit = mesh_extents.max()
length_lbm_unit = self.grid_shape[0] / 4
dx = length_phys_unit / length_lbm_unit
shift = np.array([self.grid_shape[0] * dx / 4, (self.grid_shape[1] * dx - mesh_extents[1]) / 2, 0.0])
mesh_vertices = mesh_vertices / dx
shift = np.array([self.grid_shape[0] / 4, (self.grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0])
car = mesh_vertices + shift
self.grid_spacing = dx
self.car_cross_section = np.prod(mesh_extents[1:]) / dx**2

return inlet, outlet, walls, car

def setup_boundary_conditions(self):
inlet, outlet, walls, car = self.define_boundary_indices()
bc_left = EquilibriumBC(rho=1.0, u=(self.wind_speed, 0.0, 0.0), indices=inlet)
# bc_left = RegularizedBC('velocity', (self.wind_speed, 0.0, 0.0), indices=inlet)
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
bc_car = GradsApproximationBC(mesh_vertices=car)
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
bc_car = FullwayBounceBackBC(mesh_vertices=car)
self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
# mesh_boundary_masker = MeshBoundaryMasker(
# velocity_set=self.velocity_set,
# precision_policy=self.precision_policy,
# compute_backend=self.backend,
# )
mesh_distance_boundary_masker = MeshDistanceBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
compute_backend=self.backend,
)
bclist_other = self.boundary_conditions[:-1]
bc_mesh = self.boundary_conditions[-1]
dx = self.grid_spacing
origin, spacing = (0, 0, 0), (dx, dx, dx)
self.bc_mask, self.missing_mask = indices_boundary_masker(bclist_other, self.bc_mask, self.missing_mask)
# self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)
self.bc_mask, self.missing_mask, self.f_1 = mesh_distance_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask, self.f_1)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)
self.f_1 = initialize_eq(self.f_1, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC")
def bc_profile(self):
u_max = self.wind_speed
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), 0.0, 0.0, 0.0, 0.0, length=5)

def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
y_center = Y - (H_y / 2.0)
z_center = Z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile for x velocity, zero for y and z
u_x = u_max * jnp.maximum(0.0, 1.0 - r_squared)
u_y = jnp.zeros_like(u_x)
u_z = jnp.zeros_like(u_x)

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp

def setup_stepper(self, omega):
self.stepper, self.f_0, self.f_1, self.bc_mask, self.missing_mask = IncompressibleNavierStokesStepper(
f_0=self.f_0,
f_1=self.f_1,
bc_mask=self.bc_mask,
missing_mask=self.missing_mask,
omega=omega,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def run(self, num_steps, print_interval, post_process_interval=100):
# Setup the operator for computing surface forces at the interface of the specified BC
Expand Down Expand Up @@ -236,8 +257,7 @@ def plot_drag_coefficient(self):
print_interval = 1000

# Set up Reynolds number and deduce relaxation time (omega)
# Re = 50000.0
Re = 500000000000.0
Re = 5000.0
clength = grid_size_x - 1
visc = wind_speed * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
Expand Down
Loading

0 comments on commit 49cb066

Please sign in to comment.