-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxdecomp proto #21
base: main
Are you sure you want to change the base?
jaxdecomp proto #21
Changes from 87 commits
a742065
6408aff
319942a
ac86468
e62cd84
5775a37
7501b5b
7f48cfa
c81d4d2
abde543
82be568
4f508b7
5f6d42e
1f6b9c3
ed8cf8e
0216837
5b7f595
1f20351
f25eb7d
8c5bd76
75604d2
ccbfee3
aebc3e7
9af4659
831291c
ece8c93
783a974
02754cf
8da3149
2ea05a1
ab86699
afecb13
01b9527
ff1c5e8
0ce7219
9c94f99
375f204
5a587fd
a160a3f
38714cf
591ee32
a5b267b
56ffd26
80c56dc
105568e
4d944f0
a8b194f
0433c61
2f50993
8623308
82b8f56
85cca44
d28982e
82f2987
31ca41b
cf799b6
0bb992f
45b2c7f
505f2ec
5d4f438
8e8e896
69c35d1
0f833f0
d2f1eb2
ff8856d
0c96a4d
49dd18a
11f7e90
4342279
cc4f310
d62c38f
b4fdb74
c93894f
19011d0
a757b62
f3b431a
2ad035a
b3a264a
b09580d
e9529d3
4da4c66
72457d6
a030ec4
f0c43f8
a067954
42d8e89
6256fba
2472a5d
ad45666
12c74e2
0946842
435c7c8
b32014b
c1b276d
e0c118a
21373b8
36ef18e
f70583b
7823fda
af29c40
97f39bd
ac4ef9e
5d34d3c
adaf7d2
d8c68ac
b264da5
47c69c6
8951f5c
3ce0be6
ae684c9
7384343
3be5dae
158478c
f91aa93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
name: Code Formatting | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip isort | ||
python -m pip install pre-commit | ||
- name: Run pre-commit | ||
run: python -m pre_commit run --all-files |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,4 @@ repos: | |
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
name: isort (python) | ||
name: isort (python) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,16 +4,15 @@ | |||||
<!-- ALL-CONTRIBUTORS-BADGE:END --> | ||||||
JAX-powered Cosmological Particle-Mesh N-body Solver | ||||||
|
||||||
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)** | ||||||
|
||||||
## Goals | ||||||
|
||||||
Provide a modern infrastructure to support differentiable PM N-body simulations using JAX: | ||||||
- Keep implementation simple and readable, in pure NumPy API | ||||||
- Transparent distribution using builtin `xmap` | ||||||
- Any order forward and backward automatic differentiation | ||||||
- Support automated batching using `vmap` | ||||||
- Compatibility with external optimizer libraries like `optax` | ||||||
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line should not be in the Goals section, it's a feature now. |
||||||
|
||||||
|
||||||
## Open development and use | ||||||
|
||||||
|
@@ -23,6 +22,10 @@ Current expectations are: | |||||
- Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal). | ||||||
- Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they *need* to extend authorship to all jaxpm developers. | ||||||
|
||||||
## Getting Started | ||||||
|
||||||
To dive into JaxPM’s capabilities, please explore the **notebook section** for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' [README here](notebooks/README.md) for a structured guide through each tutorial. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would put the link to the README around |
||||||
|
||||||
|
||||||
## Contributors ✨ | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
import os | ||
|
||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax | ||
import jax | ||
|
||
jax.distributed.initialize() | ||
|
||
rank = jax.process_index() | ||
size = jax.process_count() | ||
|
||
import argparse | ||
import time | ||
|
||
import jax.numpy as jnp | ||
import jax_cosmo as jc | ||
import numpy as np | ||
from cupy.cuda.nvtx import RangePop, RangePush | ||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, | ||
PIDController, SaveAt, Tsit5, diffeqsolve) | ||
from hpc_plotter.timer import Timer | ||
from jax.experimental import mesh_utils | ||
from jax.experimental.multihost_utils import sync_global_devices | ||
from jax.sharding import Mesh, NamedSharding | ||
from jax.sharding import PartitionSpec as P | ||
|
||
from jaxpm.kernels import interpolate_power_spectrum | ||
from jaxpm.painting import cic_paint_dx | ||
from jaxpm.pm import linear_field, lpt, make_ode_fn | ||
|
||
|
||
def run_simulation(mesh_shape, | ||
box_size, | ||
halo_size, | ||
solver_choice, | ||
iterations, | ||
hlo_print, | ||
trace, | ||
pdims=None, | ||
output_path="."): | ||
|
||
@jax.jit | ||
def simulate(omega_c, sigma8): | ||
# Create a small function to generate the matter power spectrum | ||
k = jnp.logspace(-4, 1, 128) | ||
pk = jc.power.linear_matter_power( | ||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) | ||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) | ||
|
||
# Create initial conditions | ||
initial_conditions = linear_field(mesh_shape, | ||
box_size, | ||
pk_fn, | ||
seed=jax.random.PRNGKey(0)) | ||
|
||
# Create particles | ||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) | ||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) | ||
if solver_choice == "Dopri5": | ||
solver = Dopri5() | ||
elif solver_choice == "LeapfrogMidpoint": | ||
solver = LeapfrogMidpoint() | ||
elif solver_choice == "Tsit5": | ||
solver = Tsit5() | ||
elif solver_choice == "lpt": | ||
lpt_field = cic_paint_dx(dx, halo_size=halo_size) | ||
return lpt_field, {"num_steps": 0} | ||
else: | ||
raise ValueError( | ||
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") | ||
# Evolve the simulation forward | ||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) | ||
term = ODETerm( | ||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) | ||
|
||
if solver_choice == "Dopri5" or solver_choice == "Tsit5": | ||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) | ||
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": | ||
stepsize_controller = ConstantStepSize() | ||
res = diffeqsolve(term, | ||
solver, | ||
t0=0.1, | ||
t1=1., | ||
dt0=0.01, | ||
y0=jnp.stack([dx, p], axis=0), | ||
args=cosmo, | ||
saveat=SaveAt(t1=True), | ||
stepsize_controller=stepsize_controller) | ||
|
||
# Return the simulation volume at requested | ||
state = res.ys[-1] | ||
final_field = cic_paint_dx(state[0], halo_size=halo_size) | ||
|
||
return final_field, res.stats | ||
|
||
def run(): | ||
# Warm start | ||
chrono_fun = Timer() | ||
RangePush("warmup") | ||
final_field, stats = chrono_fun.chrono_jit(simulate, | ||
0.32, | ||
0.8, | ||
ndarray_arg=0) | ||
RangePop() | ||
sync_global_devices("warmup") | ||
for i in range(iterations): | ||
RangePush(f"sim iter {i}") | ||
final_field, stats = chrono_fun.chrono_fun(simulate, | ||
0.32, | ||
0.8, | ||
ndarray_arg=0) | ||
RangePop() | ||
return final_field, stats, chrono_fun | ||
|
||
if jax.device_count() > 1: | ||
devices = mesh_utils.create_device_mesh(pdims) | ||
mesh = Mesh(devices.T, axis_names=('x', 'y')) | ||
with mesh: | ||
# Warm start | ||
final_field, stats, chrono_fun = run() | ||
else: | ||
final_field, stats, chrono_fun = run() | ||
|
||
return final_field, stats, chrono_fun | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser( | ||
description='JAX Cosmo Simulation Benchmark') | ||
parser.add_argument('-m', | ||
'--mesh_size', | ||
type=int, | ||
help='Mesh size', | ||
required=True) | ||
parser.add_argument('-b', | ||
'--box_size', | ||
type=float, | ||
help='Box size', | ||
required=True) | ||
parser.add_argument('-p', | ||
'--pdims', | ||
type=str, | ||
help='Processor dimensions', | ||
default=None) | ||
parser.add_argument( | ||
'-pr', | ||
'--precision', | ||
type=str, | ||
help='Precision', | ||
choices=["float32", "float64"], | ||
) | ||
parser.add_argument('-hs', | ||
'--halo_size', | ||
type=int, | ||
help='Halo size', | ||
default=None) | ||
parser.add_argument('-s', | ||
'--solver', | ||
type=str, | ||
help='Solver', | ||
choices=[ | ||
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5", | ||
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm", | ||
"lpt" | ||
], | ||
default="lpt") | ||
parser.add_argument('-o', | ||
'--output_path', | ||
type=str, | ||
help='Output path', | ||
default=".") | ||
parser.add_argument('-f', | ||
'--save_fields', | ||
action='store_true', | ||
help='Save fields') | ||
parser.add_argument('-n', | ||
'--nodes', | ||
type=int, | ||
help='Number of nodes', | ||
default=1) | ||
args = parser.parse_args() | ||
mesh_size = args.mesh_size | ||
box_size = [args.box_size] * 3 | ||
halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size | ||
solver_choice = args.solver | ||
iterations = args.iterations | ||
output_path = args.output_path | ||
os.makedirs(output_path, exist_ok=True) | ||
|
||
print(f"solver choice: {solver_choice}") | ||
match solver_choice: | ||
case "Dopri5" | "dopri5" | "d5": | ||
solver_choice = "Dopri5" | ||
case "Tsit5" | "tsit5" | "t5": | ||
solver_choice = "Tsit5" | ||
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm": | ||
solver_choice = "LeapfrogMidpoint" | ||
case "lpt": | ||
solver_choice = "lpt" | ||
case _: | ||
raise ValueError( | ||
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" | ||
) | ||
if args.precision == "float32": | ||
jax.config.update("jax_enable_x64", False) | ||
elif args.precision == "float64": | ||
jax.config.update("jax_enable_x64", True) | ||
|
||
if args.pdims: | ||
pdims = tuple(map(int, args.pdims.split("x"))) | ||
else: | ||
pdims = (1, jax.device_count()) | ||
pdm_str = f"{pdims[0]}x{pdims[1]}" | ||
|
||
mesh_shape = [mesh_size] * 3 | ||
|
||
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, | ||
halo_size, solver_choice, | ||
iterations, pdims) | ||
|
||
print( | ||
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}" | ||
) | ||
|
||
metadata = { | ||
'rank': rank, | ||
'function_name': f'JAXPM-{solver_choice}', | ||
'precision': args.precision, | ||
'x': str(mesh_size), | ||
'y': str(mesh_size), | ||
'z': str(stats["num_steps"]), | ||
'px': str(pdims[0]), | ||
'py': str(pdims[1]), | ||
'backend': 'NCCL', | ||
'nodes': str(args.nodes) | ||
} | ||
# Print the results to a CSV file | ||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) | ||
|
||
# Save the final field | ||
nb_gpus = jax.device_count() | ||
pdm_str = f"{pdims[0]}x{pdims[1]}" | ||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" | ||
os.makedirs(field_folder, exist_ok=True) | ||
with open(f'{field_folder}/jaxpm.log', 'w') as f: | ||
f.write(f"Args: {args}\n") | ||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") | ||
for i, time in enumerate(chrono_fun.times): | ||
f.write(f"Time {i}: {time:.4f} ms\n") | ||
f.write(f"Stats: {stats}\n") | ||
if args.save_fields: | ||
np.save(f'{field_folder}/final_field_0_{rank}.npy', | ||
final_field.addressable_data(0)) | ||
|
||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" | ||
os.makedirs(field_folder, exist_ok=True) | ||
with open(f'{field_folder}/jaxpm.log', 'w') as f: | ||
f.write(f"Args: {args}\n") | ||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") | ||
for i, time in enumerate(chrono_fun.times): | ||
f.write(f"Time {i}: {time:.4f} ms\n") | ||
f.write(f"Stats: {stats}\n") | ||
if args.save_fields: | ||
np.save(f'{field_folder}/final_field_0_{rank}.npy', | ||
final_field.addressable_data(0)) | ||
|
||
print(f"Finished! ") | ||
print(f"Stats {stats}") | ||
print(f"Saving to {output_path}/jax_pm_benchmark.csv") | ||
print(f"Saving field and logs in {field_folder}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check out https://github.com/pre-commit/action