From 94fd5f0ed409b4fb4996b849616c31867668caa1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 8 Apr 2024 09:57:52 -0700 Subject: [PATCH] added kbc --- examples/CFD_refactor/windtunnel3d.py | 592 ++++++++++++++---- xlb/operator/boundary_masker/__init__.py | 3 + .../boundary_masker/stl_boundary_masker.py | 126 ++-- xlb/operator/collision/kbc.py | 178 +++++- xlb/velocity_set/velocity_set.py | 13 +- 5 files changed, 716 insertions(+), 196 deletions(-) diff --git a/examples/CFD_refactor/windtunnel3d.py b/examples/CFD_refactor/windtunnel3d.py index 9af75f3..156a219 100644 --- a/examples/CFD_refactor/windtunnel3d.py +++ b/examples/CFD_refactor/windtunnel3d.py @@ -1,168 +1,512 @@ +# Wind tunnel simulation using the XLB library + +from typing import Any import os import jax import trimesh from time import time import numpy as np -import jax.numpy as jnp -from jax import config +import warp as wp +import pyvista as pv +import tqdm +import matplotlib.pyplot as plt + +wp.init() + +import xlb +from xlb.operator import Operator + +class UniformInitializer(Operator): + + def _construct_warp(self): + # Construct the warp kernel + @wp.kernel + def kernel( + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + vel: float, + ): + # Get the global index + i, j, k = wp.tid() + + # Set the velocity + u[0, i, j, k] = vel + u[1, i, j, k] = 0.0 + u[2, i, j, k] = 0.0 + + # Set the density + rho[0, i, j, k] = 1.0 + + return None, kernel + + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, rho, u, vel): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + rho, + u, + vel, + ], + dim=rho.shape[1:], + ) + return rho, u + +class MomentumTransfer(Operator): + + def _construct_warp(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.wp_c + _opp_indices = self.velocity_set.wp_opp_indices + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec( + self.velocity_set.q, dtype=wp.uint8 + ) # TODO fix vec bool + + # Find velocity index for 0, 0, 0 + for l in range(self.velocity_set.q): + if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: + zero_index = l + _zero_index = wp.int32(zero_index) + print(f"Zero index: {_zero_index}") + + # Construct the warp kernel + @wp.kernel + def kernel( + f: wp.array4d(dtype=Any), + boundary_id: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.bool), + momentum: wp.array(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Get the boundary id + _boundary_id = boundary_id[0, index[0], index[1], index[2]] + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # TODO fix vec bool + if missing_mask[l, index[0], index[1], index[2]]: + _missing_mask[l] = wp.uint8(1) + else: + _missing_mask[l] = wp.uint8(0) + + # Determin if boundary is an edge by checking if center is missing + is_edge = wp.bool(False) + if _boundary_id == wp.uint8(xlb.operator.boundary_condition.HalfwayBounceBackBC.id): + if _missing_mask[_zero_index] != wp.uint8(1): + is_edge = wp.bool(True) + + # If the boundary is an edge then add the momentum transfer + m = wp.vec3() + if is_edge: + for l in range(self.velocity_set.q): + if _missing_mask[l] == wp.uint8(1): + phi = 2.0 * f[_opp_indices[l], index[0], index[1], index[2]] + + # Compute the momentum transfer + for d in range(self.velocity_set.d): + m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) -from xlb.solver import IncompressibleNavierStokesSolver -from xlb.velocity_set import D3Q27, D3Q19 -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy -from xlb.grid_backend import GridBackend -from xlb.operator.boundary_condition import BounceBack, BounceBackHalfway, DoNothing, EquilibriumBC + wp.atomic_add(momentum, 0, m) + return None, kernel + @Operator.register_backend(xlb.ComputeBackend.WARP) + def warp_implementation(self, f, boundary_id, missing_mask): -class WindTunnel(IncompressibleNavierStokesSolver): + # Allocate the momentum field + momentum = wp.zeros((1), dtype=wp.vec3) + + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f, boundary_id, missing_mask, momentum], + dim=f.shape[1:], + ) + return momentum.numpy() + + +class WindTunnel: """ - This class extends the IncompressibleNavierStokesSolver class to define the boundary conditions for the wind tunnel simulation. - Units are in meters, seconds, and kilograms. + Wind tunnel simulation using the XLB library """ def __init__( self, - stl_filename: str - stl_center: tuple[float, float, float] = (0.0, 0.0, 0.0), # m - inlet_velocity: float = 27.78 # m/s + stl_filename: str, + inlet_velocity: float = 27.78, # m/s lower_bounds: tuple[float, float, float] = (0.0, 0.0, 0.0), # m upper_bounds: tuple[float, float, float] = (1.0, 0.5, 0.5), # m dx: float = 0.01, # m viscosity: float = 1.42e-5, # air at 20 degrees Celsius density: float = 1.2754, # kg/m^3 - collision="BGK", + solve_time: float = 1.0, # s + #collision="BGK", + collision="KBC", equilibrium="Quadratic", - velocity_set=D3Q27(), - precision_policy=PrecisionPolicy.FP32FP32, - compute_backend=ComputeBackend.JAX, - grid_backend=GridBackend.JAX, + velocity_set="D3Q27", + precision_policy=xlb.PrecisionPolicy.FP32FP32, + compute_backend=xlb.ComputeBackend.WARP, grid_configs={}, + save_state_frequency=1024, + monitor_frequency=32, ): # Set parameters + self.stl_filename = stl_filename self.inlet_velocity = inlet_velocity self.lower_bounds = lower_bounds self.upper_bounds = upper_bounds self.dx = dx + self.solve_time = solve_time self.viscosity = viscosity self.density = density + self.save_state_frequency = save_state_frequency + self.monitor_frequency = monitor_frequency # Get fluid properties needed for the simulation - self.velocity_conversion = 0.05 / inlet_velocity + self.base_velocity = 0.05 # LBM units + self.velocity_conversion = self.base_velocity / inlet_velocity self.dt = self.dx * self.velocity_conversion self.lbm_viscosity = self.viscosity * self.dt / (self.dx ** 2) self.tau = 0.5 + self.lbm_viscosity + self.omega = 1.0 / self.tau + print(f"tau: {self.tau}") + print(f"omega: {self.omega}") self.lbm_density = 1.0 self.mass_conversion = self.dx ** 3 * (self.density / self.lbm_density) + self.nr_steps = int(solve_time / self.dt) - # Make boundary conditions - - - # Initialize the IncompressibleNavierStokesSolver - super().__init__( - omega=self.tau, - shape=shape, - collision=collision, - equilibrium=equilibrium, - boundary_conditions=boundary_conditions, - initializer=initializer, - forcing=forcing, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - grid_backend=grid_backend, - grid_configs=grid_configs, - ) - - def voxelize_stl(self, stl_filename, length_lbm_unit): - mesh = trimesh.load_mesh(stl_filename, process=False) - length_phys_unit = mesh.extents.max() - pitch = length_phys_unit/length_lbm_unit - mesh_voxelized = mesh.voxelized(pitch=pitch) - mesh_matrix = mesh_voxelized.matrix - return mesh_matrix, pitch - - def set_boundary_conditions(self): - print('Voxelizing mesh...') - time_start = time() - stl_filename = 'stl-files/DrivAer-Notchback.stl' - car_length_lbm_unit = self.nx / 4 - car_voxelized, pitch = voxelize_stl(stl_filename, car_length_lbm_unit) - car_matrix = car_voxelized.matrix - print('Voxelization time for pitch={}: {} seconds'.format(pitch, time() - time_start)) - print("Car matrix shape: ", car_matrix.shape) - - self.car_area = np.prod(car_matrix.shape[1:]) - tx, ty, tz = np.array([nx, ny, nz]) - car_matrix.shape - shift = [tx//4, ty//2, 0] - car_indices = np.argwhere(car_matrix) + shift - self.BCs.append(BounceBackHalfway(tuple(car_indices.T), self.gridInfo, self.precisionPolicy)) - - wall = np.concatenate((self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'], - self.boundingBoxIndices['front'], self.boundingBoxIndices['back'])) - self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - - doNothing = self.boundingBoxIndices['right'] - self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) - self.BCs[-1].implementationStep = 'PostCollision' - # rho_outlet = np.ones(doNothing.shape[0], dtype=self.precisionPolicy.compute_dtype) - # self.BCs.append(ZouHe(tuple(doNothing.T), - # self.gridInfo, - # self.precisionPolicy, - # 'pressure', rho_outlet)) - - inlet = self.boundingBoxIndices['left'] - rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype) - vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype) - - vel_inlet[:, 0] = prescribed_vel - self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) - # self.BCs.append(ZouHe(tuple(inlet.T), - # self.gridInfo, - # self.precisionPolicy, - # 'velocity', vel_inlet)) - - def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][..., 1:-1, 1:-1, :]) - u = np.array(kwargs['u'][..., 1:-1, 1:-1, :]) - timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][..., 1:-1, 1:-1, :] - - # compute lift and drag over the car - car = self.BCs[0] - boundary_force = car.momentum_exchange_force(kwargs['f_poststreaming'], kwargs['f_postcollision']) - boundary_force = np.sum(boundary_force, axis=0) - drag = np.sqrt(boundary_force[0]**2 + boundary_force[1]**2) #xy-plane - lift = boundary_force[2] #z-direction - cd = 2. * drag / (prescribed_vel ** 2 * self.car_area) - cl = 2. * lift / (prescribed_vel ** 2 * self.car_area) - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) - - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}, CL = {:07.6f}, CD = {:07.6f}'.format(err, cl, cd)) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) + # Get the grid shape + self.nx = int((upper_bounds[0] - lower_bounds[0]) / dx) + self.ny = int((upper_bounds[1] - lower_bounds[1]) / dx) + self.nz = int((upper_bounds[2] - lower_bounds[2]) / dx) + self.shape = (self.nx, self.ny, self.nz) -if __name__ == '__main__': - precision = 'f32/f32' - lattice = LatticeD3Q27(precision) + # Set the compute backend + self.compute_backend = xlb.ComputeBackend.WARP + + # Set the precision policy + self.precision_policy = xlb.PrecisionPolicy.FP32FP32 + + # Set the velocity set + if velocity_set == "D3Q27": + self.velocity_set = xlb.velocity_set.D3Q27() + elif velocity_set == "D3Q19": + self.velocity_set = xlb.velocity_set.D3Q19() + else: + raise ValueError("Invalid velocity set") + + # Make grid + self.grid = xlb.grid.WarpGrid(shape=self.shape) + + # Make feilds + self.rho = self.grid.create_field(cardinality=1, precision=xlb.Precision.FP32) + self.u = self.grid.create_field(cardinality=self.velocity_set.d, precision=xlb.Precision.FP32) + self.f0 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) + self.f1 = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.FP32) + self.boundary_id = self.grid.create_field(cardinality=1, precision=xlb.Precision.UINT8) + self.missing_mask = self.grid.create_field(cardinality=self.velocity_set.q, precision=xlb.Precision.BOOL) + + # Make operators + self.initializer = UniformInitializer( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.momentum_transfer = MomentumTransfer( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + if collision == "BGK": + self.collision = xlb.operator.collision.BGK( + omega=self.omega, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + elif collision == "KBC": + self.collision = xlb.operator.collision.KBC( + omega=self.omega, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.macroscopic = xlb.operator.macroscopic.Macroscopic( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stream = xlb.operator.stream.Stream( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.equilibrium_bc = xlb.operator.boundary_condition.EquilibriumBC( + rho=self.lbm_density, + u=(self.base_velocity, 0.0, 0.0), + equilibrium_operator=self.equilibrium, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.half_way_bc = xlb.operator.boundary_condition.HalfwayBounceBackBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.full_way_bc = xlb.operator.boundary_condition.FullwayBounceBackBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.do_nothing_bc = xlb.operator.boundary_condition.DoNothingBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( + collision=self.collision, + equilibrium=self.equilibrium, + macroscopic=self.macroscopic, + stream=self.stream, + boundary_conditions=[ + self.half_way_bc, + self.full_way_bc, + self.equilibrium_bc, + self.do_nothing_bc + ], + ) + self.planar_boundary_masker = xlb.operator.boundary_masker.PlanarBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + self.stl_boundary_masker = xlb.operator.boundary_masker.STLBoundaryMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + + # Make list to store drag coefficients + self.drag_coefficients = [] + + def initialize_flow(self): + """ + Initialize the flow field + """ + + # Set initial conditions + self.rho, self.u = self.initializer(self.rho, self.u, self.base_velocity) + self.f0 = self.equilibrium(self.rho, self.u, self.f0) + + def initialize_boundary_conditions(self): + """ + Initialize the boundary conditions + """ + + # Set inlet bc (bottom x face) + lower_bound = (0, 1, 1) # no edges + upper_bound = (0, self.ny-1, self.nz-1) + direction = (1, 0, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.equilibrium_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set outlet bc (top x face) + lower_bound = (self.nx-1, 1, 1) + upper_bound = (self.nx-1, self.ny-1, self.nz-1) + direction = (-1, 0, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.do_nothing_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (bottom y face) + lower_bound = (0, 0, 0) + upper_bound = (self.nx, 0, self.nz) + direction = (0, 1, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (top y face) + lower_bound = (0, self.ny-1, 0) + upper_bound = (self.nx, self.ny-1, self.nz) + direction = (0, -1, 0) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (bottom z face) + lower_bound = (0, 0, 0) + upper_bound = (self.nx, self.ny, 0) + direction = (0, 0, 1) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set full way bc (top z face) + lower_bound = (0, 0, self.nz-1) + upper_bound = (self.nx, self.ny, self.nz-1) + direction = (0, 0, -1) + self.boundary_id, self.missing_mask = self.planar_boundary_masker( + lower_bound, + upper_bound, + direction, + self.full_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + # Set stl half way bc + self.boundary_id, self.missing_mask = self.stl_boundary_masker( + self.stl_filename, + self.lower_bounds, + (self.dx, self.dx, self.dx), + self.half_way_bc.id, + self.boundary_id, + self.missing_mask, + (0, 0, 0) + ) + + def save_state( + self, + postfix: str, + save_velocity_distribution: bool = False, + ): + """ + Save the solid id array. + """ + + # Create grid + grid = pv.RectilinearGrid( + np.linspace(self.lower_bounds[0], self.upper_bounds[0], self.nx, endpoint=False), + np.linspace(self.lower_bounds[1], self.upper_bounds[1], self.ny, endpoint=False), + np.linspace(self.lower_bounds[2], self.upper_bounds[2], self.nz, endpoint=False), + ) # TODO off by one? + grid["boundary_id"] = self.boundary_id.numpy().flatten("F") + grid["u"] = self.u.numpy().transpose(1, 2, 3, 0).reshape(-1, 3, order="F") + grid["rho"] = self.rho.numpy().flatten("F") + if save_velocity_distribution: + grid["f0"] = self.f0.numpy().transpose(1, 2, 3, 0).reshape(-1, self.velocity_set.q, order="F") + grid.save(f"state_{postfix}.vtk") + + def step(self): + self.f1 = self.stepper(self.f0, self.f1, self.boundary_id, self.missing_mask, 0) + self.f0, self.f1 = self.f1, self.f0 + + def compute_rho_u(self): + self.rho, self.u = self.macroscopic(self.f0, self.rho, self.u) + + def monitor(self): + # Compute the momentum transfer + momentum = self.momentum_transfer(self.f0, self.boundary_id, self.missing_mask)[0] + drag = momentum[0] + lift = momentum[2] + c_d = 2.0 * drag / (self.base_velocity ** 2 * self.cross_section) + c_l = 2.0 * lift / (self.base_velocity ** 2 * self.cross_section) + self.drag_coefficients.append(c_d) + + def plot_drag_coefficient(self): + plt.plot(self.drag_coefficients[-30:]) + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.savefig("drag_coefficient.png") + plt.close() + + def run(self): - nx = 601 - ny = 351 - nz = 251 + # Initialize the flow field + self.initialize_flow() + + # Initialize the boundary conditions + self.initialize_boundary_conditions() + + # Compute cross section + np_boundary_id = self.boundary_id.numpy() + cross_section = np.sum(np_boundary_id == self.half_way_bc.id, axis=(0, 1)) + self.cross_section = np.sum(cross_section > 0) + + # Run the simulation + for i in tqdm.tqdm(range(self.nr_steps)): + + # Step + self.step() + + # Monitor + if i % self.monitor_frequency == 0: + self.monitor() + + # Save monitor plot + if i % (self.monitor_frequency * 10) == 0: + self.plot_drag_coefficient() + + # Save state + if i % self.save_state_frequency == 0: + self.compute_rho_u() + self.save_state(str(i).zfill(8)) + +if __name__ == '__main__': - Re = 50000.0 - prescribed_vel = 0.05 - clength = nx - 1 + # Parameters + inlet_velocity = 0.01 # m/s + stl_filename = "fastback_baseline.stl" + lower_bounds = (-4.0, -2.5, -1.5) + upper_bounds = (12.0, 2.5, 2.5) + dx = 0.03 + solve_time = 10000.0 - visc = prescribed_vel * clength / Re - omega = 1.0 / (3. * visc + 0.5) + # Make wind tunnel + wind_tunnel = WindTunnel( + stl_filename=stl_filename, + inlet_velocity=inlet_velocity, + lower_bounds=lower_bounds, + upper_bounds=upper_bounds, + solve_time=solve_time, + dx=dx, + ) - os.system('rm -rf ./*.vtk && rm -rf ./*.png') + # Run the simulation + wind_tunnel.run() + wind_tunnel.save_state("final", save_velocity_distribution=True) - sim = Car(**kwargs) - sim.run(200000) diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index a9069c6..f69252f 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -7,3 +7,6 @@ from xlb.operator.boundary_masker.planar_boundary_masker import ( PlanarBoundaryMasker, ) +from xlb.operator.boundary_masker.stl_boundary_masker import ( + STLBoundaryMasker, +) diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index cda8c00..148e9b8 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -2,6 +2,7 @@ from functools import partial import numpy as np +from stl import mesh as np_mesh import jax.numpy as jnp from jax import jit import warp as wp @@ -26,89 +27,106 @@ def __init__( precision_policy: PrecisionPolicy, compute_backend: ComputeBackend.JAX, ): + # Call super super().__init__(velocity_set, precision_policy, compute_backend) - # TODO: Implement this - raise NotImplementedError - - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - - @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation( - self, mesh, id_number, boundary_id, mask, start_index=(0, 0, 0) - ): - # TODO: Implement this - raise NotImplementedError - def _construct_warp(self): # Make constants for warp - _opp_indices = wp.constant( - self._warp_int_lattice_vec(self.velocity_set.opp_indices) - ) + _c = self.velocity_set.wp_c _q = wp.constant(self.velocity_set.q) - _d = wp.constant(self.velocity_set.d) - _id = wp.constant(self.id) # Construct the warp kernel @wp.kernel - def _voxelize_mesh( - voxels: wp.array3d(dtype=wp.uint8), + def kernel( mesh: wp.uint64, - spacing: wp.vec3, origin: wp.vec3, - shape: wp.vec(3, wp.uint32), - max_length: float, - material_id: int, + spacing: wp.vec3, + id_number: wp.int32, + boundary_id: wp.array4d(dtype=wp.uint8), + mask: wp.array4d(dtype=wp.bool), + start_index: wp.vec3i, ): - # get index of voxel + # get index i, j, k = wp.tid() - # position of voxel - ijk = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) + # Get local indices + index = wp.vec3i() + index[0] = i - start_index[0] + index[1] = j - start_index[1] + index[2] = k - start_index[2] + + # position of the point + ijk = wp.vec3( + wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]) + ) ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center pos = wp.cw_mul(ijk, spacing) + origin - # Only evaluate voxel if not set yet - if voxels[i, j, k] != wp.uint8(0): - return + # Compute the maximum length + max_length = wp.sqrt( + (spacing[0] * wp.float32(boundary_id.shape[1])) ** 2.0 + + (spacing[1] * wp.float32(boundary_id.shape[2])) ** 2.0 + + (spacing[2] * wp.float32(boundary_id.shape[3])) ** 2.0 + ) - # evaluate distance of point + # evaluate if point is inside mesh face_index = int(0) face_u = float(0.0) face_v = float(0.0) sign = float(0.0) - if wp.mesh_query_point( + if wp.mesh_query_point_sign_winding_number( mesh, pos, max_length, sign, face_index, face_u, face_v ): - p = wp.mesh_eval_position(mesh, face_index, face_u, face_v) - delta = pos - p - norm = wp.sqrt(wp.dot(delta, delta)) - # set point to be solid - if norm < wp.min(spacing): - voxels[i, j, k] = wp.uint8(255) - elif sign < 0: # TODO: fix this - voxels[i, j, k] = wp.uint8(material_id) - else: - pass + if sign <= 0: # TODO: fix this + # Stream indices + for l in range(_q): + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, l] + + # Set the boundary id and mask + boundary_id[ + 0, push_index[0], push_index[1], push_index[2] + ] = wp.uint8(id_number) + mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, start_index, boundary_id, mask, id_number): - # Reuse the jax implementation, TODO: implement a warp version - # Convert to jax - boundary_id = wp.jax.to_jax(boundary_id) - mask = wp.jax.to_jax(mask) - - # Call jax implementation - boundary_id, mask = self.jax_implementation( - start_index, boundary_id, mask, id_number + def warp_implementation( + self, + stl_file, + origin, + spacing, + id_number, + boundary_id, + mask, + start_index=(0, 0, 0), + ): + # Load the mesh + mesh = np_mesh.Mesh.from_file(stl_file) + mesh_points = mesh.points.reshape(-1, 3) + mesh_indices = np.arange(mesh_points.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_points, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=int), ) - # Convert back to warp - boundary_id = wp.jax.to_warp(boundary_id) - mask = wp.jax.to_warp(mask) + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[ + mesh.id, + origin, + spacing, + id_number, + boundary_id, + mask, + start_index, + ], + dim=boundary_id.shape[1:], + ) return boundary_id, mask diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 8302978..f3c996b 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -4,11 +4,14 @@ import jax.numpy as jnp from jax import jit -from functools import partial -from xlb.operator import Operator +import warp as wp +from typing import Any + from xlb.velocity_set import VelocitySet, D2Q9, D3Q27 from xlb.compute_backend import ComputeBackend from xlb.operator.collision.collision import Collision +from xlb.operator import Operator +from functools import partial class KBC(Collision): @@ -25,15 +28,16 @@ def __init__( precision_policy=None, compute_backend=None, ): + self.epsilon = 1e-32 + self.beta = omega * 0.5 + self.inv_beta = 1.0 / self.beta + super().__init__( omega=omega, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) - self.epsilon = 1e-32 - self.beta = self.omega * 0.5 - self.inv_beta = 1.0 / self.beta @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) @@ -80,16 +84,6 @@ def jax_implementation( return fout - @Operator.register_backend(ComputeBackend.WARP) - @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) - def warp_implementation( - self, - f: jnp.ndarray, - feq: jnp.ndarray, - rho: jnp.ndarray, - ): - raise NotImplementedError("Warp implementation not yet implemented") - @partial(jit, static_argnums=(0,), inline=True) def entropic_scalar_product(self, x: jnp.ndarray, y: jnp.ndarray, feq: jnp.ndarray): """ @@ -208,3 +202,157 @@ def decompose_shear_d2q9_jax(self, fneq): s = s.at[7, ...].set(Pi[1, ...]) return s + + def _construct_warp(self): + # Raise error if velocity set is not supported + if not isinstance(self.velocity_set, D3Q27): + raise NotImplementedError( + "Velocity set not supported for warp backend: {}".format( + type(self.velocity_set) + ) + ) + + # Set local constants TODO: This is a hack and should be fixed with warp update + _w = self.velocity_set.wp_w + _cc = self.velocity_set.wp_cc + _omega = wp.constant(self.compute_dtype(self.omega)) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _pi_vec = wp.vec( + self.velocity_set.d * (self.velocity_set.d + 1) // 2, + dtype=self.compute_dtype, + ) + _epsilon = wp.constant(self.compute_dtype(self.epsilon)) + _beta = wp.constant(self.compute_dtype(self.beta)) + _inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta)) + + # Construct functional for computing momentum flux + @wp.func + def momentum_flux( + fneq: Any, + ): + # Get momentum flux + pi = _pi_vec() + for d in range(6): + pi[d] = 0.0 + for q in range(self.velocity_set.q): + pi[d] += _cc[q, d] * fneq[q] + return pi + + # Construct functional for decomposing shear + @wp.func + def decompose_shear_d3q27( + fneq: Any, + ): + # Get momentum flux + pi = momentum_flux(fneq) + nxz = pi[0] - pi[5] + nyz = pi[3] - pi[5] + + # set shear components + s = _f_vec() + + # For c = (i, 0, 0), c = (0, j, 0) and c = (0, 0, k) + s[9] = (2.0 * nxz - nyz) / 6.0 + s[18] = (2.0 * nxz - nyz) / 6.0 + s[3] = (-nxz + 2.0 * nyz) / 6.0 + s[6] = (-nxz + 2.0 * nyz) / 6.0 + s[1] = (-nxz - nyz) / 6.0 + s[2] = (-nxz - nyz) / 6.0 + + # For c = (i, j, 0) + s[12] = pi[1] / 4.0 + s[24] = pi[1] / 4.0 + s[21] = -pi[1] / 4.0 + s[15] = -pi[1] / 4.0 + + # For c = (i, 0, k) + s[10] = pi[2] / 4.0 + s[20] = pi[2] / 4.0 + s[19] = -pi[2] / 4.0 + s[11] = -pi[2] / 4.0 + + # For c = (0, j, k) + s[8] = pi[4] / 4.0 + s[4] = pi[4] / 4.0 + s[7] = -pi[4] / 4.0 + s[5] = -pi[4] / 4.0 + + return s + + # Construct functional for computing entropic scalar product + @wp.func + def entropic_scalar_product( + x: Any, + y: Any, + feq: Any, + ): + e = wp.cw_div(wp.cw_mul(x, y), feq) + e_sum = wp.float32(0.0) + for i in range(self.velocity_set.q): + e_sum += e[i] + return e_sum + + # Construct the functional + @wp.func + def functional( + f: Any, + feq: Any, + rho: Any, + u: Any, + ): + # Compute shear and delta_s + fneq = f - feq + shear = decompose_shear_d3q27(fneq) + delta_s = shear * rho # TODO: Check this + + # Perform collision + delta_h = fneq - delta_s + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( + delta_s, delta_h, feq + ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + fout = f - _beta * (2.0 * delta_s + gamma * delta_h) + + return fout + + # Construct the warp kernel + @wp.kernel + def kernel( + f: wp.array4d(dtype=Any), + feq: wp.array4d(dtype=Any), + rho: wp.array4d(dtype=Any), + u: wp.array4d(dtype=Any), + fout: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) # TODO: Warp needs to fix this + + # Load needed values + _f = _f_vec() + _feq = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = f[l, index[0], index[1], index[2]] + _feq[l] = feq[l, index[0], index[1], index[2]] + _u = self._warp_u_vec() + for l in range(_d): + _u[l] = u[l, index[0], index[1], index[2]] + _rho = rho[0, index[0], index[1], index[2]] + + # Compute the collision + _fout = functional(_f, _feq, _rho, _u) + + # Write the result + for l in range(self.velocity_set.q): + fout[l, index[0], index[1], index[2]] = _fout[l] + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + def warp_implementation( + self, + f: jnp.ndarray, + feq: jnp.ndarray, + rho: jnp.ndarray, + ): + raise NotImplementedError("Warp implementation not yet implemented") diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index a137b87..03395c8 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -8,6 +8,7 @@ import warp as wp + class VelocitySet(object): """ Base class for the velocity set of the Lattice Boltzmann Method (LBM), e.g. D2Q9, D3Q27, etc. @@ -46,9 +47,15 @@ def __init__(self, d, q, c, w): # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) - + self.wp_w = wp.constant( + wp.vec(self.q, dtype=wp.float32)(self.w) + ) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant( + wp.vec(self.q, dtype=wp.int32)(self.opp_indices) + ) + self.wp_cc = wp.constant( + wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc) + ) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype)