Skip to content

Commit

Permalink
Add finite difference laplace kernel + powerspec functions from Hugo
Browse files Browse the repository at this point in the history
Co-authored-by: Hugo Simonfroy <[email protected]>
  • Loading branch information
ASKabalan and hsimonfroy committed Dec 5, 2024
1 parent 435c7c8 commit b32014b
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 77 deletions.
13 changes: 10 additions & 3 deletions jaxpm/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,28 @@ def gradient_kernel(kvec, direction, order=1):
return wts


def invlaplace_kernel(kvec):
def invlaplace_kernel(kvec, fd=False):
"""
Compute the inverse Laplace kernel
Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters
-----------
kvec: list
List of wave-vectors
fd: bool
Finite difference kernel
Returns
--------
wts: array
Complex kernel values
"""
kk = sum(ki**2 for ki in kvec)
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)

Expand Down
218 changes: 144 additions & 74 deletions jaxpm/utils.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,156 @@
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre

from jaxpm.growth import growth_factor, growth_rate

__all__ = ['power_spectrum']
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]


def _initialize_pk(shape, boxsize, kmin, dk):
def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
Parameters
----------
mesh_shape : tuple of int
Shape of the mesh grid.
box_shape : tuple of float
Physical dimensions of the box.
kedges : None, int, float, or list
If None, set dk to twice the minimum.
If int, specifies number of edges.
If float, specifies dk.
los : array_like
Line-of-sight vector.
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
"""
I = np.eye(len(shape), dtype='int') * -2 + 1

W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0

kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)

k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5

xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)

dig = np.digitize(kmag.flat, kedges)

xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges


def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist

if isinstance(kedges, None | int | float):
if kedges is None:
dk = 2 * np.pi / np.min(
box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2

kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5

dig = np.digitize(kmesh.reshape(-1), kedges)
kcount = np.bincount(dig, minlength=len(kedges) + 1)

# Central value of each bin
# kavg = (kedges[1:] + kedges[:-1]) / 2
kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]

if los is None:
mumesh = 1.
else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)

return dig, kcount, kavg, mumesh


def power_spectrum(mesh,
mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
"""
Calculate the powerspectra given real space field
Args:
field: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
"""
shape = field.shape
nx, ny, nz = shape

#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)

#fast fourier transform
fft_image = jnp.fft.fftn(field)

#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))

#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])

Psum = jnp.bincount(dig, weights=(W.flatten() * imag),
length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)

P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')

#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2

#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2

return kbins, P / norm
Compute the auto and cross spectrum of 3D fields, with multipoles.
"""
# Initialize
mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)

if multipoles == 0:
los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2

# FFTs
meshk = jnp.fft.fftn(mesh, norm='ortho')
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()

# Sum powers
pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)

# Normalization and conversion from cell units to [Mpc/h]^3
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()

# pk = jnp.concatenate([kavg[None], pk])
if np.ndim(multipoles) == 0:
return kavg, pk[0]
else:
return kavg, pk


def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5


def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5


def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5


def cross_correlation_coefficients(field_a,
Expand Down

0 comments on commit b32014b

Please sign in to comment.