Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify probes to introduce scale cut ell max for each redshift bin #117

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a

radian_to_arcmin = 3437.7467707849396

__all__ = ["WeakLensing", "NumberCounts"]


Expand Down Expand Up @@ -71,7 +73,7 @@ def integrand_single(z_prime):

# Constant term
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c
# Ell dependent factor
# Ell-dependent factor
ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2
return constant_factor * ell_factor * radial_kernel

Expand Down Expand Up @@ -100,6 +102,20 @@ def density_kernel(cosmo, pzs, bias, z, ell):
ell_factor = 1.0
return constant_factor * ell_factor * radial_kernel

@jit
def cmb_lensing_kernel(cosmo, z, ell):
"""
Computes the CMB weak lensing kernel
"""
z_cmb=1100.
chi_cmb = bkgrd.radial_comoving_distance(cosmo, z2a(z_cmb))
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))
radial_kernel = (chi / z2a(z)) * (np.clip(chi_cmb - chi, 0) / np.clip(chi_cmb, 1.0))
# Normalization,
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c
# Ell-dependent factors
ell_factor = 1.0
return constant_factor * ell_factor * radial_kernel

@jit
def nla_kernel(cosmo, pzs, bias, z, ell):
Expand Down Expand Up @@ -152,6 +168,7 @@ class WeakLensing(container):
def __init__(
self,
redshift_bins,
lmax=None,
ia_bias=None,
multiplicative_bias=0.0,
sigma_e=0.26,
Expand All @@ -161,10 +178,10 @@ def __init__(
# container
if ia_bias is None:
ia_enabled = False
args = (redshift_bins, multiplicative_bias)
args = (redshift_bins, lmax, multiplicative_bias)
else:
ia_enabled = True
args = (redshift_bins, multiplicative_bias, ia_bias)
args = (redshift_bins, lmax, multiplicative_bias, ia_bias)
if "ia_enabled" not in kwargs.keys():
kwargs["ia_enabled"] = ia_enabled
super(WeakLensing, self).__init__(*args, sigma_e=sigma_e, **kwargs)
Expand All @@ -181,12 +198,20 @@ def n_tracers(self):
@property
def zmax(self):
"""
Returns the maximum redsfhit probed by this probe
Returns the maximum redshift probed by this probe
"""
# Extract parameters
pzs = self.params[0]
return max([pz.zmax for pz in pzs])

@property
def lmax(self):
"""
Returns the maximum multipole probed by this probe
"""
# Extract parameters
return self.params[1]

def kernel(self, cosmo, z, ell):
"""
Compute the radial kernel for all nz bins in this probe.
Expand All @@ -197,16 +222,21 @@ def kernel(self, cosmo, z, ell):
"""
z = np.atleast_1d(z)
# Extract parameters
pzs, m = self.params[:2]
pzs, lmax, m = self.params[:3]
kernel = weak_lensing_kernel(cosmo, pzs, z, ell)
# If IA is enabled, we add the IA kernel
if self.config["ia_enabled"]:
bias = self.params[2]
bias = self.params[3]
kernel += nla_kernel(cosmo, pzs, bias, z, ell)
# Applies measurement systematics
if isinstance(m, list):
m = np.expand_dims(np.stack([mi for mi in m], axis=0), 1)
kernel *= 1.0 + m
if (lmax is not None):
if isinstance(lmax, list):
lmax = np.expand_dims(np.stack([bin_lmax for bin_lmax in lmax], axis=0), 1)
ell_weight = np.where(ell>lmax, 0., 1.)
kernel *= ell_weight
return kernel

def noise(self):
Expand Down Expand Up @@ -238,9 +268,15 @@ class NumberCounts(container):
has_rsd....
"""

def __init__(self, redshift_bins, bias, has_rsd=False, **kwargs):
def __init__(
self,
redshift_bins,
bias,
lmax=None,
has_rsd=False, **kwargs):
args = (redshift_bins, bias, lmax)
super(NumberCounts, self).__init__(
redshift_bins, bias, has_rsd=has_rsd, **kwargs
*args, has_rsd=has_rsd, **kwargs
)

@property
Expand All @@ -252,6 +288,14 @@ def zmax(self):
pzs = self.params[0]
return max([pz.zmax for pz in pzs])

@property
def lmax(self):
"""
Returns the maximum multipole probed by this probe
"""
# Extract parameters
return self.params[2]

@property
def n_tracers(self):
"""Returns the number of tracers for this probe, i.e. redshift bins"""
Expand All @@ -268,9 +312,14 @@ def kernel(self, cosmo, z, ell):
"""
z = np.atleast_1d(z)
# Extract parameters
pzs, bias = self.params
pzs, bias, lmax = self.params[:3]
# Retrieve density kernel
kernel = density_kernel(cosmo, pzs, bias, z, ell)
if (lmax is not None):
if isinstance(lmax, list):
lmax = np.expand_dims(np.stack([bin_lmax for bin_lmax in lmax], axis=0), 1)
ell_weight = np.where(ell>lmax, 0., 1.)
kernel *= ell_weight
return kernel

def noise(self):
Expand Down