Skip to content

Commit

Permalink
Finalized changes
Browse files Browse the repository at this point in the history
  • Loading branch information
loliverhennigh committed Mar 30, 2024
1 parent 0387247 commit 18abd12
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 34 deletions.
29 changes: 20 additions & 9 deletions examples/interfaces/ldc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@ def run_ldc(backend, compute_mlup=True):
precision_policy=precision_policy,
compute_backend=compute_backend,
)
full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper(
collision=collision,
equilibrium=equilibrium,
macroscopic=macroscopic,
stream=stream,
boundary_conditions=[equilibrium_bc, half_way_bc],
#boundary_conditions=[equilibrium_bc, half_way_bc, full_way_bc],
boundary_conditions=[half_way_bc, full_way_bc, equilibrium_bc],
)
planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker(
velocity_set=velocity_set,
Expand All @@ -148,70 +154,75 @@ def run_ldc(backend, compute_mlup=True):
)

# Set outlet bc (top x face)
lower_bound = (nr-1, 1, 1)
upper_bound = (nr-1, nr-1, nr-1)
lower_bound = (nr-1, 0, 0)
upper_bound = (nr-1, nr, nr)
direction = (-1, 0, 0)
boundary_id, missing_mask = planar_boundary_masker(
lower_bound,
upper_bound,
direction,
half_way_bc.id,
#full_way_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)

# Set half way bc (bottom y face)
lower_bound = (1, 0, 1)
lower_bound = (0, 0, 0)
upper_bound = (nr, 0, nr)
direction = (0, 1, 0)
boundary_id, missing_mask = planar_boundary_masker(
lower_bound,
upper_bound,
direction,
half_way_bc.id,
#full_way_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)

# Set half way bc (top y face)
lower_bound = (1, nr-1, 1)
lower_bound = (0, nr-1, 0)
upper_bound = (nr, nr-1, nr)
direction = (0, -1, 0)
boundary_id, missing_mask = planar_boundary_masker(
lower_bound,
upper_bound,
direction,
half_way_bc.id,
#full_way_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)

# Set half way bc (bottom z face)
lower_bound = (1, 1, 0)
lower_bound = (0, 0, 0)
upper_bound = (nr, nr, 0)
direction = (0, 0, 1)
boundary_id, missing_mask = planar_boundary_masker(
lower_bound,
upper_bound,
direction,
half_way_bc.id,
#full_way_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
)

# Set half way bc (top z face)
lower_bound = (1, 1, nr-1)
lower_bound = (0, 0, nr-1)
upper_bound = (nr, nr, nr-1)
direction = (0, 0, -1)
boundary_id, missing_mask = planar_boundary_masker(
lower_bound,
upper_bound,
direction,
half_way_bc.id,
#full_way_bc.id,
boundary_id,
missing_mask,
(0, 0, 0)
Expand All @@ -226,10 +237,10 @@ def run_ldc(backend, compute_mlup=True):
f0 = equilibrium(rho, u)

# Time stepping
plot_freq = 512
plot_freq = 128
save_dir = "ldc"
os.makedirs(save_dir, exist_ok=True)
num_steps = nr * 512
num_steps = nr * 16
start = time.time()

for _ in tqdm(range(num_steps)):
Expand Down
6 changes: 3 additions & 3 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import jit, device_count
from functools import partial
import numpy as np
from enum import Enum
from enum import Enum, auto

from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
Expand All @@ -16,8 +16,8 @@

# Enum for implementation step
class ImplementationStep(Enum):
COLLISION = 1
STREAMING = 2
COLLISION = auto()
STREAMING = auto()


class BoundaryCondition(Operator):
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/boundary_condition/halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
def apply_jax(self, f_pre, f_post, boundary_id, missing_mask):
boundary = boundary_id == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
return lax.select(missing_mask & boundary, f_pre[self.velocity_set.opp_indices], f_post)
return lax.select(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices], f_post)

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
Expand Down
36 changes: 18 additions & 18 deletions xlb/operator/boundary_masker/planar_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def jax_implementation(
if direction[0] != 0:

# Set boundary id
boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(id_number)
boundary_id = boundary_id.at[0, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(id_number)

# Set mask
for l in range(self.velocity_set.q):
Expand All @@ -57,13 +57,13 @@ def jax_implementation(
+ direction[2] * self.velocity_set.c[2, l]
)
if d_dot_c >= 0:
mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1] + 1, lower_bound[2] : upper_bound[2] + 1].set(True)
mask = mask.at[l, lower_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2] : upper_bound[2]].set(True)

# y plane
elif direction[1] != 0:

# Set boundary id
boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(id_number)
boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(id_number)

# Set mask
for l in range(self.velocity_set.q):
Expand All @@ -73,13 +73,13 @@ def jax_implementation(
+ direction[2] * self.velocity_set.c[2, l]
)
if d_dot_c >= 0:
mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1], lower_bound[2] : upper_bound[2] + 1].set(True)
mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1], lower_bound[2] : upper_bound[2]].set(True)

# z plane
elif direction[2] != 0:

# Set boundary id
boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(id_number)
boundary_id = boundary_id.at[0, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(id_number)

# Set mask
for l in range(self.velocity_set.q):
Expand All @@ -89,7 +89,7 @@ def jax_implementation(
+ direction[2] * self.velocity_set.c[2, l]
)
if d_dot_c >= 0:
mask = mask.at[l, lower_bound[0] : upper_bound[0] + 1, lower_bound[1] : upper_bound[1] + 1, lower_bound[2]].set(True)
mask = mask.at[l, lower_bound[0] : upper_bound[0], lower_bound[1] : upper_bound[1], lower_bound[2]].set(True)

return boundary_id, mask

Expand All @@ -116,15 +116,15 @@ def kernel(
# Get local indices
if direction[0] != 0:
i = lower_bound[0] - start_index[0]
j = plane_i - start_index[1]
k = plane_j - start_index[2]
j = plane_i + lower_bound[1] - start_index[1]
k = plane_j + lower_bound[2] - start_index[2]
elif direction[1] != 0:
i = plane_i - start_index[0]
i = plane_i + lower_bound[0] - start_index[0]
j = lower_bound[1] - start_index[1]
k = plane_j - start_index[2]
k = plane_j + lower_bound[2] - start_index[2]
elif direction[2] != 0:
i = plane_i - start_index[0]
j = plane_j - start_index[1]
i = plane_i + lower_bound[0] - start_index[0]
j = plane_j + lower_bound[1] - start_index[1]
k = lower_bound[2] - start_index[2]

# Check if in bounds
Expand Down Expand Up @@ -165,18 +165,18 @@ def warp_implementation(
# Get plane dimensions
if direction[0] != 0:
dim = (
upper_bound[1] - lower_bound[1] + 1,
upper_bound[2] - lower_bound[2] + 1,
upper_bound[1] - lower_bound[1],
upper_bound[2] - lower_bound[2],
)
elif direction[1] != 0:
dim = (
upper_bound[0] - lower_bound[0] + 1,
upper_bound[2] - lower_bound[2] + 1,
upper_bound[0] - lower_bound[0],
upper_bound[2] - lower_bound[2],
)
elif direction[2] != 0:
dim = (
upper_bound[0] - lower_bound[0] + 1,
upper_bound[1] - lower_bound[1] + 1,
upper_bound[0] - lower_bound[0],
upper_bound[1] - lower_bound[1],
)

# Launch the warp kernel
Expand Down
6 changes: 3 additions & 3 deletions xlb/operator/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def _streaming_jax_i(f, c):
"""
if self.velocity_set.d == 2:
return jnp.roll(
f, (-c[0], -c[1]), axis=(0, 1)
) # Negative sign is used to pull the distribution instead of pushing
f, (c[0], c[1]), axis=(0, 1)
)
elif self.velocity_set.d == 3:
return jnp.roll(f, (-c[0], -c[1], -c[2]), axis=(0, 1, 2))
return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2))

return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)(
f, jnp.array(self.velocity_set.c).T
Expand Down

0 comments on commit 18abd12

Please sign in to comment.