Skip to content

Commit

Permalink
Fixed conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Oct 18, 2024
2 parents de67c09 + cee77b9 commit 3a19a42
Show file tree
Hide file tree
Showing 26 changed files with 124 additions and 130 deletions.
26 changes: 11 additions & 15 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
Expand Down Expand Up @@ -48,15 +48,12 @@ def _setup(self, omega):
self.setup_stepper(omega)

def define_boundary_indices(self):
inlet = self.grid.boundingBoxIndices["left"]
outlet = self.grid.boundingBoxIndices["right"]
walls = [
self.grid.boundingBoxIndices["bottom"][i]
+ self.grid.boundingBoxIndices["top"][i]
+ self.grid.boundingBoxIndices["front"][i]
+ self.grid.boundingBoxIndices["back"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_no_edge = self.grid.bounding_box_indices(remove_edges=True)
inlet = box_no_edge["left"]
outlet = box_no_edge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

sphere_radius = self.grid_shape[1] // 12
x = np.arange(self.grid_shape[0])
Expand All @@ -79,13 +76,12 @@ def setup_boundary_conditions(self):
# bc_outlet = DoNothingBC(indices=outlet)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)

self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls]
# Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because
# of the corner nodes. This way the corners are treated as wall and not inlet/outlet.
# TODO: how to ensure about this behind in the src code?
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand Down
19 changes: 11 additions & 8 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import xlb.velocity_set
import warp as wp
import jax.numpy as jnp
import xlb.velocity_set
import numpy as np


class LidDrivenCavity2D:
Expand Down Expand Up @@ -39,20 +40,22 @@ def _setup(self, omega):
self.setup_stepper(omega)

def define_boundary_indices(self):
lid = self.grid.boundingBoxIndices["top"]
walls = [
self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["left"][i] + self.grid.boundingBoxIndices["right"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_no_edge = self.grid.bounding_box_indices(remove_edges=True)
lid = box_no_edge["top"]
walls = [box["bottom"][i] + box["left"][i] + box["right"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()
return lid, walls

def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_top, bc_walls]
self.boundary_conditions = [bc_walls, bc_top]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _setup(self):

def define_boundary_indices(self):
# top and bottom sides of the channel are no-slip and the other directions are periodic
boundingBoxIndices = self.grid.bounding_box_indices(remove_edges=True)
walls = [boundingBoxIndices["bottom"][i] + boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)]
box = self.grid.bounding_box_indices(remove_edges=True)
walls = [box["bottom"][i] + box["top"][i] for i in range(self.velocity_set.d)]
return walls

def setup_boundary_conditions(self):
Expand Down
22 changes: 11 additions & 11 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
Expand Down Expand Up @@ -67,15 +67,12 @@ def voxelize_stl(self, stl_filename, length_lbm_unit):
return mesh_matrix, pitch

def define_boundary_indices(self):
inlet = self.grid.boundingBoxIndices["left"]
outlet = self.grid.boundingBoxIndices["right"]
walls = [
self.grid.boundingBoxIndices["bottom"][i]
+ self.grid.boundingBoxIndices["top"][i]
+ self.grid.boundingBoxIndices["front"][i]
+ self.grid.boundingBoxIndices["back"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_no_edge = self.grid.bounding_box_indices(remove_edges=True)
inlet = box_no_edge["left"]
outlet = box_no_edge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

# Load the mesh
stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl"
Expand Down Expand Up @@ -104,9 +101,12 @@ def setup_boundary_conditions(self):
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
bc_car = GradsApproximationBC(mesh_vertices=car)
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car]
self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand Down
13 changes: 4 additions & 9 deletions examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,10 @@ def create_grid_and_fields(cube_edge):


def define_boundary_indices(grid):
lid = grid.boundingBoxIndices["top"]
walls = [
grid.boundingBoxIndices["bottom"][i]
+ grid.boundingBoxIndices["left"][i]
+ grid.boundingBoxIndices["right"][i]
+ grid.boundingBoxIndices["front"][i]
+ grid.boundingBoxIndices["back"][i]
for i in range(len(grid.shape))
]
box = grid.bounding_box_indices()
box_no_edge = grid.bounding_box_indices(remove_edges=True)
lid = box_no_edge["top"]
walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))]
return lid, walls


Expand Down
2 changes: 1 addition & 1 deletion xlb/distribute/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .distribute import distribute as distribute
from .distribute import distribute
4 changes: 2 additions & 2 deletions xlb/experimental/ooc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from xlb.experimental.ooc.out_of_core import OOCmap as OOCmap
from xlb.experimental.ooc.ooc_array import OOCArray as OOCArray
from xlb.experimental.ooc.out_of_core import OOCmap
from xlb.experimental.ooc.ooc_array import OOCArray
1 change: 0 additions & 1 deletion xlb/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend):
self.shape = shape
self.dim = len(shape)
self.compute_backend = compute_backend
self.boundingBoxIndices = self.bounding_box_indices()
self._initialize_backend()

@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions xlb/helper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from xlb.helper.nse_solver import create_nse_fields as create_nse_fields
from xlb.helper.initializers import initialize_eq as initialize_eq
from xlb.helper.nse_solver import create_nse_fields
from xlb.helper.initializers import initialize_eq
from xlb.helper.check_boundary_overlaps import check_bc_overlaps
24 changes: 24 additions & 0 deletions xlb/helper/check_boundary_overlaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
from xlb.compute_backend import ComputeBackend


def check_bc_overlaps(bclist, dim, backend):
index_list = [[] for _ in range(dim)]
for bc in bclist:
if bc.indices is None:
continue
# Detect duplicates within bc.indices
index_arr = np.unique(bc.indices, axis=-1)
if index_arr.shape[-1] != len(bc.indices[0]):
if backend == ComputeBackend.WARP:
raise ValueError(f"Boundary condition {bc.__class__.__name__} has duplicate indices!")
print(f"WARNING: there are duplicate indices in {bc.__class__.__name__} and hence the order in bc list matters!")
for d in range(dim):
index_list[d] += bc.indices[d]

# Detect duplicates within bclist
index_arr = np.unique(index_list, axis=-1)
if index_arr.shape[-1] != len(index_list[0]):
if backend == ComputeBackend.WARP:
raise ValueError("Boundary condition list containes duplicate indices!")
print("WARNING: there are duplicate indices in the boundary condition list and hence the order in this list matters!")
4 changes: 2 additions & 2 deletions xlb/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from xlb.operator.operator import Operator as Operator
from xlb.operator.parallel_operator import ParallelOperator as ParallelOperator
from xlb.operator.operator import Operator
from xlb.operator.parallel_operator import ParallelOperator
22 changes: 10 additions & 12 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition as BoundaryCondition
from xlb.operator.boundary_condition.boundary_condition_registry import (
BoundaryConditionRegistry as BoundaryConditionRegistry,
)
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC as EquilibriumBC
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC
from xlb.operator.boundary_condition.bc_regularized import RegularizedBC as RegularizedBC
from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC as ExtrapolationOutflowBC
from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC as GradsApproximationBC
from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition
from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.bc_regularized import RegularizedBC
from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC
from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC
1 change: 0 additions & 1 deletion xlb/operator/boundary_condition/bc_grads_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
indices=None,
mesh_vertices=None,
):

# TODO: the input velocity must be suitably stored elesewhere when mesh is moving.
self.u = (0, 0, 0)

Expand Down
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from xlb import DefaultConfig
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry


# Enum for implementation step
class ImplementationStep(Enum):
COLLISION = auto()
Expand Down
12 changes: 3 additions & 9 deletions xlb/operator/boundary_masker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from xlb.operator.boundary_masker.indices_boundary_masker import (
IndicesBoundaryMasker as IndicesBoundaryMasker,
)
from xlb.operator.boundary_masker.mesh_boundary_masker import (
MeshBoundaryMasker as MeshBoundaryMasker,
)
from xlb.operator.boundary_masker.mesh_distance_boundary_masker import (
MeshDistanceBoundaryMasker as MeshDistanceBoundaryMasker,
)
from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker
from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
from xlb.operator.boundary_masker.mesh_distance_boundary_masker import MeshDistanceBoundaryMasker
35 changes: 10 additions & 25 deletions xlb/operator/boundary_masker/indices_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,17 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z])
return bc_mask, missing_mask


def _construct_warp(self):
# Make constants for warp
_c = self.velocity_set.c
_q = wp.constant(self.velocity_set.q)

@wp.func
def check_index_bounds(index: wp.vec3i, shape: wp.vec3i):
is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2]
return is_in_bounds

# Construct the warp 3D kernel
@wp.kernel
def kernel(
Expand All @@ -118,14 +124,8 @@ def kernel(
index[2] = indices[2, ii] - start_index[2]

# Check if index is in bounds
if (
index[0] >= 0
and index[0] < missing_mask.shape[1]
and index[1] >= 0
and index[1] < missing_mask.shape[2]
and index[2] >= 0
and index[2] < missing_mask.shape[3]
):
shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3])
if check_index_bounds(index, shape):
# Stream indices
for l in range(_q):
# Get the index of the streaming direction
Expand All @@ -140,27 +140,12 @@ def kernel(

# check if pull index is out of bound
# These directions will have missing information after streaming
if (
pull_index[0] < 0
or pull_index[0] >= missing_mask.shape[1]
or pull_index[1] < 0
or pull_index[1] >= missing_mask.shape[2]
or pull_index[2] < 0
or pull_index[2] >= missing_mask.shape[3]
):
if not check_index_bounds(pull_index, shape):
# Set the missing mask
missing_mask[l, index[0], index[1], index[2]] = True

# handling geometries in the interior of the computational domain
elif (
is_interior[ii]
and push_index[0] >= 0
and push_index[0] < missing_mask.shape[1]
and push_index[1] >= 0
and push_index[1] < missing_mask.shape[2]
and push_index[2] >= 0
and push_index[2] < missing_mask.shape[3]
):
elif check_index_bounds(pull_index, shape) and is_interior[ii]:
# Set the missing mask
missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii]
Expand Down
8 changes: 4 additions & 4 deletions xlb/operator/collision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xlb.operator.collision.collision import Collision as Collision
from xlb.operator.collision.bgk import BGK as BGK
from xlb.operator.collision.kbc import KBC as KBC
from xlb.operator.collision.forced_collision import ForcedCollision as ForcedCollision
from xlb.operator.collision.collision import Collision
from xlb.operator.collision.bgk import BGK
from xlb.operator.collision.kbc import KBC
from xlb.operator.collision.forced_collision import ForcedCollision
4 changes: 2 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def decompose_shear_d2q9_jax(self, fneq):

def _construct_warp(self):
# Raise error if velocity set is not supported
if not isinstance(self.velocity_set, D3Q27):
if not (isinstance(self.velocity_set, D3Q27) or isinstance(self.velocity_set, D2Q9)):
raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set)))

# Set local constants TODO: This is a hack and should be fixed with warp update
Expand All @@ -192,7 +192,7 @@ def _construct_warp(self):
def decompose_shear_d2q9(fneq: Any):
pi = self.momentum_flux.warp_functional(fneq)
N = pi[0] - pi[1]
s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
s = _f_vec()
s[3] = N
s[6] = N
s[2] = -N
Expand Down
5 changes: 1 addition & 4 deletions xlb/operator/equilibrium/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
from xlb.operator.equilibrium.quadratic_equilibrium import (
Equilibrium as Equilibrium,
QuadraticEquilibrium as QuadraticEquilibrium,
)
from xlb.operator.equilibrium.quadratic_equilibrium import Equilibrium, QuadraticEquilibrium
4 changes: 2 additions & 2 deletions xlb/operator/force/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from xlb.operator.force.momentum_transfer import MomentumTransfer as MomentumTransfer
from xlb.operator.force.exact_difference_force import ExactDifference as ExactDifference
from xlb.operator.force.momentum_transfer import MomentumTransfer
from xlb.operator.force.exact_difference_force import ExactDifference
2 changes: 1 addition & 1 deletion xlb/operator/precision_caster/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from xlb.operator.precision_caster.precision_caster import PrecisionCaster as PrecisionCaster
from xlb.operator.precision_caster.precision_caster import PrecisionCaster
4 changes: 2 additions & 2 deletions xlb/operator/stepper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from xlb.operator.stepper.stepper import Stepper as Stepper
from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper as IncompressibleNavierStokesStepper
from xlb.operator.stepper.stepper import Stepper
from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper
2 changes: 1 addition & 1 deletion xlb/operator/stream/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from xlb.operator.stream.stream import Stream as Stream
from xlb.operator.stream.stream import Stream
Loading

0 comments on commit 3a19a42

Please sign in to comment.