diff --git a/README.md b/README.md index 6bb7727..dcb67ee 100644 --- a/README.md +++ b/README.md @@ -10,18 +10,21 @@ XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that ## Accompanying Paper -Please refer to the [accompanying paper](https://arxiv.org/abs/2311.16080) for benchmarks, validation, and more details about the library. +Please refer to the [accompanying paper](https://doi.org/10.1016/j.cpc.2024.109187) for benchmarks, validation, and more details about the library. ## Citing XLB If you use XLB in your research, please cite the following paper: ``` -@article{ataei2023xlb, - title={{XLB}: A Differentiable Massively Parallel Lattice Boltzmann Library in Python}, - author={Ataei, Mohammadmehdi and Salehipour, Hesam}, - journal={arXiv preprint arXiv:2311.16080}, - year={2023}, +@article{ataei2024xlb, + title={{XLB}: A differentiable massively parallel lattice {Boltzmann} library in {Python}}, + author={Ataei, Mohammadmehdi and Salehipour, Hesam}, + journal={Computer Physics Communications}, + volume={300}, + pages={109187}, + year={2024}, + publisher={Elsevier} } ``` @@ -153,4 +156,47 @@ git clone https://github.com/Autodesk/XLB cd XLB export PYTHONPATH=. python3 examples/CFD/cavity2d.py -``` \ No newline at end of file +``` +## Roadmap + +### Work in Progress (WIP) +*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.* + +- 🚀 **Warp Backend:** Achieving state-of-the-art performance by leveraging the [Warp](https://github.com/NVIDIA/warp) framework in combination with JAX. + + - 🌐 **Grid Refinement:** Implementing adaptive mesh refinement techniques for enhanced simulation accuracy. + +- ⚡ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Using Neon's data structure for improved scaling. + +- 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips. + +- 🗜ī¸ **GPU Accelerated Lossless Compression and Decompression**: Implementing high-performance lossless compression and decompression techniques for larger-scale simulations and improved performance. + +- 🌡ī¸ **Fluid-Thermal Simulation Capabilities:** Incorporating heat transfer and thermal effects into fluid simulations. + +- đŸŽ¯ **Adjoint-based Shape and Topology Optimization:** Implementing gradient-based optimization techniques for design optimization. + +- 🧠 **Machine Learning Accelerated Simulations:** Leveraging machine learning to speed up simulations and improve accuracy. + +- 📉 **Reduced Order Modeling using Machine Learning:** Developing data-driven reduced-order models for efficient and accurate simulations. + + +### Wishlist +*Contributions to these features are welcome. Please submit PRs for the Wishlist items.* + +- 🌊 **Free Surface Flows:** Simulating flows with free surfaces, such as water waves and droplets. + +- 📡 **Electromagnetic Wave Propagation:** Simulating the propagation of electromagnetic waves. + +- 🛩ī¸ **Supersonic Flows:** Simulating supersonic flows. + +- 🌊🧱 **Fluid-Solid Interaction:** Modeling the interaction between fluids and solid objects. + +- 🧩 **Multiphase Flow Simulation:** Simulating flows with multiple immiscible fluids. + +- đŸ”Ĩ **Combustion:** Simulating combustion processes and reactive flows. + +- đŸĒ¨ **Particle Flows and Discrete Element Method:** Incorporating particle-based methods for granular and particulate flows. + +- 🔧 **Better Geometry Processing Pipelines:** Improving the handling and preprocessing of complex geometries for simulations. + diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 8c99dff..2487919 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -94,7 +94,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index c67ce8e..f94e209 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -61,7 +61,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index ea8cad7..65b56bf 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -102,7 +102,7 @@ def initialize_fields(self): u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) else: u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype) - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init) def setup_stepper(self): force = self.get_force() diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 92c94aa..077ae98 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -122,7 +122,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 32c3d6d..907c1f2 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -100,7 +100,7 @@ def main(): backend, precision_policy = setup_simulation(args) velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge) - f_0 = initialize_eq(f_0, grid, velocity_set, backend) + f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend) elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index c8439d9..ccb4a82 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -2,11 +2,11 @@ from xlb.operator.equilibrium import QuadraticEquilibrium -def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None): +def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None): if rho is None: - rho = grid.create_field(cardinality=1, fill_value=1.0) + rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision) if u is None: - u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0) + u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision) equilibrium = QuadraticEquilibrium() if backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index dcdc8fd..0ddbcfc 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -86,7 +86,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -112,7 +112,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 716fd8e..6d4e3ed 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -111,7 +111,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -137,7 +137,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 87c6850..55a5851 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -229,7 +229,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -270,7 +270,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85bffb1..57d29fd 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -95,7 +95,7 @@ def kernel2d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -121,7 +121,7 @@ def kernel3d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index c2483c9..e723570 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -110,7 +110,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -136,7 +136,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 7879137..a42b695 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -344,7 +344,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -370,7 +370,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 0bf68b8..4e9fe29 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -352,7 +352,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -378,7 +378,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 8d47929..9f6ef5d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -81,8 +81,8 @@ def _get_thread_data_2d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: @@ -106,8 +106,8 @@ def _get_thread_data_3d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 196e3ba..60f63ef 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -59,7 +59,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -86,7 +86,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 7984829..bc731c6 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -337,7 +337,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -369,7 +369,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 0cce91b..ba337f0 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -79,7 +79,7 @@ def kernel3d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = feq[l] + f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l]) @wp.kernel def kernel2d( @@ -100,7 +100,7 @@ def kernel2d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = feq[l] + f[l, index[0], index[1]] = self.store_dtype(feq[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py index f148e12..b4da602 100644 --- a/xlb/operator/force/exact_difference_force.py +++ b/xlb/operator/force/exact_difference_force.py @@ -108,7 +108,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -134,7 +134,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index 561fe7a..329a71f 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -53,7 +53,7 @@ def kernel3d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -71,7 +71,7 @@ def kernel2d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 495a6a0..b574436 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -50,9 +50,9 @@ def kernel3d( _f[l] = f[l, index[0], index[1], index[2]] _rho, _u = functional(_f) - rho[0, index[0], index[1], index[2]] = _rho + rho[0, index[0], index[1], index[2]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -68,9 +68,9 @@ def kernel2d( _f[l] = f[l, index[0], index[1]] _rho, _u = functional(_f) - rho[0, index[0], index[1]] = _rho + rho[0, index[0], index[1]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 917c86a..687b38a 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -95,7 +95,7 @@ def kernel3d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1], index[2]] = _pi[d] + pi[d, index[0], index[1], index[2]] = self.store_dtype(_pi[d]) @wp.kernel def kernel2d( @@ -114,7 +114,7 @@ def kernel2d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1]] = _pi[d] + pi[d, index[0], index[1]] = self.store_dtype(_pi[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 18b2167..05cee7b 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -190,18 +190,18 @@ def get_thread_data_2d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_thread_data_3d( @@ -210,18 +210,18 @@ def get_thread_data_3d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_bc_auxilary_data_2d( @@ -243,7 +243,7 @@ def get_bc_auxilary_data_2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) return f_auxiliary @wp.func @@ -266,7 +266,7 @@ def get_bc_auxilary_data_3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) return f_auxiliary @wp.kernel @@ -283,38 +283,33 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1]] - f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, - feq, - rho, - u, - ) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = f_post_collision[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) # Construct the kernel @wp.kernel @@ -331,33 +326,33 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1], index[2]] - f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = f_post_collision[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index d96c307..dc2417a 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -76,7 +76,7 @@ def functional2d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1]] + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1]]) return _f @@ -94,7 +94,7 @@ def kernel2d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = _f[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the funcional to get streamed indices @wp.func @@ -117,7 +117,8 @@ def functional3d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1], pull_index[2]] + # Unlike other functionals, we need to cast the type here since we read from the buffer + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1], pull_index[2]]) return _f @@ -136,7 +137,7 @@ def kernel3d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = _f[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 3b0f85f..d85deed 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -1,7 +1,6 @@ # Enum for precision policy from enum import Enum, auto - import jax.numpy as jnp import warp as wp @@ -87,12 +86,4 @@ def cast_to_compute_jax(self, array): def cast_to_store_jax(self, array): store_precision = self.store_precision - return jnp.array(array, dtype=store_precision.jax_dtype) - - def cast_to_compute_warp(self, array): - compute_precision = self.compute_precision - return wp.array(array, dtype=compute_precision.wp_dtype) - - def cast_to_store_warp(self, array): - store_precision = self.store_precision - return wp.array(array, dtype=store_precision.wp_dtype) + return jnp.array(array, dtype=store_precision.jax_dtype) \ No newline at end of file