Skip to content

Commit

Permalink
Fixed 2d bugs in Warp backeneds in pytest, KBC and bc_regularized
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Oct 25, 2024
1 parent 00c24e3 commit e04f6eb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
8 changes: 4 additions & 4 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type='KBC')

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
Expand Down Expand Up @@ -100,7 +100,7 @@ def post_process(self, i):

fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude}

save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity")
# save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity")
save_image(fields["u_magnitude"], timestep=i, prefix="lid_driven_cavity")


Expand All @@ -112,7 +112,7 @@ def post_process(self, i):
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6
omega = 1.99

simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=5000, post_process_interval=1000)
simulation.run(num_steps=50000, post_process_interval=1000)
10 changes: 6 additions & 4 deletions tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
[test_bc],
bc_mask,
missing_mask,
start_index=(0, 0, 0) if dim == 3 else (0, 0),
)
assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype

Expand All @@ -69,9 +68,12 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
bc_mask = bc_mask.numpy()
missing_mask = missing_mask.numpy()

assert bc_mask.shape == (1,) + grid_shape

assert missing_mask.shape == (velocity_set.q,) + grid_shape
if len(grid_shape) == 2:
assert bc_mask.shape == (1,) + grid_shape + (1,), "bc_mask shape is incorrect got {}".format(bc_mask.shape)
assert missing_mask.shape == (velocity_set.q,) + grid_shape + (1,), "missing_mask shape is incorrect got {}".format(missing_mask.shape)
else:
assert bc_mask.shape == (1,) + grid_shape, "bc_mask shape is incorrect got {}".format(bc_mask.shape)
assert missing_mask.shape == (velocity_set.q,) + grid_shape, "missing_mask shape is incorrect got {}".format(missing_mask.shape)

if dim == 2:
assert np.all(bc_mask[0, indices[0], indices[1]] == test_bc.id)
Expand Down
11 changes: 8 additions & 3 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,14 @@ def _get_fsum(
def get_normal_vectors(
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
Expand Down
20 changes: 12 additions & 8 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def jax_implementation(
fneq = f - feq
if isinstance(self.velocity_set, D2Q9):
shear = self.decompose_shear_d2q9_jax(fneq)
delta_s = shear * rho / 4.0 # TODO: Check this
delta_s = shear * rho / 4.0
elif isinstance(self.velocity_set, D3Q27):
shear = self.decompose_shear_d3q27_jax(fneq)
delta_s = shear * rho
Expand Down Expand Up @@ -191,16 +191,16 @@ def _construct_warp(self):
@wp.func
def decompose_shear_d2q9(fneq: Any):
pi = self.momentum_flux.warp_functional(fneq)
N = pi[0] - pi[1]
N = pi[0] - pi[2]
s = _f_vec()
s[3] = N
s[6] = N
s[2] = -N
s[1] = -N
s[8] = pi[2]
s[4] = -pi[2]
s[5] = -pi[2]
s[7] = pi[2]
s[8] = pi[1]
s[4] = -pi[1]
s[5] = -pi[1]
s[7] = pi[1]
return s

# Construct functional for decomposing shear
Expand Down Expand Up @@ -271,8 +271,12 @@ def functional(
):
# Compute shear and delta_s
fneq = f - feq
shear = decompose_shear_d3q27(fneq)
delta_s = shear * rho # TODO: Check this
if wp.static(self.velocity_set.d == 3):
shear = decompose_shear_d3q27(fneq)
delta_s = shear * rho
else:
shear = decompose_shear_d2q9(fneq)
delta_s = shear * rho / self.compute_dtype(4.0)

# Perform collision
delta_h = fneq - delta_s
Expand Down

0 comments on commit e04f6eb

Please sign in to comment.