diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py index f4af5a8..3d60879 100644 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ b/xlb/operator/boundary_condition/bc_grads_approximation.py @@ -47,7 +47,7 @@ def __init__( indices=None, mesh_vertices=None, ): - + # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. self.u = (0, 0, 0) diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index f90ca60..065a0b0 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -62,8 +62,6 @@ def __init__( mesh_vertices, ) # Overwrite the boundary condition registry id with the bc_type in the name - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) - # The operator to compute the momentum flux self.momentum_flux = MomentumFlux() @partial(jit, static_argnums=(0,), inline=True) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index a1b79c2..4be2cf2 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -35,8 +35,6 @@ class ZouHeBC(BoundaryCondition): Reynolds numbers. One needs to use "Regularized" BC at higher Reynolds. """ - - def __init__( self, bc_type, @@ -64,9 +62,6 @@ def __init__( mesh_vertices, ) - # Overwrite the boundary condition registry id with the bc_type in the name - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + bc_type) - # Set the prescribed value for pressure or velocity dim = self.velocity_set.d if self.compute_backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 17f8226..6d72fc0 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -35,7 +35,7 @@ def __init__( indices=None, mesh_vertices=None, ): - self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__) + self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self))) velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy compute_backend = compute_backend or DefaultConfig.default_backend diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 5b1e092..6238fc5 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -23,6 +23,7 @@ def register_boundary_condition(self, boundary_condition): self.next_id += 1 self.id_to_bc[_id] = boundary_condition self.bc_to_id[boundary_condition] = _id + print(f"registered bc {boundary_condition} with id {_id}") return _id diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index b52b721..99431eb 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -100,65 +100,16 @@ def _construct_warp(self): bc_to_id = boundary_condition_registry.bc_to_id id_to_bc = boundary_condition_registry.id_to_bc + # Gather IDs of ExtrapolationOutflowBC boundary conditions + extrapolation_outflow_bc_ids = [] + for bc_name, bc_id in bc_to_id.items(): + if bc_name.startswith("ExtrapolationOutflowBC"): + extrapolation_outflow_bc_ids.append(bc_id) + # Group active boundary conditions active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) - for bc in self.boundary_conditions: - bc_name = id_to_bc[bc.id] - setattr(self, bc_name, bc) - - @wp.func - def apply_post_streaming_bc( - index: Any, - timestep: Any, - _boundary_id: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - f_result = f_post - - if wp.static("EquilibriumBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["EquilibriumBC"]): - f_result = self.EquilibriumBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("DoNothingBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["DoNothingBC"]): - f_result = self.DoNothingBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("HalfwayBounceBackBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["HalfwayBounceBackBC"]): - f_result = self.HalfwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ZouHeBC_pressure" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ZouHeBC_pressure"]): - f_result = self.ZouHeBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ZouHeBC_velocity" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ZouHeBC_velocity"]): - f_result = self.ZouHeBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("RegularizedBC_pressure" in active_bcs): - if _boundary_id == wp.static(bc_to_id["RegularizedBC_pressure"]): - f_result = self.RegularizedBC_pressure.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("RegularizedBC_velocity" in active_bcs): - if _boundary_id == wp.static(bc_to_id["RegularizedBC_velocity"]): - f_result = self.RegularizedBC_velocity.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ExtrapolationOutflowBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): - f_result = self.ExtrapolationOutflowBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("GradsApproximationBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["GradsApproximationBC"]): - f_result = self.GradsApproximationBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - return f_result - @wp.func - def apply_post_collision_bc( + def apply_bc( index: Any, timestep: Any, _boundary_id: Any, @@ -167,17 +118,25 @@ def apply_post_collision_bc( f_1: Any, f_pre: Any, f_post: Any, + is_post_streaming: bool, ): f_result = f_post - if wp.static("FullwayBounceBackBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["FullwayBounceBackBC"]): - f_result = self.FullwayBounceBackBC.warp_functional(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - - if wp.static("ExtrapolationOutflowBC" in active_bcs): - if _boundary_id == wp.static(bc_to_id["ExtrapolationOutflowBC"]): - f_result = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) - + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if is_post_streaming: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + else: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].prepare_bc_auxilary_data)( + index, timestep, missing_mask, f_0, f_1, f_pre, f_post + ) return f_result @wp.func @@ -244,7 +203,7 @@ def kernel2d( _f_post_collision = _f0_thread # Apply post-streaming boundary conditions - _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) # Compute rho and u _rho, _u = self.macroscopic.warp_functional(_f_post_stream) @@ -256,7 +215,7 @@ def kernel2d( _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision boundary conditions - _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) # Store the result in f_1 for l in range(self.velocity_set.q): @@ -284,17 +243,18 @@ def kernel3d( _f_post_collision = _f0_thread # Apply post-streaming boundary conditions - _f_post_stream = apply_post_streaming_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream) + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_collision, _f_post_stream, True) _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision boundary conditions - _f_post_collision = apply_post_collision_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision) + _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) # Store the result in f_1 for l in range(self.velocity_set.q): + # TODO: Improve this later if wp.static("GradsApproximationBC" in active_bcs): if _boundary_id == wp.static(boundary_condition_registry.bc_to_id["GradsApproximationBC"]): if _missing_mask[l] == wp.uint8(1):