From aeb17703941d336e71b33b15af317dcd8f2426b5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Fri, 2 Aug 2024 16:49:02 -0400 Subject: [PATCH] Added ruff --- .github/workflows/lint.yml | 27 ++++++ .pre-commit-config.yaml | 6 ++ examples/cfd/flow_past_sphere_3d.py | 25 ++---- examples/cfd/lid_driven_cavity_2d.py | 24 ++--- .../cfd/lid_driven_cavity_2d_distributed.py | 16 ++-- examples/cfd/windtunnel_3d.py | 34 ++----- .../flow_past_sphere.py | 39 ++------ .../cfd_old_to_be_migrated/taylor_green.py | 79 +++++------------ examples/performance/mlups_3d.py | 24 ++--- requirements.txt | 3 +- ruff.toml | 42 +++++++++ setup.py | 3 +- .../bc_equilibrium/test_bc_equilibrium_jax.py | 22 ++--- .../test_bc_equilibrium_warp.py | 23 ++--- .../test_bc_fullway_bounce_back_jax.py | 18 ++-- .../test_bc_fullway_bounce_back_warp.py | 17 ++-- .../mask/test_bc_indices_masker_jax.py | 21 ++--- .../mask/test_bc_indices_masker_warp.py | 17 ++-- tests/grids/test_grid_jax.py | 3 +- .../collision/test_bgk_collision_jax.py | 1 + .../collision/test_bgk_collision_warp.py | 7 +- .../equilibrium/test_equilibrium_jax.py | 9 +- .../equilibrium/test_equilibrium_warp.py | 5 +- .../macroscopic/test_macroscopic_jax.py | 2 +- .../macroscopic/test_macroscopic_warp.py | 13 +-- tests/kernels/stream/test_stream_warp.py | 4 +- xlb/__init__.py | 11 +-- xlb/distribute/__init__.py | 2 +- xlb/distribute/distribute.py | 26 ++---- xlb/experimental/ooc/__init__.py | 4 +- xlb/experimental/ooc/ooc_array.py | 88 ++++--------------- xlb/experimental/ooc/out_of_core.py | 14 +-- xlb/experimental/ooc/tiles/compressed_tile.py | 47 +++------- xlb/experimental/ooc/tiles/dense_tile.py | 14 +-- xlb/experimental/ooc/tiles/dynamic_array.py | 10 +-- xlb/experimental/ooc/tiles/tile.py | 10 +-- xlb/experimental/ooc/utils.py | 2 +- xlb/grid/__init__.py | 2 +- xlb/grid/grid.py | 25 +++--- xlb/grid/jax_grid.py | 17 +--- xlb/grid/warp_grid.py | 9 +- xlb/helper/__init__.py | 4 +- xlb/helper/nse_solver.py | 17 +--- xlb/operator/__init__.py | 4 +- xlb/operator/boundary_condition/__init__.py | 12 +-- .../boundary_condition/bc_do_nothing.py | 8 +- .../boundary_condition/bc_equilibrium.py | 14 +-- .../bc_fullway_bounce_back.py | 10 +-- .../bc_halfway_bounce_back.py | 12 +-- .../boundary_condition/boundary_condition.py | 8 +- .../boundary_condition_registry.py | 8 +- xlb/operator/boundary_masker/__init__.py | 4 +- .../indices_boundary_masker.py | 45 +++------- .../boundary_masker/stl_boundary_masker.py | 18 +--- xlb/operator/collision/__init__.py | 6 +- xlb/operator/collision/kbc.py | 32 +++---- xlb/operator/equilibrium/__init__.py | 4 +- .../equilibrium/quadratic_equilibrium.py | 3 +- xlb/operator/macroscopic/__init__.py | 2 +- xlb/operator/macroscopic/macroscopic.py | 4 +- xlb/operator/operator.py | 21 +---- xlb/operator/parallel_operator.py | 8 +- xlb/operator/precision_caster/__init__.py | 2 +- .../precision_caster/precision_caster.py | 5 +- xlb/operator/stepper/__init__.py | 4 +- xlb/operator/stepper/nse_stepper.py | 38 ++------ xlb/operator/stepper/stepper.py | 55 +++--------- xlb/operator/stream/__init__.py | 2 +- xlb/operator/stream/stream.py | 5 +- xlb/precision_policy.py | 4 +- xlb/precision_policy/precision_policy.py | 21 ++--- xlb/utils/__init__.py | 14 +-- xlb/utils/utils.py | 37 +++----- xlb/velocity_set/__init__.py | 8 +- xlb/velocity_set/d2q9.py | 5 +- xlb/velocity_set/d3q19.py | 9 +- xlb/velocity_set/d3q27.py | 1 + xlb/velocity_set/velocity_set.py | 19 +--- 78 files changed, 415 insertions(+), 823 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml create mode 100644 ruff.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..1b44c5a --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,27 @@ +name: Lint + +on: + pull_request: + branches: + - major-refactoring # Remember to add main branch later + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Run Ruff + run: ruff check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6a2bd2f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: [--fix] diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index a7d9dac..2a580aa 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -29,9 +29,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -61,10 +59,7 @@ def define_boundary_indices(self): z = np.arange(self.grid_shape[2]) X, Y, Z = np.meshgrid(x, y, z, indexing="ij") indices = np.where( - (X - self.grid_shape[0] // 6) ** 2 - + (Y - self.grid_shape[1] // 2) ** 2 - + (Z - self.grid_shape[2] // 2) ** 2 - < sphere_radius**2 + (X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 + (Z - self.grid_shape[2] // 2) ** 2 < sphere_radius**2 ) sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)] @@ -84,23 +79,17 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: @@ -134,7 +123,5 @@ def post_process(self, i): precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = FlowOverSphere( - omega, grid_shape, velocity_set, backend, precision_policy - ) + simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=10000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index ecc6f0b..488ebc1 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -24,9 +24,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -42,9 +40,7 @@ def _setup(self, omega): def define_boundary_indices(self): lid = self.grid.boundingBoxIndices["top"] walls = [ - self.grid.boundingBoxIndices["bottom"][i] - + self.grid.boundingBoxIndices["left"][i] - + self.grid.boundingBoxIndices["right"][i] + self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["left"][i] + self.grid.boundingBoxIndices["right"][i] for i in range(self.velocity_set.d) ] return lid, walls @@ -61,23 +57,17 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: @@ -114,7 +104,5 @@ def post_process(self, i): precision_policy = PrecisionPolicy.FP32FP32 omega = 1.6 - simulation = LidDrivenCavity2D( - omega, grid_shape, velocity_set, backend, precision_policy - ) + simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 7a43a14..225d6bd 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -11,24 +11,24 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): super().__init__(omega, grid_shape, velocity_set, backend, precision_policy) def setup_stepper(self, omega): - stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions - ) + stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions) distributed_stepper = distribute( - stepper, self.grid, self.velocity_set, - ) + stepper, + self.grid, + self.velocity_set, + ) self.stepper = distributed_stepper return - + if __name__ == "__main__": # Running the simulation grid_size = 512 grid_shape = (grid_size, grid_size) - backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! + backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 - omega=1.6 + omega = 1.6 simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps=5000, post_process_interval=1000) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 35140ee..e76b303 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -19,9 +19,7 @@ class WindTunnel3D: - def __init__( - self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy - ): + def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precision_policy): # initialize backend xlb.init( velocity_set=velocity_set, @@ -33,9 +31,7 @@ def __init__( self.velocity_set = velocity_set self.backend = backend self.precision_policy = precision_policy - self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = ( - create_nse_fields(grid_shape) - ) + self.grid, self.f_0, self.f_1, self.missing_mask, self.boundary_mask = create_nse_fields(grid_shape) self.stepper = None self.boundary_conditions = [] @@ -72,10 +68,8 @@ def define_boundary_indices(self): car_length_lbm_unit = grid_size_x / 4 car_voxelized, pitch = self.voxelize_stl(stl_filename, car_length_lbm_unit) - car_area = np.prod(car_voxelized.shape[1:]) - tx, ty, tz = ( - np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape - ) + # car_area = np.prod(car_voxelized.shape[1:]) + tx, ty, _ = np.array([grid_size_x, grid_size_y, grid_size_z]) - car_voxelized.shape shift = [tx // 4, ty // 2, 0] car = np.argwhere(car_voxelized) + shift car = np.array(car).T @@ -97,31 +91,23 @@ def setup_boundary_masks(self): precision_policy=self.precision_policy, compute_backend=self.backend, ) - self.boundary_mask, self.missing_mask = indices_boundary_masker( - self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0) - ) + self.boundary_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.boundary_mask, self.missing_mask, (0, 0, 0)) def initialize_fields(self): self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend) def setup_stepper(self, omega): - self.stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=self.boundary_conditions, collision_type="KBC" - ) + self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="KBC") def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_1 = self.stepper( - self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i - ) + self.f_1 = self.stepper(self.f_0, self.f_1, self.boundary_mask, self.missing_mask, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: elapsed_time = time.time() - start_time - print( - f"Iteration: {i+1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s" - ) + print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") if i % post_process_interval == 0 or i == num_steps - 1: self.post_process(i) @@ -178,7 +164,5 @@ def post_process(self, i): print(f"Max iterations: {num_steps}") print("\n" + "=" * 50 + "\n") - simulation = WindTunnel3D( - omega, wind_speed, grid_shape, velocity_set, backend, precision_policy - ) + simulation = WindTunnel3D(omega, wind_speed, grid_shape, velocity_set, backend, precision_policy) simulation.run(num_steps, print_interval, post_process_interval=1000) diff --git a/examples/cfd_old_to_be_migrated/flow_past_sphere.py b/examples/cfd_old_to_be_migrated/flow_past_sphere.py index 7e8af30..68d1c2b 100644 --- a/examples/cfd_old_to_be_migrated/flow_past_sphere.py +++ b/examples/cfd_old_to_be_migrated/flow_past_sphere.py @@ -22,8 +22,8 @@ from xlb.operator import Operator -class UniformInitializer(Operator): +class UniformInitializer(Operator): def _construct_warp(self): # Construct the warp kernel @wp.kernel @@ -149,48 +149,27 @@ def warp_implementation(self, rho, u, vel): y = np.arange(nr) z = np.arange(nr) X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = np.array(indices).T indices = wp.from_numpy(indices, dtype=wp.int32) # Set boundary conditions on the indices - boundary_mask, missing_mask = indices_boundary_masker( - indices, - half_way_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) - ) + boundary_mask, missing_mask = indices_boundary_masker(indices, half_way_bc.id, boundary_mask, missing_mask, (0, 0, 0)) # Set inlet bc lower_bound = (0, 0, 0) upper_bound = (0, nr, nr) direction = (1, 0, 0) boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - equilibrium_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) + lower_bound, upper_bound, direction, equilibrium_bc.id, boundary_mask, missing_mask, (0, 0, 0) ) # Set outlet bc - lower_bound = (nr-1, 0, 0) - upper_bound = (nr-1, nr, nr) + lower_bound = (nr - 1, 0, 0) + upper_bound = (nr - 1, nr, nr) direction = (-1, 0, 0) boundary_mask, missing_mask = planar_boundary_masker( - lower_bound, - upper_bound, - direction, - do_nothing_bc.id, - boundary_mask, - missing_mask, - (0, 0, 0) + lower_bound, upper_bound, direction, do_nothing_bc.id, boundary_mask, missing_mask, (0, 0, 0) ) # Set initial conditions @@ -201,7 +180,7 @@ def warp_implementation(self, rho, u, vel): plot_freq = 512 save_dir = "flow_past_sphere" os.makedirs(save_dir, exist_ok=True) - #compute_mlup = False # Plotting results + # compute_mlup = False # Plotting results compute_mlup = True num_steps = 1024 * 8 start = time.time() @@ -225,4 +204,4 @@ def warp_implementation(self, rho, u, vel): end = time.time() # Print MLUPS - print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") diff --git a/examples/cfd_old_to_be_migrated/taylor_green.py b/examples/cfd_old_to_be_migrated/taylor_green.py index 10eb54f..9ed7fa6 100644 --- a/examples/cfd_old_to_be_migrated/taylor_green.py +++ b/examples/cfd_old_to_be_migrated/taylor_green.py @@ -4,10 +4,8 @@ from tqdm import tqdm import os import matplotlib.pyplot as plt -from functools import partial from typing import Any import jax.numpy as jnp -from jax import jit import warp as wp wp.init() @@ -15,13 +13,14 @@ import xlb from xlb.operator import Operator + class TaylorGreenInitializer(Operator): """ Initialize the Taylor-Green vortex. """ @Operator.register_backend(xlb.ComputeBackend.JAX) - #@partial(jit, static_argnums=(0)) + # @partial(jit, static_argnums=(0)) def jax_implementation(self, vel, nr): # Make meshgrid x = jnp.linspace(0, 2 * jnp.pi, nr) @@ -33,24 +32,14 @@ def jax_implementation(self, vel, nr): u = jnp.stack( [ vel * jnp.sin(X) * jnp.cos(Y) * jnp.cos(Z), - - vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), + -vel * jnp.cos(X) * jnp.sin(Y) * jnp.cos(Z), jnp.zeros_like(X), ], axis=0, ) # Compute rho - rho = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - jnp.cos(2.0 * X) - + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0)) - ) - + 1.0 - ) + rho = 3.0 * vel * vel * (1.0 / 16.0) * (jnp.cos(2.0 * X) + (jnp.cos(2.0 * Y) * (jnp.cos(2.0 * Z) + 2.0))) + 1.0 rho = jnp.expand_dims(rho, axis=0) return rho, u @@ -74,22 +63,11 @@ def kernel( # Compute u u[0, i, j, k] = vel * wp.sin(x) * wp.cos(y) * wp.cos(z) - u[1, i, j, k] = - vel * wp.cos(x) * wp.sin(y) * wp.cos(z) + u[1, i, j, k] = -vel * wp.cos(x) * wp.sin(y) * wp.cos(z) u[2, i, j, k] = 0.0 # Compute rho - rho[0, i, j, k] = ( - 3.0 - * vel - * vel - * (1.0 / 16.0) - * ( - wp.cos(2.0 * x) - + (wp.cos(2.0 * y) - * (wp.cos(2.0 * z) + 2.0)) - ) - + 1.0 - ) + rho[0, i, j, k] = 3.0 * vel * vel * (1.0 / 16.0) * (wp.cos(2.0 * x) + (wp.cos(2.0 * y) * (wp.cos(2.0 * z) + 2.0))) + 1.0 return None, kernel @@ -108,8 +86,8 @@ def warp_implementation(self, rho, u, vel, nr): ) return rho, u -def run_taylor_green(backend, compute_mlup=True): +def run_taylor_green(backend, compute_mlup=True): # Set the compute backend if backend == "warp": compute_backend = xlb.ComputeBackend.WARP @@ -139,35 +117,19 @@ def run_taylor_green(backend, compute_mlup=True): missing_mask = grid.create_field(cardinality=velocity_set.q, precision=xlb.Precision.BOOL) # Make operators - initializer = TaylorGreenInitializer( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - collision = xlb.operator.collision.BGK( - omega=1.9, - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + initializer = TaylorGreenInitializer(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + collision = xlb.operator.collision.BGK(omega=1.9, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - macroscopic = xlb.operator.macroscopic.Macroscopic( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) - stream = xlb.operator.stream.Stream( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend) + velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend + ) + macroscopic = xlb.operator.macroscopic.Macroscopic(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) + stream = xlb.operator.stream.Stream(velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend) stepper = xlb.operator.stepper.IncompressibleNavierStokesStepper( - collision=collision, - equilibrium=equilibrium, - macroscopic=macroscopic, - stream=stream) + collision=collision, equilibrium=equilibrium, macroscopic=macroscopic, stream=stream + ) # Parrallelize the stepper TODO: Add this functionality - #stepper = grid.parallelize_operator(stepper) + # stepper = grid.parallelize_operator(stepper) # Set initial conditions if backend == "warp": @@ -200,8 +162,7 @@ def run_taylor_green(backend, compute_mlup=True): elif backend == "jax": rho, local_u = macroscopic(f0) - - plt.imshow(local_u[0, :, nr//2, :]) + plt.imshow(local_u[0, :, nr // 2, :]) plt.colorbar() plt.savefig(f"{save_dir}/{str(_).zfill(6)}.png") plt.close() @@ -209,12 +170,12 @@ def run_taylor_green(backend, compute_mlup=True): end = time.time() # Print MLUPS - print(f"MLUPS: {num_steps*nr**3/(end-start)/1e6}") + print(f"MLUPS: {num_steps * nr**3 / (end - start) / 1e6}") -if __name__ == "__main__": +if __name__ == "__main__": # Run Taylor-Green vortex on different backends backends = ["warp", "jax"] - #backends = ["jax"] + # backends = ["jax"] for backend in backends: run_taylor_green(backend, compute_mlup=True) diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 75ecccc..74bfa04 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -11,19 +11,11 @@ def parse_arguments(): - parser = argparse.ArgumentParser( - description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)" - ) - parser.add_argument( - "cube_edge", type=int, help="Length of the edge of the cubic grid" - ) + parser = argparse.ArgumentParser(description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)") + parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") parser.add_argument("num_steps", type=int, help="Timestep for the simulation") - parser.add_argument( - "backend", type=str, help="Backend for the simulation (jax or warp)" - ) - parser.add_argument( - "precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)" - ) + parser.add_argument("backend", type=str, help="Backend for the simulation (jax or warp)") + parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") return parser.parse_args() @@ -77,9 +69,7 @@ def setup_boundary_conditions(grid): def run(f_0, f_1, backend, grid, boundary_mask, missing_mask, num_steps): omega = 1.0 - stepper = IncompressibleNavierStokesStepper( - omega, boundary_conditions=setup_boundary_conditions(grid) - ) + stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=setup_boundary_conditions(grid)) if backend == ComputeBackend.JAX: stepper = distribute( @@ -111,9 +101,7 @@ def main(): grid, f_0, f_1, missing_mask, boundary_mask = create_grid_and_fields(args.cube_edge) f_0 = initialize_eq(f_0, grid, xlb.velocity_set.D3Q19(), backend) - elapsed_time = run( - f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps - ) + elapsed_time = run(f_0, f_1, backend, grid, boundary_mask, missing_mask, args.num_steps) mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) print(f"Simulation completed in {elapsed_time:.2f} seconds") diff --git a/requirements.txt b/requirements.txt index ebae946..ee107af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git tqdm==4.66.2 warp-lang==1.0.2 numpy-stl==3.1.1 -pydantic==2.7.0 \ No newline at end of file +pydantic==2.7.0 +ruff==0.5.6 \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..83c6370 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,42 @@ +# Adopted from tinygrad's ruff.toml thanks @geohot +indent-width = 4 +preview = true +target-version = "py38" + +lint.select = [ + "F", # Pyflakes + "W6", + "E71", + "E72", + "E112", # no-indented-block + "E113", # unexpected-indentation + # "E124", + "E203", # whitespace-before-punctuation + "E272", # multiple-spaces-before-keyword + "E303", # too-many-blank-lines + "E304", # blank-line-after-decorator + "E501", # line-too-long + # "E502", + "E702", # multiple-statements-on-one-line-semicolon + "E703", # useless-semicolon + "E731", # lambda-assignment + "W191", # tab-indentation + "W291", # trailing-whitespace + "W293", # blank-line-with-whitespace + "UP039", # unnecessary-class-parentheses + "C416", # unnecessary-comprehension + "RET506", # superfluous-else-raise + "RET507", # superfluous-else-continue + "A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing + "SIM105", # suppressible-exception + "FURB110",# if-exp-instead-of-or-operator +] + +# unused-variable, shadowing a Python builtin module, Module imported but unused +lint.ignore = ["F841", "A005", "F401"] +line-length = 150 + +exclude = [ + "docs/", + "xlb/experimental/", +] \ No newline at end of file diff --git a/setup.py b/setup.py index 5ef3ed7..6f9780a 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ version="0.0.1", author="", packages=find_packages(), - install_requires=[ - ], + install_requires=[], include_package_data=True, ) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 9017a9c..3e50fdb 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -7,6 +7,7 @@ from xlb import DefaultConfig from xlb.operator.boundary_masker import IndicesBoundaryMasker + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -29,9 +30,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -48,10 +47,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] @@ -62,9 +58,7 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker( - [equilibrium_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -80,13 +74,9 @@ def test_bc_equilibrium_jax(dim, velocity_set, grid_shape): weights = velocity_set.w for i, weight in enumerate(weights): if dim == 2: - assert jnp.allclose( - f[i, indices[0], indices[1]], weight - ), f"Direction {i} in f does not match the expected weight" + assert jnp.allclose(f[i, indices[0], indices[1]], weight), f"Direction {i} in f does not match the expected weight" else: - assert jnp.allclose( - f[i, indices[0], indices[1], indices[2]], weight - ), f"Direction {i} in f does not match the expected weight" + assert jnp.allclose(f[i, indices[0], indices[1], indices[2]], weight), f"Direction {i} in f does not match the expected weight" # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values mask_outside = np.ones(grid_shape, dtype=bool) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 7bb78cf..e319dbd 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -1,12 +1,12 @@ import pytest import numpy as np -import warp as wp import xlb from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig from xlb.operator.boundary_masker import IndicesBoundaryMasker + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -29,9 +29,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -48,10 +46,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] equilibrium = xlb.operator.equilibrium.QuadraticEquilibrium() @@ -63,9 +58,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): indices=indices, ) - boundary_mask, missing_mask = indices_boundary_masker( - [equilibrium_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([equilibrium_bc], boundary_mask, missing_mask, start_index=None) f = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32) @@ -84,13 +77,9 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): weights = velocity_set.w for i, weight in enumerate(weights): if dim == 2: - assert np.allclose( - f[i, indices[0], indices[1]], weight - ), f"Direction {i} in f does not match the expected weight" + assert np.allclose(f[i, indices[0], indices[1]], weight), f"Direction {i} in f does not match the expected weight" else: - assert np.allclose( - f[i, indices[0], indices[1], indices[2]], weight - ), f"Direction {i} in f does not match the expected weight" + assert np.allclose(f[i, indices[0], indices[1], indices[2]], weight), f"Direction {i} in f does not match the expected weight" # Make sure that everywhere else the values are the same as f_post. Note that indices are just int values mask_outside = np.ones(grid_shape, dtype=bool) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index b6ce4c3..1b7edc2 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -7,6 +7,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -31,9 +32,7 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,21 +49,14 @@ def test_fullway_bounce_back_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - boundary_mask, missing_mask = indices_boundary_masker( - [fullway_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) - f_pre = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0 - ) + f_pre = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=0.0) # Generate a random field with the same shape key = jax.random.PRNGKey(0) random_field = jax.random.uniform(key, f_pre.shape) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 3f8f0d0..963e081 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -2,11 +2,11 @@ import numpy as np import warp as wp import xlb -import jax from xlb.compute_backend import ComputeBackend from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -31,9 +31,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,17 +48,12 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] fullway_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) - - boundary_mask, missing_mask = indices_boundary_masker( - [fullway_bc], boundary_mask, missing_mask, start_index=None - ) + + boundary_mask, missing_mask = indices_boundary_masker([fullway_bc], boundary_mask, missing_mask, start_index=None) # Generate a random field with the same shape random_field = np.random.rand(velocity_set.q, *grid_shape).astype(np.float32) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index 0de8805..ddbc761 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -32,9 +32,7 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -51,19 +49,14 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] assert len(indices) == dim test_bc = xlb.operator.boundary_condition.FullwayBounceBackBC(indices=indices) test_bc.id = 5 - boundary_mask, missing_mask = indices_boundary_masker( - [test_bc], boundary_mask, missing_mask, start_index=None - ) + boundary_mask, missing_mask = indices_boundary_masker([test_bc], boundary_mask, missing_mask, start_index=None) assert missing_mask.dtype == xlb.Precision.BOOL.jax_dtype @@ -79,13 +72,9 @@ def test_indices_masker_jax(dim, velocity_set, grid_shape): boundary_mask = boundary_mask.at[0, indices[0], indices[1]].set(0) assert jnp.all(boundary_mask == 0) if dim == 3: - assert jnp.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id - ) + assert jnp.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) # assert that the rest of the boundary_mask is zero - boundary_mask = boundary_mask.at[ - 0, indices[0], indices[1], indices[2] - ].set(0) + boundary_mask = boundary_mask.at[0, indices[0], indices[1], indices[2]].set(0) assert jnp.all(boundary_mask == 0) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 43911f6..6919ba9 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -1,5 +1,4 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend @@ -31,9 +30,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field( - cardinality=velocity_set.q, dtype=xlb.Precision.BOOL - ) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -50,10 +47,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 < sphere_radius**2) else: X, Y, Z = np.meshgrid(x, y, z) - indices = np.where( - (X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 - < sphere_radius**2 - ) + indices = np.where((X - nr // 2) ** 2 + (Y - nr // 2) ** 2 + (Z - nr // 2) ** 2 < sphere_radius**2) indices = [tuple(indices[i]) for i in range(velocity_set.d)] @@ -80,15 +74,14 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): if dim == 2: assert np.all(boundary_mask[0, indices[0], indices[1]] == test_bc.id) # assert that the rest of the boundary_mask is zero - boundary_mask[0, indices[0], indices[1]]= 0 + boundary_mask[0, indices[0], indices[1]] = 0 assert np.all(boundary_mask == 0) if dim == 3: - assert np.all( - boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id - ) + assert np.all(boundary_mask[0, indices[0], indices[1], indices[2]] == test_bc.id) # assert that the rest of the boundary_mask is zero boundary_mask[0, indices[0], indices[1], indices[2]] = 0 assert np.all(boundary_mask == 0) + if __name__ == "__main__": pytest.main() diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index ce4bc70..edd9dd0 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -7,6 +7,7 @@ from jax.experimental import mesh_utils import jax.numpy as jnp + def init_xlb_env(): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -51,6 +52,7 @@ def test_jax_3d_grid_initialization(grid_size): "z", ), "PartitionSpec is incorrect" + def test_jax_grid_create_field_fill_value(): init_xlb_env() grid_shape = (100, 100) @@ -62,7 +64,6 @@ def test_jax_grid_create_field_fill_value(): assert jnp.allclose(f, fill_value), "Field not properly initialized with fill_value" - @pytest.fixture(autouse=True) def setup_xlb_env(): init_xlb_env() diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index cce5ca4..5a400e0 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -7,6 +7,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 7509c1d..522ea33 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -1,14 +1,12 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.macroscopic import Macroscopic from xlb.operator.collision import BGK from xlb.grid import grid_factory from xlb import DefaultConfig -from xlb.precision_policy import Precision + def init_xlb_env(velocity_set): xlb.init( @@ -17,6 +15,7 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set(), ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape,omega", [ @@ -40,7 +39,6 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_eq = compute_macro(rho, u, f_eq) - compute_collision = BGK(omega=omega) f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) @@ -53,5 +51,6 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): assert np.allclose(f_out, f_orig - omega * (f_orig - f_eq), atol=1e-5) + if __name__ == "__main__": pytest.main() diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index fbdadb6..07bafe7 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -6,6 +6,7 @@ from xlb.grid import grid_factory from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -38,16 +39,12 @@ def test_quadratic_equilibrium_jax(dim, velocity_set, grid_shape): # Test sum of f_eq across cardinality at each point sum_f_eq = np.sum(f_eq, axis=0) - assert np.allclose( - sum_f_eq, 1.0 - ), f"Sum of f_eq should be 1.0 across all directions at each grid point" + assert np.allclose(sum_f_eq, 1.0), "Sum of f_eq should be 1.0 across all directions at each grid point" # Test that each direction matches the expected weights weights = DefaultConfig.velocity_set.w for i, weight in enumerate(weights): - assert np.allclose( - f_eq[i, ...], weight - ), f"Direction {i} in f_eq does not match the expected weight" + assert np.allclose(f_eq[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" if __name__ == "__main__": diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index ef2287f..063a723 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -1,11 +1,12 @@ import pytest -import warp as wp import numpy as np import xlb from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.grid import grid_factory from xlb import DefaultConfig + + def init_xlb_env(velocity_set): xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, @@ -13,6 +14,7 @@ def init_xlb_env(velocity_set): velocity_set=velocity_set(), ) + @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ @@ -45,6 +47,7 @@ def test_quadratic_equilibrium_warp(dim, velocity_set, grid_shape): for i, weight in enumerate(weights): assert np.allclose(f_eq_np[i, ...], weight), f"Direction {i} in f_eq does not match the expected weight" + # @pytest.fixture(autouse=True) # def setup_xlb_env(request): # dim, velocity_set, grid_shape = request.param diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 89ef393..50d1735 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -5,7 +5,7 @@ from xlb.operator.equilibrium import QuadraticEquilibrium from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory -from xlb import DefaultConfig + def init_xlb_env(velocity_set): xlb.init( diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index 7a4a8cd..d98a014 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -6,7 +6,6 @@ from xlb.operator.macroscopic import Macroscopic from xlb.grid import grid_factory from xlb import DefaultConfig -import warp as wp def init_xlb_env(velocity_set): @@ -25,8 +24,8 @@ def init_xlb_env(velocity_set): (2, xlb.velocity_set.D2Q9, (100, 100), 1.1, 2.0), (2, xlb.velocity_set.D2Q9, (50, 50), 1.1, 2.0), (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.0, 0.0), - (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. - (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), #TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 1.0), # TODO: Uncommenting will cause a Warp error. Needs investigation. + (3, xlb.velocity_set.D3Q19, (50, 50, 50), 1.1, 2.0), # TODO: Uncommenting will cause a Warp error. Needs investigation. ], ) def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): @@ -45,12 +44,8 @@ def test_macroscopic_warp(dim, velocity_set, grid_shape, rho, velocity): rho_calc, u_calc = compute_macro(f_eq, rho_calc, u_calc) - assert np.allclose( - rho_calc.numpy(), rho - ), f"Computed density should be close to initialized density {rho}" - assert np.allclose( - u_calc.numpy(), velocity - ), f"Computed velocity should be close to initialized velocity {velocity}" + assert np.allclose(rho_calc.numpy(), rho), f"Computed density should be close to initialized density {rho}" + assert np.allclose(u_calc.numpy(), velocity), f"Computed velocity should be close to initialized velocity {velocity}" if __name__ == "__main__": diff --git a/tests/kernels/stream/test_stream_warp.py b/tests/kernels/stream/test_stream_warp.py index af70b4c..b83368d 100644 --- a/tests/kernels/stream/test_stream_warp.py +++ b/tests/kernels/stream/test_stream_warp.py @@ -70,9 +70,7 @@ def test_stream_operator_warp(dim, velocity_set, grid_shape): f_streamed = my_grid_warp.create_field(cardinality=velocity_set.q) f_streamed = stream_op(f_initial_warp, f_streamed) - assert jnp.allclose( - f_streamed.numpy(), np.array(expected) - ), "Streaming did not occur as expected" + assert jnp.allclose(f_streamed.numpy(), np.array(expected)), "Streaming did not occur as expected" if __name__ == "__main__": diff --git a/xlb/__init__.py b/xlb/__init__.py index be63d06..b58db3b 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -1,10 +1,10 @@ # Enum classes -from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy, Precision -from xlb.physics_type import PhysicsType +from xlb.compute_backend import ComputeBackend as ComputeBackend +from xlb.precision_policy import PrecisionPolicy as PrecisionPolicy, Precision as Precision +from xlb.physics_type import PhysicsType as PhysicsType # Config -from .default_config import init, DefaultConfig +from .default_config import init as init, DefaultConfig as DefaultConfig # Velocity Set import xlb.velocity_set @@ -15,6 +15,7 @@ import xlb.operator.stream import xlb.operator.boundary_condition import xlb.operator.macroscopic + # Grids import xlb.grid @@ -25,4 +26,4 @@ import xlb.utils # Distributed computing -import xlb.distribute \ No newline at end of file +import xlb.distribute diff --git a/xlb/distribute/__init__.py b/xlb/distribute/__init__.py index 33e0d2b..25fa0af 100644 --- a/xlb/distribute/__init__.py +++ b/xlb/distribute/__init__.py @@ -1 +1 @@ -from .distribute import distribute \ No newline at end of file +from .distribute import distribute as distribute diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index bcee7dd..c62b915 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -57,13 +57,8 @@ def build_specs(grid, *args): else: sharding_flags.append(False) - in_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() - for flag in sharding_flags - ) - out_specs = tuple( - P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results) - ) + in_specs = tuple(P(*((None, "x") + (grid.dim - 1) * (None,))) if flag else P() for flag in sharding_flags) + out_specs = tuple(P(*((None, "x") + (grid.dim - 1) * (None,))) for _ in range(num_results)) return tuple(sharding_flags), in_specs, out_specs def _wrapped_operator(*args): @@ -92,28 +87,19 @@ def distribute(operator, grid, velocity_set, num_results=1, ops="permute"): """ if isinstance(operator, IncompressibleNavierStokesStepper): # Check for post-streaming boundary conditions - has_post_streaming_bc = any( - bc.implementation_step == ImplementationStep.STREAMING - for bc in operator.boundary_conditions - ) + has_post_streaming_bc = any(bc.implementation_step == ImplementationStep.STREAMING for bc in operator.boundary_conditions) if has_post_streaming_bc: # If there are post-streaming BCs, only distribute the stream operator - distributed_stream = distribute_operator( - operator.stream, grid, velocity_set - ) + distributed_stream = distribute_operator(operator.stream, grid, velocity_set) operator.stream = distributed_stream else: # If no post-streaming BCs, distribute the whole operator - distributed_op = distribute_operator( - operator, grid, velocity_set, num_results=num_results, ops=ops - ) + distributed_op = distribute_operator(operator, grid, velocity_set, num_results=num_results, ops=ops) return distributed_op return operator else: # For other operators, apply the original distribution logic - distributed_op = distribute_operator( - operator, grid, velocity_set, num_results=num_results, ops=ops - ) + distributed_op = distribute_operator(operator, grid, velocity_set, num_results=num_results, ops=ops) return distributed_op diff --git a/xlb/experimental/ooc/__init__.py b/xlb/experimental/ooc/__init__.py index 5206cc1..801683d 100644 --- a/xlb/experimental/ooc/__init__.py +++ b/xlb/experimental/ooc/__init__.py @@ -1,2 +1,2 @@ -from xlb.experimental.ooc.out_of_core import OOCmap -from xlb.experimental.ooc.ooc_array import OOCArray +from xlb.experimental.ooc.out_of_core import OOCmap as OOCmap +from xlb.experimental.ooc.ooc_array import OOCArray as OOCArray diff --git a/xlb/experimental/ooc/ooc_array.py b/xlb/experimental/ooc/ooc_array.py index 6effde6..11aedc3 100644 --- a/xlb/experimental/ooc/ooc_array.py +++ b/xlb/experimental/ooc/ooc_array.py @@ -3,7 +3,6 @@ # from mpi4py import MPI import itertools -from dataclasses import dataclass from xlb.experimental.ooc.tiles.dense_tile import DenseTile, DenseGPUTile, DenseCPUTile from xlb.experimental.ooc.tiles.compressed_tile import ( @@ -63,9 +62,7 @@ def __init__( if self.codec is None: self.Tile = DenseTile self.DeviceTile = DenseGPUTile - self.HostTile = ( - DenseCPUTile # TODO: Possibly make HardDiskTile or something - ) + self.HostTile = DenseCPUTile # TODO: Possibly make HardDiskTile or something else: self.Tile = CompressedTile @@ -84,45 +81,33 @@ def __init__( # Get number of tiles per process if self.nr_tiles % self.nr_proc != 0: - raise ValueError( - f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}." - ) + raise ValueError(f"Number of tiles {self.nr_tiles} does not divide number of processes {self.nr_proc}.") self.nr_tiles_per_proc = self.nr_tiles // self.nr_proc # Make the tile mapppings self.tile_process_map = {} self.tile_device_map = {} - for i, tile_index in enumerate( - itertools.product(*[range(n) for n in self.tile_dims]) - ): + for i, tile_index in enumerate(itertools.product(*[range(n) for n in self.tile_dims])): self.tile_process_map[tile_index] = i % self.nr_proc - self.tile_device_map[tile_index] = devices[ - i % len(devices) - ] # Checkoboard pattern, TODO: may not be optimal + self.tile_device_map[tile_index] = devices[i % len(devices)] # Checkoboard pattern, TODO: may not be optimal # Get my device if self.nr_proc != len(self.devices): - raise ValueError( - f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}." - ) + raise ValueError(f"Number of processes {self.nr_proc} does not equal number of devices {len(self.devices)}.") self.device = self.devices[self.pid] # Make the tiles self.tiles = {} for tile_index in self.tile_process_map.keys(): if self.pid == self.tile_process_map[tile_index]: - self.tiles[tile_index] = self.HostTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + self.tiles[tile_index] = self.HostTile(self.tile_shape, self.dtype, self.padding, self.codec) # Make GPU tiles for copying data between CPU and GPU if self.nr_tiles % self.nr_compute_tiles != 0: raise ValueError( f"Number of tiles {self.nr_tiles} does not divide number of compute tiles {self.nr_compute_tiles}. This is used for asynchronous copies." ) - compute_array_shape = [ - s + 2 * p for (s, p) in zip(self.tile_shape, self.padding) - ] + compute_array_shape = [s + 2 * p for (s, p) in zip(self.tile_shape, self.padding)] self.compute_tiles_htd = [] self.compute_tiles_dth = [] self.compute_streams_htd = [] @@ -132,13 +117,9 @@ def __init__( with cp.cuda.Device(self.device): for i in range(self.nr_compute_tiles): # Make compute tiles for copying data - compute_tile = self.DeviceTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + compute_tile = self.DeviceTile(self.tile_shape, self.dtype, self.padding, self.codec) self.compute_tiles_htd.append(compute_tile) - compute_tile = self.DeviceTile( - self.tile_shape, self.dtype, self.padding, self.codec - ) + compute_tile = self.DeviceTile(self.tile_shape, self.dtype, self.padding, self.codec) self.compute_tiles_dth.append(compute_tile) # Make cupy stream @@ -185,9 +166,7 @@ def compression_ratio(self): def update_compute_index(self): """Update the current compute index.""" - self.current_compute_index = ( - self.current_compute_index + 1 - ) % self.nr_compute_tiles + self.current_compute_index = (self.current_compute_index + 1) % self.nr_compute_tiles def _guess_next_tile_index(self, tile_index): """Guess the next tile index to use for the compute array.""" @@ -291,9 +270,7 @@ def get_compute_array(self, tile_index): compute_tile.to_array(self.compute_arrays[self.current_compute_index]) # Return the compute array index in global array - global_index = tuple( - [i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)] - ) + global_index = tuple([i * s - p for (i, s, p) in zip(tile_index, self.tile_shape, self.padding)]) return self.compute_arrays[self.current_compute_index], global_index @@ -339,12 +316,7 @@ def update_padding(self): # Loop over all padding for pad_index in pad_ind: # Get neighboring tile index - neigh_tile_index = tuple( - [ - (i + p) % s - for (i, p, s) in zip(tile_index, pad_index, self.tile_dims) - ] - ) + neigh_tile_index = tuple([(i + p) % s for (i, p, s) in zip(tile_index, pad_index, self.tile_dims)]) neigh_pad_index = tuple([-p for p in pad_index]) # flip # 4 cases: @@ -354,10 +326,7 @@ def update_padding(self): # 4. the tile and neighboring tile are on different processes # Case 1: the tile and neighboring tile are on the same process - if ( - self.pid == self.tile_process_map[tile_index] - and self.pid == self.tile_process_map[neigh_tile_index] - ): + if self.pid == self.tile_process_map[tile_index] and self.pid == self.tile_process_map[neigh_tile_index]: # Get the tile and neighboring tile tile = self.tiles[tile_index] neigh_tile = self.tiles[neigh_tile_index] @@ -371,10 +340,7 @@ def update_padding(self): neigh_tile._buf_padding[neigh_pad_index] = padding # Case 2: the tile is on this process and the neighboring tile is on another process - if ( - self.pid == self.tile_process_map[tile_index] - and self.pid != self.tile_process_map[neigh_tile_index] - ): + if self.pid == self.tile_process_map[tile_index] and self.pid != self.tile_process_map[neigh_tile_index]: # Get the tile and padding tile = self.tiles[tile_index] padding = tile._padding[pad_index] @@ -387,10 +353,7 @@ def update_padding(self): ) # Case 3: the tile is on another process and the neighboring tile is on this process - if ( - self.pid != self.tile_process_map[tile_index] - and self.pid == self.tile_process_map[neigh_tile_index] - ): + if self.pid != self.tile_process_map[tile_index] and self.pid == self.tile_process_map[neigh_tile_index]: # Get the neighboring tile and padding neigh_tile = self.tiles[neigh_tile_index] neigh_padding = neigh_tile._buf_padding[neigh_pad_index] @@ -403,10 +366,7 @@ def update_padding(self): ) # Case 4: the tile and neighboring tile are on different processes - if ( - self.pid != self.tile_process_map[tile_index] - and self.pid != self.tile_process_map[neigh_tile_index] - ): + if self.pid != self.tile_process_map[tile_index] and self.pid != self.tile_process_map[neigh_tile_index]: pass # Increment the communication tag @@ -429,12 +389,7 @@ def get_array(self): comm_tag = 0 for tile_index in self.tile_process_map.keys(): # Set the center array in the full array - slice_index = tuple( - [ - slice(i * s, (i + 1) * s) - for (i, s) in zip(tile_index, self.tile_shape) - ] - ) + slice_index = tuple([slice(i * s, (i + 1) * s) for (i, s) in zip(tile_index, self.tile_shape)]) # if tile on this process compute the center array if self.comm.rank == self.tile_process_map[tile_index]: @@ -465,18 +420,13 @@ def get_array(self): if self.comm.rank == 0 and self.tile_process_map[tile_index] != 0: # Get the data from the other rank center_array = np.empty(self.tile_shape, dtype=self.dtype) - self.comm.Recv( - center_array, source=self.tile_process_map[tile_index], tag=comm_tag - ) + self.comm.Recv(center_array, source=self.tile_process_map[tile_index], tag=comm_tag) # Set the center array in the full array array[slice_index] = center_array # Case 3: the tile is on this rank and this process is not rank 0 - if ( - self.comm.rank != 0 - and self.tile_process_map[tile_index] == self.comm.rank - ): + if self.comm.rank != 0 and self.tile_process_map[tile_index] == self.comm.rank: # Send the data to rank 0 self.comm.Send(center_array, dest=0, tag=comm_tag) diff --git a/xlb/experimental/ooc/out_of_core.py b/xlb/experimental/ooc/out_of_core.py index 01851e8..bc42fab 100644 --- a/xlb/experimental/ooc/out_of_core.py +++ b/xlb/experimental/ooc/out_of_core.py @@ -1,17 +1,11 @@ # Out-of-core decorator for functions that take a lot of memory -import functools -import warp as wp import cupy as cp -import jax.dlpack as jdlpack -import jax -import numpy as np from xlb.experimental.ooc.ooc_array import OOCArray from xlb.experimental.ooc.utils import ( _cupy_to_backend, _backend_to_cupy, - _stream_to_backend, ) @@ -47,9 +41,7 @@ def wrapper(*args): # TODO: Add better checks for ooc_array in ooc_array_args: if ooc_array_args[0].tile_dims != ooc_array.tile_dims: - raise ValueError( - f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}" - ) + raise ValueError(f"Tile dimensions of ooc arrays do not match. {ooc_array_args[0].tile_dims} != {ooc_array.tile_dims}") # Apply the function to each of the ooc arrays for tile_index in ooc_array_args[0].tiles.keys(): @@ -79,9 +71,7 @@ def wrapper(*args): results = (results,) # Convert the results back to cupy arrays - results = tuple( - [_backend_to_cupy(result, backend) for result in results] - ) + results = tuple([_backend_to_cupy(result, backend) for result in results]) # Write the results back to the ooc array for arg_index, result in zip(ref_args, results): diff --git a/xlb/experimental/ooc/tiles/compressed_tile.py b/xlb/experimental/ooc/tiles/compressed_tile.py index 415f83b..ccdd2bb 100644 --- a/xlb/experimental/ooc/tiles/compressed_tile.py +++ b/xlb/experimental/ooc/tiles/compressed_tile.py @@ -1,9 +1,6 @@ import numpy as np import cupy as cp -import itertools -from dataclasses import dataclass import warnings -import time try: from kvikio._lib.arr import asarray @@ -98,9 +95,7 @@ def compression_ratio(self): # Get total number of bytes in uncompressed tile total_bytes_uncompressed = np.prod(self.shape) * self.dtype_itemsize for pad_ind in self.pad_ind: - total_bytes_uncompressed += ( - np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize - ) + total_bytes_uncompressed += np.prod(self._padding_shape[pad_ind]) * self.dtype_itemsize # Return compression ratio return total_bytes_uncompressed, total_bytes @@ -147,9 +142,7 @@ def to_gpu_tile(self, dst_gpu_tile): """Copy tile to a GPU tile.""" # Check tile is Compressed - assert isinstance( - dst_gpu_tile, CompressedGPUTile - ), "Destination tile must be a CompressedGPUTile" + assert isinstance(dst_gpu_tile, CompressedGPUTile), "Destination tile must be a CompressedGPUTile" # Copy array dst_gpu_tile._array[: len(self._array.array)].set(self._array.array) @@ -157,9 +150,7 @@ def to_gpu_tile(self, dst_gpu_tile): # Copy padding for pad_ind in self.pad_ind: - dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set( - self._padding[pad_ind].array - ) + dst_gpu_tile._padding[pad_ind][: len(self._padding[pad_ind].array)].set(self._padding[pad_ind].array) dst_gpu_tile._padding_bytes[pad_ind] = self._padding[pad_ind].nbytes @@ -186,9 +177,7 @@ def allocate_array(self, shape): """Returns a cupy array with the given shape.""" nbytes = np.prod(shape) * self.dtype_itemsize codec = self.codec() - max_compressed_buffer = codec._manager.configure_compression(nbytes)[ - "max_compressed_buffer_size" - ] + max_compressed_buffer = codec._manager.configure_compression(nbytes)["max_compressed_buffer_size"] array = cp.zeros((max_compressed_buffer,), dtype=np.uint8) return array @@ -198,9 +187,7 @@ def to_array(self, array): # Copy center array if self._array_codec is None: self._array_codec = self.codec() - self._array_codec._manager.configure_decompression_with_compressed_buffer( - asarray(self._array[: self._array_bytes]) - ) + self._array_codec._manager.configure_decompression_with_compressed_buffer(asarray(self._array[: self._array_bytes])) self._array_codec.decompression_config = self._array_codec._manager.configure_decompression_with_compressed_buffer( asarray(self._array[: self._array_bytes]) ) @@ -217,17 +204,13 @@ def to_array(self, array): self._padding_codec[pad_ind] = self.codec() self._padding_codec[pad_ind].decompression_config = self._padding_codec[ pad_ind - ]._manager.configure_decompression_with_compressed_buffer( - asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]]) - ) + ]._manager.configure_decompression_with_compressed_buffer(asarray(self._padding[pad_ind][: self._padding_bytes[pad_ind]])) self.dense_gpu_tile._padding[pad_ind] = _decode( self._padding[pad_ind][: self._padding_bytes[pad_ind]], self.dense_gpu_tile._padding[pad_ind], self._padding_codec[pad_ind], ) - array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[ - pad_ind - ] + array[self._slice_padding_to_array[pad_ind]] = self.dense_gpu_tile._padding[pad_ind] def from_array(self, array): """Copy a full array to tile.""" @@ -236,17 +219,13 @@ def from_array(self, array): if self._array_codec is None: self._array_codec = self.codec() self._array_codec.configure_compression(self._array.nbytes) - self._array_bytes = _encode( - array[self._slice_center], self._array, self._array_codec - ) + self._array_bytes = _encode(array[self._slice_center], self._array, self._array_codec) # Copy padding for pad_ind in self.pad_ind: if pad_ind not in self._padding_codec: self._padding_codec[pad_ind] = self.codec() - self._padding_codec[pad_ind].configure_compression( - self._padding[pad_ind].nbytes - ) + self._padding_codec[pad_ind].configure_compression(self._padding[pad_ind].nbytes) self._padding_bytes[pad_ind] = _encode( array[self._slice_array_to_padding[pad_ind]], self._padding[pad_ind], @@ -257,9 +236,7 @@ def to_cpu_tile(self, dst_cpu_tile): """Copy tile to a CPU tile.""" # Check tile is Compressed - assert isinstance( - dst_cpu_tile, CompressedCPUTile - ), "Destination tile must be a CompressedCPUTile" + assert isinstance(dst_cpu_tile, CompressedCPUTile), "Destination tile must be a CompressedCPUTile" # Copy array dst_cpu_tile._array.resize(self._array_bytes) @@ -268,6 +245,4 @@ def to_cpu_tile(self, dst_cpu_tile): # Copy padding for pad_ind in self.pad_ind: dst_cpu_tile._padding[pad_ind].resize(self._padding_bytes[pad_ind]) - self._padding[pad_ind][: self._padding_bytes[pad_ind]].get( - out=dst_cpu_tile._padding[pad_ind].array - ) + self._padding[pad_ind][: self._padding_bytes[pad_ind]].get(out=dst_cpu_tile._padding[pad_ind].array) diff --git a/xlb/experimental/ooc/tiles/dense_tile.py b/xlb/experimental/ooc/tiles/dense_tile.py index 8a303e4..41fc129 100644 --- a/xlb/experimental/ooc/tiles/dense_tile.py +++ b/xlb/experimental/ooc/tiles/dense_tile.py @@ -1,7 +1,5 @@ import numpy as np import cupy as cp -import itertools -from dataclasses import dataclass from xlb.experimental.ooc.tiles.tile import Tile @@ -46,9 +44,7 @@ def allocate_array(self, shape): """Returns a cupy array with the given shape.""" # TODO: Seems hacky, but it works. Is there a better way? mem = cp.cuda.alloc_pinned_memory(np.prod(shape) * self.dtype_itemsize) - array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape( - shape - ) + array = np.frombuffer(mem, dtype=self.dtype, count=np.prod(shape)).reshape(shape) self.nbytes += mem.size() return array @@ -62,9 +58,7 @@ def to_gpu_tile(self, dst_gpu_tile): dst_gpu_tile._array.set(self._array) # Copy padding - for src_array, dst_gpu_array in zip( - self._padding.values(), dst_gpu_tile._padding.values() - ): + for src_array, dst_gpu_array in zip(self._padding.values(), dst_gpu_tile._padding.values()): dst_gpu_array.set(src_array) @@ -90,7 +84,5 @@ def to_cpu_tile(self, dst_cpu_tile): self._array.get(out=dst_cpu_tile._array) # Copy padding - for src_array, dst_array in zip( - self._padding.values(), dst_cpu_tile._padding.values() - ): + for src_array, dst_array in zip(self._padding.values(), dst_cpu_tile._padding.values()): src_array.get(out=dst_array) diff --git a/xlb/experimental/ooc/tiles/dynamic_array.py b/xlb/experimental/ooc/tiles/dynamic_array.py index 2b05b2e..403d164 100644 --- a/xlb/experimental/ooc/tiles/dynamic_array.py +++ b/xlb/experimental/ooc/tiles/dynamic_array.py @@ -3,7 +3,6 @@ import math import cupy as cp import numpy as np -import time class DynamicArray: @@ -46,17 +45,12 @@ def resize(self, nbytes): self.nbytes = nbytes # Check if the number of bytes requested is less than 2xbytes_resize or if the number of bytes requested exceeds the allocated number of bytes - if ( - nbytes < (self.allocated_bytes - 2 * self.bytes_resize) - or nbytes > self.allocated_bytes - ): + if nbytes < (self.allocated_bytes - 2 * self.bytes_resize) or nbytes > self.allocated_bytes: ## Free the memory # del self.mem # Set the new number of allocated bytes - self.allocated_bytes = ( - math.ceil(nbytes / self.bytes_resize) * self.bytes_resize - ) + self.allocated_bytes = math.ceil(nbytes / self.bytes_resize) * self.bytes_resize # Allocate the memory self.mem = cp.cuda.alloc_pinned_memory(self.allocated_bytes) diff --git a/xlb/experimental/ooc/tiles/tile.py b/xlb/experimental/ooc/tiles/tile.py index 90c3334..9bb347b 100644 --- a/xlb/experimental/ooc/tiles/tile.py +++ b/xlb/experimental/ooc/tiles/tile.py @@ -1,7 +1,5 @@ -import numpy as np import cupy as cp import itertools -from dataclasses import dataclass class Tile: @@ -25,9 +23,7 @@ def __init__(self, shape, dtype, padding, codec=None): self.padding = padding self.dtype_itemsize = cp.dtype(self.dtype).itemsize self.nbytes = 0 # Updated when array is allocated - self.codec = ( - codec # Codec to use for compression TODO: Find better abstraction for this - ) + self.codec = codec # Codec to use for compression TODO: Find better abstraction for this # Make center array self._array = self.allocate_array(self.shape) @@ -59,9 +55,7 @@ def __init__(self, shape, dtype, padding, codec=None): self._buf_padding[ind] = self.allocate_array(shape) # Get slicing for array copies - self._slice_center = tuple( - [slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)] - ) + self._slice_center = tuple([slice(pad, pad + shape) for (pad, shape) in zip(self.padding, self.shape)]) self._slice_padding_to_array = {} self._slice_array_to_padding = {} self._padding_shape = {} diff --git a/xlb/experimental/ooc/utils.py b/xlb/experimental/ooc/utils.py index f607128..1179c76 100644 --- a/xlb/experimental/ooc/utils.py +++ b/xlb/experimental/ooc/utils.py @@ -70,7 +70,7 @@ def _stream_to_backend(stream, backend): # Convert stream to backend stream if backend == "jax": raise ValueError("Jax currently does not support streams") - elif backend == "warp": + if backend == "warp": backend_stream = wp.Stream(cuda_stream=stream.ptr) elif backend == "cupy": backend_stream = stream diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index 692b453..7d9ec24 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1 +1 @@ -from xlb.grid.grid import grid_factory \ No newline at end of file +from xlb.grid.grid import grid_factory as grid_factory diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 483386d..7d8a678 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,14 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Tuple +from typing import Tuple from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import Precision -def grid_factory( - shape: Tuple[int, ...], - compute_backend: ComputeBackend = None): +def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None): compute_backend = compute_backend or DefaultConfig.default_backend if compute_backend == ComputeBackend.WARP: from xlb.grid.warp_grid import WarpGrid @@ -38,16 +35,17 @@ def _bounding_box_indices(self): """ This function calculates the indices of the bounding box of a 2D or 3D grid. The bounding box is defined as the set of grid points on the outer edge of the grid. - + Returns ------- boundingBox (dict): A dictionary where keys are the names of the bounding box faces ("bottom", "top", "left", "right" for 2D; additional "front", "back" for 3D), and values are numpy arrays of indices corresponding to each face. """ - def to_tuple(list): - d = len(list[0]) - return [tuple([sublist[i] for sublist in list]) for i in range(d)] + + def to_tuple(lst): + d = len(lst[0]) + return [tuple([sublist[i] for sublist in lst]) for i in range(d)] if self.dim == 2: # For a 2D grid, the bounding box consists of four edges: bottom, top, left, and right. @@ -58,9 +56,9 @@ def to_tuple(list): "bottom": to_tuple([[i, 0] for i in range(nx)]), "top": to_tuple([[i, ny - 1] for i in range(nx)]), "left": to_tuple([[0, i] for i in range(ny)]), - "right": to_tuple([[nx - 1, i] for i in range(ny)]) + "right": to_tuple([[nx - 1, i] for i in range(ny)]), } - + elif self.dim == 3: # For a 3D grid, the bounding box consists of six faces: bottom, top, left, right, front, and back. # Each face is represented as an array of indices. For example, the bottom face includes all points @@ -72,7 +70,6 @@ def to_tuple(list): "left": to_tuple([[0, j, k] for j in range(ny) for k in range(nz)]), "right": to_tuple([[nx - 1, j, k] for j in range(ny) for k in range(nz)]), "front": to_tuple([[i, 0, k] for i in range(nx) for k in range(nz)]), - "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]) + "back": to_tuple([[i, ny - 1, k] for i in range(nx) for k in range(nz)]), } - return - + return diff --git a/xlb/grid/jax_grid.py b/xlb/grid/jax_grid.py index 289790a..24eeb03 100644 --- a/xlb/grid/jax_grid.py +++ b/xlb/grid/jax_grid.py @@ -1,19 +1,16 @@ -from typing import Any, Literal, Optional, Tuple +from typing import Literal from jax.sharding import PartitionSpec as P from jax.sharding import NamedSharding, Mesh from jax.experimental import mesh_utils -from jax.experimental.shard_map import shard_map from xlb.compute_backend import ComputeBackend import jax.numpy as jnp -from jax import lax import jax from xlb import DefaultConfig from .grid import Grid -from xlb.operator import Operator from xlb.precision_policy import Precision @@ -25,9 +22,7 @@ def _initialize_backend(self): self.nDevices = jax.device_count() self.backend = jax.default_backend() self.device_mesh = ( - mesh_utils.create_device_mesh((1, self.nDevices, 1)) - if self.dim == 2 - else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) + mesh_utils.create_device_mesh((1, self.nDevices, 1)) if self.dim == 2 else mesh_utils.create_device_mesh((1, self.nDevices, 1, 1)) ) self.global_mesh = ( Mesh(self.device_mesh, axis_names=("cardinality", "x", "y")) @@ -53,9 +48,7 @@ def create_field( dtype = dtype.jax_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.jax_dtype - for d, index in self.sharding.addressable_devices_indices_map( - full_shape - ).items(): + for d, index in self.sharding.addressable_devices_indices_map(full_shape).items(): jax.default_device = d if fill_value: x = jnp.full(device_shape, fill_value, dtype=dtype) @@ -63,6 +56,4 @@ def create_field( x = jnp.zeros(shape=device_shape, dtype=dtype) arrays += [jax.device_put(x, d)] jax.default_device = jax.devices()[0] - return jax.make_array_from_single_device_arrays( - full_shape, self.sharding, arrays - ) \ No newline at end of file + return jax.make_array_from_single_device_arrays(full_shape, self.sharding, arrays) diff --git a/xlb/grid/warp_grid.py b/xlb/grid/warp_grid.py index 75c3f14..5018962 100644 --- a/xlb/grid/warp_grid.py +++ b/xlb/grid/warp_grid.py @@ -1,13 +1,10 @@ -from dataclasses import field import warp as wp from .grid import Grid -from xlb.operator import Operator from xlb.precision_policy import Precision from xlb.compute_backend import ComputeBackend from typing import Literal from xlb import DefaultConfig -import numpy as np class WarpGrid(Grid): @@ -23,11 +20,7 @@ def create_field( dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, fill_value=None, ): - dtype = ( - dtype.wp_dtype - if dtype - else DefaultConfig.default_precision_policy.store_precision.wp_dtype - ) + dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype shape = (cardinality,) + (self.shape) if fill_value is None: diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index 29ac3f6..92d3583 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,2 +1,2 @@ -from xlb.helper.nse_solver import create_nse_fields -from xlb.helper.initializers import initialize_eq \ No newline at end of file +from xlb.helper.nse_solver import create_nse_fields as create_nse_fields +from xlb.helper.initializers import initialize_eq as initialize_eq diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_solver.py index 4462edc..a42c6ac 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_solver.py @@ -1,21 +1,13 @@ -import xlb -from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig from xlb.grid import grid_factory from xlb.precision_policy import Precision from typing import Tuple -def create_nse_fields( - grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None -): - velocity_set = velocity_set if velocity_set else DefaultConfig.velocity_set - compute_backend = ( - compute_backend if compute_backend else DefaultConfig.default_backend - ) - precision_policy = ( - precision_policy if precision_policy else DefaultConfig.default_precision_policy - ) +def create_nse_fields(grid_shape: Tuple[int, int, int], velocity_set=None, compute_backend=None, precision_policy=None): + velocity_set = velocity_set or DefaultConfig.velocity_set + compute_backend = compute_backend or DefaultConfig.default_backend + precision_policy = precision_policy or DefaultConfig.default_precision_policy grid = grid_factory(grid_shape, compute_backend=compute_backend) # Create fields @@ -25,4 +17,3 @@ def create_nse_fields( boundary_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) return grid, f_0, f_1, missing_mask, boundary_mask - diff --git a/xlb/operator/__init__.py b/xlb/operator/__init__.py index 02b8a59..c88ef83 100644 --- a/xlb/operator/__init__.py +++ b/xlb/operator/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.operator import Operator -from xlb.operator.parallel_operator import ParallelOperator +from xlb.operator.operator import Operator as Operator +from xlb.operator.parallel_operator import ParallelOperator as ParallelOperator diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 275a074..1fd2152 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,8 +1,8 @@ -from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition +from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition as BoundaryCondition from xlb.operator.boundary_condition.boundary_condition_registry import ( - BoundaryConditionRegistry, + BoundaryConditionRegistry as BoundaryConditionRegistry, ) -from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC -from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC -from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC -from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC +from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC as EquilibriumBC +from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC +from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC +from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index 38697c3..a57e427 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -6,7 +6,7 @@ from jax import jit from functools import partial import warp as wp -from typing import Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -34,7 +34,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): super().__init__( ImplementationStep.STREAMING, @@ -53,9 +53,7 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 0af61e1..f1d85d3 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -35,11 +35,11 @@ def __init__( self, rho: float, u: Tuple[float, float, float], - equilibrium_operator : Operator = None, + equilibrium_operator: Operator = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): # Store the equilibrium information self.rho = rho @@ -73,14 +73,8 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _rho = wp.float32(self.rho) - _u = ( - _u_vec(self.u[0], self.u[1], self.u[2]) - if self.velocity_set.d == 3 - else _u_vec(self.u[0], self.u[1]) - ) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index a445e07..5b2f2c1 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -6,7 +6,7 @@ from jax import jit from functools import partial import warp as wp -from typing import Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -33,7 +33,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): super().__init__( ImplementationStep.COLLISION, @@ -48,16 +48,14 @@ def __init__( def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask): boundary = boundary_mask == self.id boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0) - return jnp.where(boundary, f_pre[self.velocity_set.opp_indices,...], f_post) + return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _opp_indices = self.velocity_set.wp_opp_indices _q = wp.constant(self.velocity_set.q) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Construct the funcional to get streamed indices @wp.func diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 55472a3..1fa4a7c 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -3,12 +3,10 @@ """ import jax.numpy as jnp -from jax import jit, device_count -import jax.lax as lax +from jax import jit from functools import partial -import numpy as np import warp as wp -from typing import Tuple, Any, List +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -37,7 +35,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): # Call the parent constructor super().__init__( @@ -64,9 +62,7 @@ def _construct_warp(self): _c = self.velocity_set.wp_c _opp_indices = self.velocity_set.wp_opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool @wp.func def functional2d( diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index dbeadbc..125e45d 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -2,12 +2,7 @@ Base class for boundary conditions in a LBM simulation. """ -import jax.numpy as jnp -from jax import jit, device_count -from functools import partial -import numpy as np from enum import Enum, auto -from typing import List from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -15,6 +10,7 @@ from xlb.operator.operator import Operator from xlb import DefaultConfig + # Enum for implementation step class ImplementationStep(Enum): COLLISION = auto() @@ -32,7 +28,7 @@ def __init__( velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, - indices = None, + indices=None, ): velocity_set = velocity_set or DefaultConfig.velocity_set precision_policy = precision_policy or DefaultConfig.default_precision_policy diff --git a/xlb/operator/boundary_condition/boundary_condition_registry.py b/xlb/operator/boundary_condition/boundary_condition_registry.py index 0a3b2c7..5b1e092 100644 --- a/xlb/operator/boundary_condition/boundary_condition_registry.py +++ b/xlb/operator/boundary_condition/boundary_condition_registry.py @@ -19,11 +19,11 @@ def register_boundary_condition(self, boundary_condition): """ Register a boundary condition. """ - id = self.next_id + _id = self.next_id self.next_id += 1 - self.id_to_bc[id] = boundary_condition - self.bc_to_id[boundary_condition] = id - return id + self.id_to_bc[_id] = boundary_condition + self.bc_to_id[boundary_condition] = _id + return _id boundary_condition_registry = BoundaryConditionRegistry() diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index cc80b85..262e638 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,6 +1,6 @@ from xlb.operator.boundary_masker.indices_boundary_masker import ( - IndicesBoundaryMasker, + IndicesBoundaryMasker as IndicesBoundaryMasker, ) from xlb.operator.boundary_masker.stl_boundary_masker import ( - STLBoundaryMasker, + STLBoundaryMasker as STLBoundaryMasker, ) diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 460bd3b..7960083 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -22,41 +22,32 @@ def __init__( # Call super super().__init__(velocity_set, precision_policy, compute_backend) - @Operator.register_backend(ComputeBackend.JAX) - # TODO HS: figure out why uncommenting the line below fails unlike other operators! + # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) - def jax_implementation( - self, bclist, boundary_mask, mask, start_index=None - ): + def jax_implementation(self, bclist, boundary_mask, mask, start_index=None): # define a helper function def compute_boundary_id_and_mask(boundary_mask, mask): if dim == 2: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1] - ].set(id_number) + boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1]].set(id_number) mask = mask.at[:, local_indices[0], local_indices[1]].set(True) if dim == 3: - boundary_mask = boundary_mask.at[ - 0, local_indices[0], local_indices[1], local_indices[2] - ].set(id_number) - mask = mask.at[ - :, local_indices[0], local_indices[1], local_indices[2] - ].set(True) + boundary_mask = boundary_mask.at[0, local_indices[0], local_indices[1], local_indices[2]].set(id_number) + mask = mask.at[:, local_indices[0], local_indices[1], local_indices[2]].set(True) return boundary_mask, mask - + dim = mask.ndim - 1 if start_index is None: start_index = (0,) * dim for bc in bclist: - assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC!' + assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" id_number = bc.id local_indices = np.array(bc.indices) - np.array(start_index)[:, np.newaxis] boundary_mask, mask = compute_boundary_id_and_mask(boundary_mask, mask) # We are done with bc.indices. Remove them from BC objects - bc.__dict__.pop('indices', None) + bc.__dict__.pop("indices", None) mask = self.stream(mask) return boundary_mask, mask @@ -84,12 +75,7 @@ def kernel2d( index[1] = indices[1, ii] - start_index[1] # Check if in bounds - if ( - index[0] >= 0 - and index[0] < mask.shape[1] - and index[1] >= 0 - and index[1] < mask.shape[2] - ): + if index[0] >= 0 and index[0] < mask.shape[1] and index[1] >= 0 and index[1] < mask.shape[2]: # Stream indices for l in range(_q): # Get the index of the streaming direction @@ -146,10 +132,7 @@ def kernel3d( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation( - self, bclist, boundary_mask, missing_mask, start_index=None - ): - + def warp_implementation(self, bclist, boundary_mask, missing_mask, start_index=None): dim = self.velocity_set.d index_list = [[] for _ in range(dim)] id_list = [] @@ -159,10 +142,10 @@ def warp_implementation( index_list[d] += bc.indices[d] id_list += [bc.id] * len(bc.indices[0]) # We are done with bc.indices. Remove them from BC objects - bc.__dict__.pop('indices', None) - - indices = wp.array2d(index_list, dtype = wp.int32) - id_number = wp.array1d(id_list, dtype = wp.uint8) + bc.__dict__.pop("indices", None) + + indices = wp.array2d(index_list, dtype=wp.int32) + id_number = wp.array1d(id_list, dtype=wp.uint8) if start_index is None: start_index = (0,) * dim diff --git a/xlb/operator/boundary_masker/stl_boundary_masker.py b/xlb/operator/boundary_masker/stl_boundary_masker.py index c2cfc30..b4ea8ca 100644 --- a/xlb/operator/boundary_masker/stl_boundary_masker.py +++ b/xlb/operator/boundary_masker/stl_boundary_masker.py @@ -1,19 +1,13 @@ # Base class for all equilibriums -from functools import partial import numpy as np from stl import mesh as np_mesh -import jax.numpy as jnp -from jax import jit import warp as wp -from typing import Tuple -from xlb import DefaultConfig from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator -from xlb.operator.stream.stream import Stream class STLBoundaryMasker(Operator): @@ -56,9 +50,7 @@ def kernel( index[2] = k - start_index[2] # position of the point - ijk = wp.vec3( - wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]) - ) + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) ijk = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center pos = wp.cw_mul(ijk, spacing) + origin @@ -74,9 +66,7 @@ def kernel( face_u = float(0.0) face_v = float(0.0) sign = float(0.0) - if wp.mesh_query_point_sign_winding_number( - mesh, pos, max_length, sign, face_index, face_u, face_v - ): + if wp.mesh_query_point_sign_winding_number(mesh, pos, max_length, sign, face_index, face_u, face_v): # set point to be solid if sign <= 0: # TODO: fix this # Stream indices @@ -87,9 +77,7 @@ def kernel( push_index[d] = index[d] + _c[d, l] # Set the boundary id and mask - boundary_mask[ - 0, push_index[0], push_index[1], push_index[2] - ] = wp.uint8(id_number) + boundary_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) mask[l, push_index[0], push_index[1], push_index[2]] = True return None, kernel diff --git a/xlb/operator/collision/__init__.py b/xlb/operator/collision/__init__.py index 77395e6..b48d0ce 100644 --- a/xlb/operator/collision/__init__.py +++ b/xlb/operator/collision/__init__.py @@ -1,3 +1,3 @@ -from xlb.operator.collision.collision import Collision -from xlb.operator.collision.bgk import BGK -from xlb.operator.collision.kbc import KBC +from xlb.operator.collision.collision import Collision as Collision +from xlb.operator.collision.bgk import BGK as BGK +from xlb.operator.collision.kbc import KBC as KBC diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index ec40b56..fa0857a 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -70,15 +70,13 @@ def jax_implementation( shear = self.decompose_shear_d3q27_jax(fneq) delta_s = shear * rho else: - raise NotImplementedError( - "Velocity set not supported: {}".format(type(self.velocity_set)) - ) + raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set))) # Perform collision delta_h = fneq - delta_s - gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product( - delta_s, delta_h, feq - ) / (self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)) + gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / ( + self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - self.beta * (2.0 * delta_s + gamma[None, ...] * delta_h) @@ -206,11 +204,7 @@ def decompose_shear_d2q9_jax(self, fneq): def _construct_warp(self): # Raise error if velocity set is not supported if not isinstance(self.velocity_set, D3Q27): - raise NotImplementedError( - "Velocity set not supported for warp backend: {}".format( - type(self.velocity_set) - ) - ) + raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set))) # Set local constants TODO: This is a hack and should be fixed with warp update _w = self.velocity_set.wp_w @@ -323,9 +317,9 @@ def functional2d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( - delta_s, delta_h, feq - ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + _epsilon + entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - _beta * (2.0 * delta_s + gamma * delta_h) return fout @@ -345,9 +339,9 @@ def functional3d( # Perform collision delta_h = fneq - delta_s - gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product( - delta_s, delta_h, feq - ) / (_epsilon + entropic_scalar_product(delta_h, delta_h, feq)) + gamma = _inv_beta - (2.0 - _inv_beta) * entropic_scalar_product(delta_s, delta_h, feq) / ( + _epsilon + entropic_scalar_product(delta_h, delta_h, feq) + ) fout = f - _beta * (2.0 * delta_s + gamma * delta_h) return fout @@ -362,12 +356,13 @@ def kernel2d( fout: wp.array3d(dtype=Any), ): # Get the global index - i, j, k = wp.tid() + i, j = wp.tid() index = wp.vec3i(i, j) # TODO: Warp needs to fix this # Load needed values _f = _f_vec() _feq = _f_vec() + _d = self.velocity_set.d for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1]] _feq[l] = feq[l, index[0], index[1]] @@ -399,6 +394,7 @@ def kernel3d( # Load needed values _f = _f_vec() _feq = _f_vec() + _d = self.velocity_set.d for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index 42b601e..b9f9f08 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1,4 +1,4 @@ from xlb.operator.equilibrium.quadratic_equilibrium import ( - Equilibrium, - QuadraticEquilibrium, + Equilibrium as Equilibrium, + QuadraticEquilibrium as QuadraticEquilibrium, ) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 794d78d..3af6b4a 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -4,11 +4,10 @@ import warp as wp from typing import Any -from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium.equilibrium import Equilibrium from xlb.operator import Operator -from xlb import DefaultConfig + class QuadraticEquilibrium(Equilibrium): """ diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 91eb36c..3463078 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -1 +1 @@ -from xlb.operator.macroscopic.macroscopic import Macroscopic +from xlb.operator.macroscopic.macroscopic import Macroscopic as Macroscopic diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index 7fa309f..13d3817 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -4,10 +4,8 @@ import jax.numpy as jnp from jax import jit import warp as wp -from typing import Tuple, Any +from typing import Any -from xlb import DefaultConfig -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index 38e8e15..83c6538 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,12 +1,7 @@ -# Base class for all operators, (collision, streaming, equilibrium, etc.) - import inspect -import warp as wp -from typing import Any import traceback from xlb.compute_backend import ComputeBackend -from xlb.precision_policy import PrecisionPolicy, Precision from xlb import DefaultConfig @@ -22,9 +17,7 @@ class Operator: def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): # Set the default values from the global config self.velocity_set = velocity_set or DefaultConfig.velocity_set - self.precision_policy = ( - precision_policy or DefaultConfig.default_precision_policy - ) + self.precision_policy = precision_policy or DefaultConfig.default_precision_policy self.compute_backend = compute_backend or DefaultConfig.default_backend # Check if the compute backend is supported @@ -52,17 +45,13 @@ def decorator(func): def __call__(self, *args, callback=None, **kwargs): method_candidates = [ - (key, method) - for key, method in self._backends.items() - if key[0] == self.__class__.__name__ and key[1] == self.compute_backend + (key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend ] bound_arguments = None for key, backend_method in method_candidates: try: # This attempts to bind the provided args and kwargs to the backend method's signature - bound_arguments = inspect.signature(backend_method).bind( - self, *args, **kwargs - ) + bound_arguments = inspect.signature(backend_method).bind(self, *args, **kwargs) bound_arguments.apply_defaults() # This fills in any default values result = backend_method(self, *args, **kwargs) callback_arg = result if result is not None else (args, kwargs) @@ -74,9 +63,7 @@ def __call__(self, *args, callback=None, **kwargs): traceback_str = traceback.format_exc() continue # This skips to the next candidate if binding fails - raise Exception( - f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}" - ) + raise Exception(f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}") @property def supported_compute_backend(self): diff --git a/xlb/operator/parallel_operator.py b/xlb/operator/parallel_operator.py index 9309b21..9f9b5c5 100644 --- a/xlb/operator/parallel_operator.py +++ b/xlb/operator/parallel_operator.py @@ -65,12 +65,8 @@ def _parallel_func(self, f): jax.numpy.ndarray The processed data. """ - rightPerm = [ - (i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices) - ] - leftPerm = [ - ((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices) - ] + rightPerm = [(i, (i + 1) % self.grid.nDevices) for i in range(self.grid.nDevices)] + leftPerm = [((i + 1) % self.grid.nDevices, i) for i in range(self.grid.nDevices)] f = self.func(f) left_comm, right_comm = ( f[self.velocity_set.right_indices, :1, ...], diff --git a/xlb/operator/precision_caster/__init__.py b/xlb/operator/precision_caster/__init__.py index a027c52..c333ab7 100644 --- a/xlb/operator/precision_caster/__init__.py +++ b/xlb/operator/precision_caster/__init__.py @@ -1 +1 @@ -from xlb.operator.precision_caster.precision_caster import PrecisionCaster +from xlb.operator.precision_caster.precision_caster import PrecisionCaster as PrecisionCaster diff --git a/xlb/operator/precision_caster/precision_caster.py b/xlb/operator/precision_caster/precision_caster.py index cb441c5..5427cba 100644 --- a/xlb/operator/precision_caster/precision_caster.py +++ b/xlb/operator/precision_caster/precision_caster.py @@ -3,10 +3,9 @@ """ import jax.numpy as jnp -from jax import jit, device_count +from jax import jit +import warp as wp from functools import partial -import numpy as np -from enum import Enum from xlb.operator.operator import Operator from xlb.velocity_set import VelocitySet diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py index e5d159c..528375d 100644 --- a/xlb/operator/stepper/__init__.py +++ b/xlb/operator/stepper/__init__.py @@ -1,2 +1,2 @@ -from xlb.operator.stepper.stepper import Stepper -from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper +from xlb.operator.stepper.stepper import Stepper as Stepper +from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper as IncompressibleNavierStokesStepper diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 3f986d3..3fad2b1 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -1,13 +1,11 @@ # Base class for all stepper operators -from logging import warning from functools import partial from jax import jit import warp as wp from typing import Any from xlb import DefaultConfig -from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator import Operator from xlb.operator.stream import Stream @@ -32,9 +30,7 @@ def __init__(self, omega, boundary_conditions=[], collision_type="BGK"): # Construct the operators self.stream = Stream(velocity_set, precision_policy, compute_backend) - self.equilibrium = QuadraticEquilibrium( - velocity_set, precision_policy, compute_backend - ) + self.equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend) self.macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend) operators = [self.macroscopic, self.equilibrium, self.collision, self.stream] @@ -91,9 +87,7 @@ def jax_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec( - self.velocity_set.q, dtype=wp.uint8 - ) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool # Get the boundary condition ids _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) @@ -129,19 +123,13 @@ def kernel2d( f_post_stream = self.stream.warp_functional(f_0, index) elif _boundary_id == _equilibrium_bc: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _do_nothing_bc: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _halfway_bounce_back_bc: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -199,19 +187,13 @@ def kernel3d( f_post_stream = self.stream.warp_functional(f_0, index) elif _boundary_id == _equilibrium_bc: # Equilibrium boundary condition - f_post_stream = self.equilibrium_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _do_nothing_bc: # Do nothing boundary condition - f_post_stream = self.do_nothing_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) elif _boundary_id == _halfway_bounce_back_bc: # Half way boundary condition - f_post_stream = self.halfway_bounce_back_bc.warp_functional( - f_0, _missing_mask, index - ) + f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) # Compute rho and u rho, u = self.macroscopic.warp_functional(f_post_stream) @@ -220,9 +202,7 @@ def kernel3d( feq = self.equilibrium.warp_functional(rho, u) # Apply collision - f_post_collision = self.collision.warp_functional( - f_post_stream, feq, rho, u - ) + f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions if _boundary_id == _fullway_bounce_back_bc: diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index c11b39b..adc2564 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,16 +1,5 @@ # Base class for all stepper operators - -from ast import Raise -from functools import partial -import jax.numpy as jnp -from jax import jit -import warp as wp - -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.velocity_set import VelocitySet -from xlb.compute_backend import ComputeBackend from xlb.operator import Operator -from xlb.operator.precision_caster import PrecisionCaster from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -24,39 +13,17 @@ def __init__(self, operators, boundary_conditions): self.operators = operators self.boundary_conditions = boundary_conditions # Get velocity set, precision policy, and compute backend - velocity_sets = set( - [op.velocity_set for op in self.operators if op is not None] - ) - assert ( - len(velocity_sets) < 2 - ), "All velocity sets must be the same. Got {}".format(velocity_sets) - velocity_set = ( - DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() - ) + velocity_sets = set([op.velocity_set for op in self.operators if op is not None]) + assert len(velocity_sets) < 2, "All velocity sets must be the same. Got {}".format(velocity_sets) + velocity_set = DefaultConfig.velocity_set if not velocity_sets else velocity_sets.pop() - precision_policies = set( - [op.precision_policy for op in self.operators if op is not None] - ) - assert ( - len(precision_policies) < 2 - ), "All precision policies must be the same. Got {}".format(precision_policies) - precision_policy = ( - DefaultConfig.default_precision_policy - if not precision_policies - else precision_policies.pop() - ) + precision_policies = set([op.precision_policy for op in self.operators if op is not None]) + assert len(precision_policies) < 2, "All precision policies must be the same. Got {}".format(precision_policies) + precision_policy = DefaultConfig.default_precision_policy if not precision_policies else precision_policies.pop() - compute_backends = set( - [op.compute_backend for op in self.operators if op is not None] - ) - assert ( - len(compute_backends) < 2 - ), "All compute backends must be the same. Got {}".format(compute_backends) - compute_backend = ( - DefaultConfig.default_backend - if not compute_backends - else compute_backends.pop() - ) + compute_backends = set([op.compute_backend for op in self.operators if op is not None]) + assert len(compute_backends) < 2, "All compute backends must be the same. Got {}".format(compute_backends) + compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() # Add boundary conditions # Warp cannot handle lists of functions currently @@ -93,9 +60,7 @@ def __init__(self, operators, boundary_conditions): self.equilibrium_bc = EquilibriumBC( rho=1.0, u=(0.0, 0.0, 0.0), - equilibrium_operator=next( - (op for op in self.operators if isinstance(op, Equilibrium)), None - ), + equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None), velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, diff --git a/xlb/operator/stream/__init__.py b/xlb/operator/stream/__init__.py index 9093da7..2f5b2f3 100644 --- a/xlb/operator/stream/__init__.py +++ b/xlb/operator/stream/__init__.py @@ -1 +1 @@ -from xlb.operator.stream.stream import Stream +from xlb.operator.stream.stream import Stream as Stream diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 77cf22d..da724c2 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -6,7 +6,6 @@ import warp as wp from typing import Any -from xlb.velocity_set.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator @@ -49,9 +48,7 @@ def _streaming_jax_i(f, c): elif self.velocity_set.d == 3: return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2)) - return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)( - f, jnp.array(self.velocity_set.c).T - ) + return vmap(_streaming_jax_i, in_axes=(0, 0), out_axes=0)(f, jnp.array(self.velocity_set.c).T) def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 5a59b97..3b0f85f 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import warp as wp + class Precision(Enum): FP64 = auto() FP32 = auto() @@ -42,6 +43,7 @@ def jax_dtype(self): else: raise ValueError("Invalid precision") + class PrecisionPolicy(Enum): FP64FP64 = auto() FP64FP32 = auto() @@ -93,4 +95,4 @@ def cast_to_compute_warp(self, array): def cast_to_store_warp(self, array): store_precision = self.store_precision - return wp.array(array, dtype=store_precision.wp_dtype) \ No newline at end of file + return wp.array(array, dtype=store_precision.wp_dtype) diff --git a/xlb/precision_policy/precision_policy.py b/xlb/precision_policy/precision_policy.py index 57f65c0..6b400ee 100644 --- a/xlb/precision_policy/precision_policy.py +++ b/xlb/precision_policy/precision_policy.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig from xlb.precision_policy.jax_precision_policy import ( @@ -15,9 +14,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp64() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp64Fp32: @@ -25,9 +22,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp32() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp32Fp32: @@ -35,9 +30,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp32() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp64Fp16: @@ -45,9 +38,7 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp64Fp16() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") class Fp32Fp16: @@ -55,6 +46,4 @@ def __new__(cls): if DefaultConfig.compute_backend == ComputeBackend.JAX: return JaxFp32Fp16() else: - raise ValueError( - f"Unsupported compute backend: {DefaultConfig.compute_backend}" - ) + raise ValueError(f"Unsupported compute backend: {DefaultConfig.compute_backend}") diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py index 3c8032e..6f1f61a 100644 --- a/xlb/utils/__init__.py +++ b/xlb/utils/__init__.py @@ -1,9 +1,9 @@ from .utils import ( - downsample_field, - save_image, - save_fields_vtk, - save_BCs_vtk, - rotate_geometry, - voxelize_stl, - axangle2mat, + downsample_field as downsample_field, + save_image as save_image, + save_fields_vtk as save_fields_vtk, + save_BCs_vtk as save_BCs_vtk, + rotate_geometry as rotate_geometry, + voxelize_stl as voxelize_stl, + axangle2mat as axangle2mat, ) diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index 7b7ac78..074177e 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -1,7 +1,6 @@ import numpy as np import matplotlib.pylab as plt from matplotlib import cm -import numpy as np from time import time import pyvista as pv from jax.image import resize @@ -38,9 +37,7 @@ def downsample_field(field, factor, method="bicubic"): else: new_shape = tuple(dim // factor for dim in field.shape[:-1]) downsampled_components = [] - for i in range( - field.shape[-1] - ): # Iterate over the last dimension (vector components) + for i in range(field.shape[-1]): # Iterate over the last dimension (vector components) resized = resize(field[..., i], new_shape, method=method) downsampled_components.append(resized) @@ -66,8 +63,10 @@ def save_image(fld, timestep, prefix=None): Notes ----- - This function saves the field as an image in the PNG format. The filename is based on the name of the main script file, the provided prefix, and the timestep number. - If the field is 3D, the magnitude of the field is calculated and saved. The image is saved with the 'nipy_spectral' colormap and the origin set to 'lower'. + This function saves the field as an image in the PNG format. + The filename is based on the name of the main script file, the provided prefix, and the timestep number. + If the field is 3D, the magnitude of the field is calculated and saved. + The image is saved with the 'nipy_spectral' colormap and the origin set to 'lower'. """ if prefix is None: fname = os.path.basename(__main__.__file__) @@ -79,7 +78,7 @@ def save_image(fld, timestep, prefix=None): if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") - elif len(fld.shape) == 3: + if len(fld.shape) == 3: fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) plt.clf() @@ -118,9 +117,7 @@ def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"): if key == list(fields.keys())[0]: dimensions = value.shape else: - assert ( - value.shape == dimensions - ), "All fields must have the same dimensions!" + assert value.shape == dimensions, "All fields must have the same dimensions!" output_filename = os.path.join(output_dir, prefix + "_" + f"{timestep:07d}.vtk") @@ -231,15 +228,11 @@ def rotate_geometry(indices, origin, axis, angle): This function rotates the mesh by applying a rotation matrix to the voxel indices. The rotation matrix is calculated using the axis-angle representation of rotations. The origin of the rotation axis is assumed to be at (0, 0, 0). """ - indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat( - axis, angle - ) + origin + indices_rotated = (jnp.array(indices).T - origin) @ axangle2mat(axis, angle) + origin return tuple(jnp.rint(indices_rotated).astype("int32").T) -def voxelize_stl( - stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None -): +def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None): """ Converts an STL file to a voxelized mesh. @@ -314,10 +307,8 @@ def axangle2mat(axis, angle, is_normalized=False): xyC = x * yC yzC = y * zC zxC = z * xC - return jnp.array( - [ - [x * xC + c, xyC - zs, zxC + ys], - [xyC + zs, y * yC + c, yzC - xs], - [zxC - ys, yzC + xs, z * zC + c], - ] - ) + return jnp.array([ + [x * xC + c, xyC - zs, zxC + ys], + [xyC + zs, y * yC + c, yzC - xs], + [zxC - ys, yzC + xs, z * zC + c], + ]) diff --git a/xlb/velocity_set/__init__.py b/xlb/velocity_set/__init__.py index 5b7b737..c1338db 100644 --- a/xlb/velocity_set/__init__.py +++ b/xlb/velocity_set/__init__.py @@ -1,4 +1,4 @@ -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.velocity_set.d2q9 import D2Q9 -from xlb.velocity_set.d3q19 import D3Q19 -from xlb.velocity_set.d3q27 import D3Q27 +from xlb.velocity_set.velocity_set import VelocitySet as VelocitySet +from xlb.velocity_set.d2q9 import D2Q9 as D2Q9 +from xlb.velocity_set.d3q19 import D3Q19 as D3Q19 +from xlb.velocity_set.d3q27 import D3Q27 as D3Q27 diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 178c89e..700806c 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -12,14 +12,13 @@ class D2Q9(VelocitySet): D2Q9 stands for two-dimensional nine-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ + def __init__(self): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] c = np.array(tuple(zip(cx, cy))).T - w = np.array( - [4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36] - ) + w = np.array([4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36]) # Call the parent constructor super().__init__(2, 9, c, w) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 7f69019..97db1d9 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -13,15 +13,10 @@ class D3Q19(VelocitySet): D3Q19 stands for three-dimensional nineteen-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ + def __init__(self): # Construct the velocity vectors and weights - c = np.array( - [ - ci - for ci in itertools.product([-1, 0, 1], repeat=3) - if np.sum(np.abs(ci)) <= 2 - ] - ).T + c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) for i in range(19): if np.sum(np.abs(c[:, i])) == 0: diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index ac908eb..702acf4 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -13,6 +13,7 @@ class D3Q27(VelocitySet): D3Q27 stands for three-dimensional twenty-seven-velocity model. It is a common model used in the Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ + def __init__(self): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 6f2bf4e..47bbae4 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -2,9 +2,6 @@ import math import numpy as np -from functools import partial -import jax.numpy as jnp -from jax import jit, vmap import warp as wp @@ -48,15 +45,9 @@ def __init__(self, d, q, c, w): # Make warp constants for these vectors # TODO: Following warp updates these may not be necessary self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant( - wp.vec(self.q, dtype=wp.float32)(self.w) - ) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant( - wp.vec(self.q, dtype=wp.int32)(self.opp_indices) - ) - self.wp_cc = wp.constant( - wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc) - ) + self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow + self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) + self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -127,9 +118,7 @@ def _construct_main_indices(self): return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] elif self.d == 3: - return np.nonzero( - (np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1) - )[0] + return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] def _construct_right_indices(self): """