Skip to content

Commit

Permalink
Merge pull request #20 from nouiz/init2
Browse files Browse the repository at this point in the history
Allow multiple way to initialize from the command line.
  • Loading branch information
mehdiataei authored Oct 30, 2023
2 parents 2180c56 + e71e50d commit 30f6def
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions examples/performance/MLUPS3d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@
import argparse
import os
import jax
# Initialize JAX distributed. The IP, number of processes and process id must be updated.
# Currently set on local host for testing purposes.
# Can be tested on a two GPU system as follows:
# (export PYTHONPATH=.; CUDA_VISIBLE_DEVICES=0 python3 examples/performance/MLUPS3d_distributed.py 100 100 & CUDA_VISIBLE_DEVICES=1 python3 examples/performance/MLUPS3d_distributed.py 100 100 &)
#IMPORTANT: jax distributed must be initialized before any jax computation is performed
jax.distributed.initialize(f'127.0.0.1:1234', 2, process_id=int(os.environ['CUDA_VISIBLE_DEVICES']))

print('Process id: ', jax.process_index())
print('Number of total devices (over all processes): ', jax.device_count())
print('Number of local devices:', jax.local_device_count())


import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -56,17 +45,41 @@ def set_boundary_conditions(self):
self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall))

if __name__ == '__main__':
precision = 'f32/f32'
lattice = LatticeD3Q19(precision)

# Create a parser that will read the command line arguments
parser = argparse.ArgumentParser("Calculate MLUPS for a 3D cavity flow simulation")
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN", default=100, type=int)
parser.add_argument("N_ITERS", help="Number of timesteps", default=10000, type=int)
parser.add_argument("N", help="The total number of voxels in one direction. The final dimension will be N*NxN",
default=100, type=int)
parser.add_argument("N_ITERS", help="Number of iterations", default=10000, type=int)
parser.add_argument("N_PROCESSES", help="Number of processes. If >1, call jax.distributed.initialize with that number of process. If -1 will call jax.distributed.initialize without any arsgument. So it should pick up the values from SLURM env variable.",
default=1, type=int)
parser.add_argument("IP", help="IP of the master node for multi-node. Useless if using SLURM.",
default='127.0.0.1', type=str, nargs='?')
parser.add_argument("PROCESS_ID_INCREMENT", help="For multi-node only. Useless if using SLURM.",
default=0, type=int, nargs='?')

args = parser.parse_args()
n = args.N
n_iters = args.N_ITERS
n_processes = args.N_PROCESSES
# Initialize JAX distributed. The IP, number of processes and process id must be set correctly.
print("N processes, ", n_processes)
print("N iter, ", n_iters)
if n_processes > 1:
process_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + args.PROCESS_ID_INCREMENT
print("ip, num_processes, process_id, ", args.IP, n_processes, process_id)
jax.distributed.initialize(args.IP, num_processes=n_processes,
process_id=process_id)
elif n_processes == -1:
print("Will call jax.distributed.initialize()")
jax.distributed.initialize()
print("jax.distributed.initialize() ended")
else:
print("No call to jax.distributed.initialize")
print("JAX local devices: ", jax.local_devices())

precision = 'f32/f32'
# Create a 3D lattice with the D3Q19 scheme
lattice = LatticeD3Q19(precision)

# Store the Reynolds number in the variable Re
Re = 100.0
Expand All @@ -92,4 +105,4 @@ def set_boundary_conditions(self):
}

sim = Cavity(**kwargs) # Run the simulation
sim.run(n_iters)
sim.run(n_iters)

0 comments on commit 30f6def

Please sign in to comment.