Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Nov 23, 2023
2 parents 8dced2b + 7b0555c commit ed256cf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 108 deletions.
2 changes: 1 addition & 1 deletion examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def output_data(self, **kwargs):
'print_info_rate': 100,
'checkpoint_rate': checkpoint_rate,
'checkpoint_dir': checkpoint_dir,
'restore_checkpoint': True,
'restore_checkpoint': False,
}

sim = Cavity(**kwargs)
Expand Down
47 changes: 30 additions & 17 deletions examples/performance/MLUPS3d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@
import argparse
import os
import jax
# Initialize JAX distributed. The IP, number of processes and process id must be updated.
# Currently set on local host for testing purposes.
# Can be tested on a two GPU system as follows:
# (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &)
#IMPORTANT: jax distributed must be initialized before any jax computation is performed
jax.distributed.initialize(f'127.0.0.1:1234', 2, process_id=int(os.environ['CUDA_VISIBLE_DEVICES']))

print('Process id: ', jax.process_index())
print('Number of total devices (over all processes): ', jax.device_count())
print('Number of local devices:', jax.local_device_count())


import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -56,17 +45,41 @@ def set_boundary_conditions(self):
self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall))

if __name__ == '__main__':
precision = 'f32/f32'
lattice = LatticeD3Q19(precision)

# Create a parser that will read the command line arguments
parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation")
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN", default=100, type=int)
parser.add_argument("N_ITERS", help="Number of timesteps", default=10000, type=int)
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN",
default=100, type=int)
parser.add_argument("N_ITERS", help="Number of iterations", default=10000, type=int)
parser.add_argument("N_PROCESSES", help="Number of processes. If >1, call jax.distributed.initialize with that number of process. If -1 will call jax.distributed.initialize without any arsgument. So it should pick up the values from SLURM env variable.",
default=1, type=int)
parser.add_argument("IP", help="IP of the master node for multi-node. Useless if using SLURM.",
default='127.0.0.1', type=str, nargs='?')
parser.add_argument("PROCESS_ID_INCREMENT", help="For multi-node only. Useless if using SLURM.",
default=0, type=int, nargs='?')

args = parser.parse_args()
n = args.N
n_iters = args.N_ITERS
n_processes = args.N_PROCESSES
# Initialize JAX distributed. The IP, number of processes and process id must be set correctly.
print("N processes, ", n_processes)
print("N iter, ", n_iters)
if n_processes > 1:
process_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + args.PROCESS_ID_INCREMENT
print("ip, num_processes, process_id, ", args.IP, n_processes, process_id)
jax.distributed.initialize(args.IP, num_processes=n_processes,
process_id=process_id)
elif n_processes == -1:
print("Will call jax.distributed.initialize()")
jax.distributed.initialize()
print("jax.distributed.initialize() ended")
else:
print("No call to jax.distributed.initialize")
print("JAX local devices: ", jax.local_devices())

precision = 'f32/f32'
# Create a 3D lattice with the D3Q19 scheme
lattice = LatticeD3Q19(precision)

# Store the Reynolds number in the variable Re
Re = 100.0
Expand All @@ -92,4 +105,4 @@ def set_boundary_conditions(self):
}

sim = Cavity(**kwargs) # Run the simulation
sim.run(n_iters)
sim.run(n_iters)
93 changes: 3 additions & 90 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def __init__(self, **kwargs):
self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh,
in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False))

self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh,
in_specs=P("x", None, None), out_specs=P("x", None, None), check_rep=False))

# Set up the sharding and streaming for 2D and 3D simulations
elif self.dim == 3:
self.devices = mesh_utils.create_device_mesh((self.nDevices, 1, 1, 1))
Expand All @@ -138,9 +135,7 @@ def __init__(self, **kwargs):

self.streaming = jit(shard_map(self.streaming_m, mesh=self.mesh,
in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False))

self.compute_bitmask = jit(shard_map(self.compute_bitmask_m, mesh=self.mesh,
in_specs=P("x", None, None, None), out_specs=P("x", None, None, None), check_rep=False))

else:
raise ValueError(f"dim = {self.dim} not supported")

Expand Down Expand Up @@ -439,7 +434,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels):
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)

connectivity_bitmask = self.compute_bitmask(connectivity_bitmask)
connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)

elif self.dim == 3:
Expand All @@ -450,7 +445,7 @@ def create_grid_connectivity_bitmask(self, solid_halo_voxels):
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)
connectivity_bitmask = self.compute_bitmask(connectivity_bitmask)
connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)

def bounding_box_indices(self):
Expand Down Expand Up @@ -686,88 +681,6 @@ def streaming_i(f, c):
return jnp.roll(f, (c[0], c[1], c[2]), axis=(0, 1, 2))

return vmap(streaming_i, in_axes=(-1, 0), out_axes=-1)(f, self.c.T)

def compute_bitmask_m(self, b):
"""
This function computes a bitmask for each direction in the lattice. The bitmask is used to
determine which nodes are fluid nodes and which are boundary nodes.
To enable multi-GPU/TPU functionality, it extracts the left and right boundary slices of the
distribution functions that need to be communicated to the neighboring processes.
The function then sends the left boundary slice to the right neighboring process and the right
boundary slice to the left neighboring process. The received data is then set to the
corresponding indices in the receiving domain.
Parameters
----------
b: jax.numpy.ndarray
The array holding the bitmasks for the simulation.
Returns
-------
jax.numpy.ndarray
The bitmasks after the streaming operation.
"""
b = self.compute_bitmask_p(b)
left_comm, right_comm = b[:1, ..., self.lattice.right_indices], b[-1:, ..., self.lattice.left_indices]

left_comm, right_comm = self.send_right(left_comm, 'x'), self.send_left(right_comm, 'x')
b = b.at[:1, ..., self.lattice.right_indices].set(left_comm)
b = b.at[-1:, ..., self.lattice.left_indices].set(right_comm)
return b

def compute_bitmask_p(self, b):
"""
This function computes a bitmask for each direction in the lattice. The bitmask is used to
determine which nodes are fluid nodes and which are boundary nodes.
It does this by rolling the input bitmask (b) in the opposite direction of each lattice
direction. The rolling operation shifts the values of the bitmask along the specified axes.
The function uses the vmap operation provided by the JAX library to vectorize the computation
over all lattice directions.
Parameters
----------
b: ndarray
The input bitmask.
Returns
-------
jax.numpy.ndarray
The computed bitmask for each direction in the lattice.
"""
def compute_bitmask_i(b, i):
"""
This function computes the bitmask for a specific direction in the lattice.
It does this by rolling the input bitmask (b) in the opposite direction of the specified
lattice direction. The rolling operation shifts the values of the bitmask along the
specified axes.
Parameters
----------
b: jax.numpy.ndarray
The input bitmask.
i: int
The index of the lattice direction.
Returns
-------
jax.numpy.ndarray
The computed bitmask for the specified direction in the lattice.
"""
if self.dim == 2:
rolls = (self.c.T[i, 0], self.c.T[i, 1])
axes = (0, 1)
return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes)
elif self.dim == 3:
rolls = (self.c.T[i, 0], self.c.T[i, 1], self.c.T[i, 2])
axes = (0, 1, 2)
return jnp.roll(b[..., self.lattice.opp_indices[i]], rolls, axes)

return vmap(compute_bitmask_i, in_axes=(None, 0), out_axes=-1)(b, self.lattice.i_s)

@partial(jit, static_argnums=(0, 3), inline=True)
def equilibrium(self, rho, u, cast_output=True):
Expand Down

0 comments on commit ed256cf

Please sign in to comment.