Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding fixamp ics #15

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
afdfcfa
Adds a trivial jaxpm implementation
EiffL Feb 13, 2022
c9a4268
adds growth functions from Chirag
EiffL Feb 13, 2022
25bde55
Adds demo and notebooks
EiffL Feb 14, 2022
6522c25
Merge pull request #9 from DifferentiableUniverseInitiative/experimen…
EiffL Feb 14, 2022
ebc9cf9
Created using Colaboratory
EiffL Feb 14, 2022
ae2be36
Update README.md
EiffL Feb 14, 2022
dbfecd6
Adding demo CAMELS notebook
EiffL Mar 21, 2022
b08641d
Add utility to compute the power spectrum
EiffL Mar 25, 2022
5549568
fix normalization of init cond
EiffL Mar 25, 2022
d8a5b70
adding cic compensation tools
EiffL Mar 25, 2022
907dc42
adding neural network
EiffL Mar 26, 2022
5195a28
fix minor issue
EiffL Mar 26, 2022
687aad6
fix power spec
EiffL Mar 27, 2022
1795319
adding hamiltonian gnn demo
EiffL Apr 27, 2022
d0a15e8
adding hamiltonian gnn demo
EiffL Apr 27, 2022
e8806e0
merging upsteam
EiffL Apr 27, 2022
e188d5e
adding function for doing 2d paintinng
EiffL May 17, 2022
e45bab2
adding density plane cutting code
EiffL May 17, 2022
d3026f7
PGD
May 17, 2022
0008f85
adds utilities for simple lensing
EiffL May 17, 2022
eabac32
adding notebook demo
EiffL May 17, 2022
80eb8cb
minor correction to gaussian smoothing
EiffL May 17, 2022
88cff99
adds fix to make code jittablel
EiffL May 17, 2022
a52287c
change impor
EiffL May 17, 2022
0f6fb39
small fix
EiffL May 17, 2022
b737631
Update jaxpm/pm.py
dlanzieri May 18, 2022
8266f11
Update jaxpm/pm.py
dlanzieri May 18, 2022
782c9db
Update jaxpm/pm.py
dlanzieri May 18, 2022
5394abb
Update jaxpm/pm.py
dlanzieri May 18, 2022
9967dff
Merge pull request #12 from DifferentiableUniverseInitiative/PGD
EiffL May 18, 2022
4a584b1
Update jaxpm/painting.py
EiffL May 18, 2022
e7fa6c9
Merge pull request #11 from DifferentiableUniverseInitiative/u/EiffL/…
EiffL May 18, 2022
1b7c797
changes definition of lensplanes
EiffL May 18, 2022
5c4f75a
fixes the lensing demo
EiffL May 18, 2022
101ef80
Merge pull request #13 from DifferentiableUniverseInitiative/u/EiffL/…
EiffL May 18, 2022
a67fc42
Adding fixamp ics
dforero0896 Sep 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<!-- ALL-CONTRIBUTORS-BADGE:END -->
JAX-powered Cosmological Particle-Mesh N-body Solver

**This project is currently in an early design phased. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**
**This project is currently in an early design phase. All inputs are welcome on the [design document](https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/main/design.md)**

## Goals

Expand Down
1,140 changes: 1,140 additions & 0 deletions dev/CAMELS-loss-several_steps-TimeDependent.ipynb

Large diffs are not rendered by default.

898 changes: 898 additions & 0 deletions dev/HamiltonianGNN.ipynb

Large diffs are not rendered by default.

913 changes: 913 additions & 0 deletions dev/JaxPM_ODE-tCOLA.ipynb

Large diffs are not rendered by default.

674 changes: 674 additions & 0 deletions dev/JaxPM_ODE.ipynb

Large diffs are not rendered by default.

63 changes: 63 additions & 0 deletions dev/test_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt
import jax
import numpy as np
import jax.numpy as jnp
import jax.lax as lax
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions

def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])

neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]

dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
return mesh

# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))

f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)

devices = np.array(jax.devices()).reshape((2, 2, 1))

# Let's import the mesh
m = jnp.zeros([32, 32, 32])

with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)

# Apply the sharded CiC function
res = f(m)

plt.imshow(res.sum(axis=2))
plt.show()
Empty file added jaxpm/__init__.py
Empty file.
Loading