diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 8c99dff..2487919 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -94,7 +94,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK") diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index c67ce8e..f94e209 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -61,7 +61,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self, omega): self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index ea8cad7..65b56bf 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -102,7 +102,7 @@ def initialize_fields(self): u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init) else: u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype) - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init) def setup_stepper(self): force = self.get_force() diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 92c94aa..077ae98 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -122,7 +122,7 @@ def setup_boundary_masker(self): self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask) def initialize_fields(self): - self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) + self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend) def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 32c3d6d..907c1f2 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -100,7 +100,7 @@ def main(): backend, precision_policy = setup_simulation(args) velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge) - f_0 = initialize_eq(f_0, grid, velocity_set, backend) + f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend) elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index c8439d9..ccb4a82 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -2,11 +2,11 @@ from xlb.operator.equilibrium import QuadraticEquilibrium -def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None): +def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None): if rho is None: - rho = grid.create_field(cardinality=1, fill_value=1.0) + rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision) if u is None: - u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0) + u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision) equilibrium = QuadraticEquilibrium() if backend == ComputeBackend.JAX: diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index dcdc8fd..0ddbcfc 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -86,7 +86,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -112,7 +112,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 716fd8e..6d4e3ed 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -111,7 +111,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -137,7 +137,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 87c6850..55a5851 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -229,7 +229,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -270,7 +270,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 85bffb1..57d29fd 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -95,7 +95,7 @@ def kernel2d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -121,7 +121,7 @@ def kernel3d( # Write the result to the output for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index c2483c9..e723570 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -110,7 +110,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -136,7 +136,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 7879137..a42b695 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -344,7 +344,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -370,7 +370,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 0bf68b8..4e9fe29 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -352,7 +352,7 @@ def kernel2d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1]] = _f[l] + f_post[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the warp kernel @wp.kernel @@ -378,7 +378,7 @@ def kernel3d( # Write the distribution function for l in range(self.velocity_set.q): - f_post[l, index[0], index[1], index[2]] = _f[l] + f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d if self.velocity_set.d == 3 and self.bc_type == "velocity": diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 8d47929..9f6ef5d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -81,8 +81,8 @@ def _get_thread_data_2d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1]] - _f_post[l] = f_post[l, index[0], index[1]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: @@ -106,8 +106,8 @@ def _get_thread_data_3d( _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of populations - _f_pre[l] = f_pre[l, index[0], index[1], index[2]] - _f_post[l] = f_post[l, index[0], index[1], index[2]] + _f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]]) + _f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 196e3ba..60f63ef 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -59,7 +59,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -86,7 +86,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 7984829..bc731c6 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -337,7 +337,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -369,7 +369,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 0cce91b..ba337f0 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -79,7 +79,7 @@ def kernel3d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1], index[2]] = feq[l] + f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l]) @wp.kernel def kernel2d( @@ -100,7 +100,7 @@ def kernel2d( # Set the output for l in range(self.velocity_set.q): - f[l, index[0], index[1]] = feq[l] + f[l, index[0], index[1]] = self.store_dtype(feq[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/force/exact_difference_force.py b/xlb/operator/force/exact_difference_force.py index f148e12..b4da602 100644 --- a/xlb/operator/force/exact_difference_force.py +++ b/xlb/operator/force/exact_difference_force.py @@ -108,7 +108,7 @@ def kernel2d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1]] = _fout[l] + fout[l, index[0], index[1]] = self.store_dtype(_fout[l]) # Construct the warp kernel @wp.kernel @@ -134,7 +134,7 @@ def kernel3d( # Write the result for l in range(self.velocity_set.q): - fout[l, index[0], index[1], index[2]] = _fout[l] + fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d return functional, kernel diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index 561fe7a..329a71f 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -53,7 +53,7 @@ def kernel3d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -71,7 +71,7 @@ def kernel2d( _u = functional(_f, _rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 495a6a0..b574436 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -50,9 +50,9 @@ def kernel3d( _f[l] = f[l, index[0], index[1], index[2]] _rho, _u = functional(_f) - rho[0, index[0], index[1], index[2]] = _rho + rho[0, index[0], index[1], index[2]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1], index[2]] = _u[d] + u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d]) @wp.kernel def kernel2d( @@ -68,9 +68,9 @@ def kernel2d( _f[l] = f[l, index[0], index[1]] _rho, _u = functional(_f) - rho[0, index[0], index[1]] = _rho + rho[0, index[0], index[1]] = self.store_dtype(_rho) for d in range(self.velocity_set.d): - u[d, index[0], index[1]] = _u[d] + u[d, index[0], index[1]] = self.store_dtype(_u[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 917c86a..687b38a 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -95,7 +95,7 @@ def kernel3d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1], index[2]] = _pi[d] + pi[d, index[0], index[1], index[2]] = self.store_dtype(_pi[d]) @wp.kernel def kernel2d( @@ -114,7 +114,7 @@ def kernel2d( # Set the output for d in range(_pi_dim): - pi[d, index[0], index[1]] = _pi[d] + pi[d, index[0], index[1]] = self.store_dtype(_pi[d]) kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 18b2167..05cee7b 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -190,18 +190,18 @@ def get_thread_data_2d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1]]) # TODO fix vec bool if missing_mask[l, index[0], index[1]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_thread_data_3d( @@ -210,18 +210,18 @@ def get_thread_data_3d( index: Any, ): # Get the boundary id and missing mask - f_post_collision = _f_vec() + _f_post_collision = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] + _f_post_collision[l] = self.compute_dtype(f_0[l, index[0], index[1], index[2]]) # TODO fix vec bool if missing_mask[l, index[0], index[1], index[2]]: _missing_mask[l] = wp.uint8(1) else: _missing_mask[l] = wp.uint8(0) - return f_post_collision, _missing_mask + return _f_post_collision, _missing_mask @wp.func def get_bc_auxilary_data_2d( @@ -243,7 +243,7 @@ def get_bc_auxilary_data_2d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1]]) return f_auxiliary @wp.func @@ -266,7 +266,7 @@ def get_bc_auxilary_data_3d( for d in range(self.velocity_set.d): pull_index[d] = index[d] - (_c[d, l] + nv[d]) # The following is the post-streaming values of the neighbor cell - f_auxiliary[l] = f_0[l, pull_index[0], pull_index[1], pull_index[2]] + f_auxiliary[l] = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]]) return f_auxiliary @wp.kernel @@ -283,38 +283,33 @@ def kernel2d( index = wp.vec2i(i, j) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_2d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1]] - f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, - feq, - rho, - u, - ) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = f_post_collision[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f_post_collision[l]) # Construct the kernel @wp.kernel @@ -331,33 +326,33 @@ def kernel3d( index = wp.vec3i(i, j, k) # TODO warp should fix this # Read thread data for populations and missing mask - f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) + _f_post_collision, _missing_mask = get_thread_data_3d(f_0, missing_mask, index) # Apply streaming (pull method) - f_post_stream = self.stream.warp_functional(f_0, index) + _f_post_stream = self.stream.warp_functional(f_0, index) # Prepare auxilary data for BC (if applicable) _boundary_id = bc_mask[0, index[0], index[1], index[2]] - f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) + _f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct) # Apply post-streaming type boundary conditions - f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_stream = apply_post_streaming_bc(_f_post_collision, _f_post_stream, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Compute rho and u - rho, u = self.macroscopic.warp_functional(f_post_stream) + _rho, _u = self.macroscopic.warp_functional(_f_post_stream) # Compute equilibrium - feq = self.equilibrium.warp_functional(rho, u) + _feq = self.equilibrium.warp_functional(_rho, _u) # Apply collision - f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) # Apply post-collision type boundary conditions - f_post_collision = apply_post_collision_bc(f_post_stream, f_post_collision, f_auxiliary, _missing_mask, _boundary_id, bc_struct) + _f_post_collision = apply_post_collision_bc(_f_post_stream, _f_post_collision, _f_auxiliary, _missing_mask, _boundary_id, bc_struct) # Set the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = f_post_collision[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f_post_collision[l]) # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index d96c307..dc2417a 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -76,7 +76,7 @@ def functional2d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1]] + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1]]) return _f @@ -94,7 +94,7 @@ def kernel2d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1]] = _f[l] + f_1[l, index[0], index[1]] = self.store_dtype(_f[l]) # Construct the funcional to get streamed indices @wp.func @@ -117,7 +117,8 @@ def functional3d( pull_index[d] = 0 # Read the distribution function - _f[l] = f[l, pull_index[0], pull_index[1], pull_index[2]] + # Unlike other functionals, we need to cast the type here since we read from the buffer + _f[l] = self.compute_dtype(f[l, pull_index[0], pull_index[1], pull_index[2]]) return _f @@ -136,7 +137,7 @@ def kernel3d( # Write the output for l in range(self.velocity_set.q): - f_1[l, index[0], index[1], index[2]] = _f[l] + f_1[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) functional = functional3d if self.velocity_set.d == 3 else functional2d kernel = kernel3d if self.velocity_set.d == 3 else kernel2d diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 3b0f85f..d85deed 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -1,7 +1,6 @@ # Enum for precision policy from enum import Enum, auto - import jax.numpy as jnp import warp as wp @@ -87,12 +86,4 @@ def cast_to_compute_jax(self, array): def cast_to_store_jax(self, array): store_precision = self.store_precision - return jnp.array(array, dtype=store_precision.jax_dtype) - - def cast_to_compute_warp(self, array): - compute_precision = self.compute_precision - return wp.array(array, dtype=compute_precision.wp_dtype) - - def cast_to_store_warp(self, array): - store_precision = self.store_precision - return wp.array(array, dtype=store_precision.wp_dtype) + return jnp.array(array, dtype=store_precision.jax_dtype) \ No newline at end of file