Skip to content

Commit

Permalink
Merge pull request #7 from 21cmfast/add_get_var
Browse files Browse the repository at this point in the history
feat:add get_variance
  • Loading branch information
DanielaBreitman authored Dec 10, 2024
2 parents 6725cf6 + 609bb42 commit fd578af
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/py21cmfast_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""py21cmfast_tools package."""

from .lc2ps import bin_kpar as bin_kpar
from .lc2ps import calculate_ps as calculate_ps
from .lc2ps import log_bin as log_bin
from .lc2ps import cylindrical_to_spherical as cylindrical_to_spherical
from .lc2ps import postprocess_ps as postprocess_ps
230 changes: 206 additions & 24 deletions src/py21cmfast_tools/lc2ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from powerbox.tools import (
_magnitude_grid,
above_mu_min_angular_generator,
angular_average,
get_power,
ignore_zero_ki,
power2delta,
regular_angular_generator,
)
from scipy.interpolate import RegularGridInterpolator


def calculate_ps( # noqa: C901
Expand All @@ -34,6 +36,7 @@ def calculate_ps( # noqa: C901
interp=None,
prefactor_fnc=power2delta,
interp_points_generator=None,
get_variance=False,
):
r"""Calculate power spectra from a lightcone.
Expand Down Expand Up @@ -113,16 +116,22 @@ def calculate_ps( # noqa: C901
A function that generates the points at which to interpolate the PS.
See powerbox.tools.get_power documentation for more details.
"""
if not interp:
interp = None
# Split the lightcone into chunks for each redshift bin
# Infer HII_DIM from lc side shape
if box_side_shape is None:
box_side_shape = lc.shape[0]
if get_variance and interp is not None:
raise NotImplementedError("Cannot get variance while interpolating.")
if zs is None:
if chunk_size is None:
chunk_size = box_side_shape
n_slices = lc.shape[-1]
chunk_indices = list(range(0, n_slices - chunk_size, chunk_skip))
else:
if np.min(zs) < np.min(lc_redshifts) or np.max(zs) > np.max(lc_redshifts):
raise ValueError("zs should be within the range of lc_redshifts")
if chunk_size is None:
chunk_size = box_side_shape
chunk_indices = np.array(
Expand All @@ -139,6 +148,13 @@ def calculate_ps( # noqa: C901
zs = [] # all redshifts that will be computed
lc_ps_2d = []
clean_lc_ps_2d = []
if get_variance:
if calc_2d:
lc_var_2d = []
if postprocess:
clean_lc_var_2d = []
if calc_1d:
lc_var_1d = []
if calc_global:
tb = []
if calc_1d:
Expand All @@ -152,17 +168,26 @@ def calculate_ps( # noqa: C901
start = i
end = i + chunk_size
if end > len(lc_redshifts):
# Shift the chunk back if it goes over the edge of the lc
shift_it_back_by_a_few_bins = end - len(lc_redshifts)
start -= shift_it_back_by_a_few_bins
end = len(lc_redshifts)
if start < 0:
# Shift the chunk forward if it starts before the start of the lc
end += -start
start = 0
chunk = lc[..., start:end]
zs.append(lc_redshifts[(start + end) // 2])
if calc_global:
tb.append(np.mean(chunk))
if calc_2d:
ps_2d, kperp, nmodes, kpar = get_power(
results = get_power(
chunk,
(box_length, box_length, box_length * chunk.shape[-1] / box_side_shape),
(
box_length,
box_length,
box_length * chunk.shape[-1] / box_side_shape,
),
res_ndim=2,
bin_ave=bin_ave,
bins=nbins,
Expand All @@ -172,7 +197,15 @@ def calculate_ps( # noqa: C901
prefactor_fnc=prefactor_fnc,
interpolation_method=interp,
return_sumweights=True,
get_variance=get_variance,
)
if get_variance:
ps_2d, kperp, var, nmodes, kpar = results
lc_var_2d.append(var)
else:
ps_2d, kperp, nmodes, kpar = results

lc_ps_2d.append(ps_2d)
if postprocess:
clean_ps_2d, clean_kperp, clean_kpar, clean_nmodes = postprocess_ps(
ps_2d,
Expand All @@ -183,10 +216,22 @@ def calculate_ps( # noqa: C901
crop=crop.copy() if crop is not None else crop,
kperp_modes=nmodes,
return_modes=True,
interp=interp,
)
clean_lc_ps_2d.append(clean_ps_2d)

lc_ps_2d.append(ps_2d)
if get_variance:
clean_var_2d, _, _ = postprocess_ps(
var,
kperp,
kpar,
log_bins=log_bins,
kpar_bins=kpar_bins,
crop=crop.copy() if crop is not None else crop,
kperp_modes=nmodes,
return_modes=False,
interp=interp,
)
clean_lc_var_2d.append(clean_var_2d)

if calc_1d:
if mu is not None:
Expand All @@ -212,9 +257,14 @@ def mask_fnc(freq, absk):
k_weights1d = ignore_zero_ki
if interp is not None:
interp_points_generator = regular_angular_generator()
ps_1d, k, nmodes_1d = get_power(

results = get_power(
chunk,
(box_length, box_length, box_length * chunk.shape[-1] / box_side_shape),
(
box_length,
box_length,
box_length * chunk.shape[-1] / box_side_shape,
),
bin_ave=bin_ave,
bins=nbins_1d,
log_bins=log_bins,
Expand All @@ -223,39 +273,51 @@ def mask_fnc(freq, absk):
interpolation_method=interp,
interp_points_generator=interp_points_generator,
return_sumweights=True,
get_variance=get_variance,
)
if get_variance:
ps_1d, k, var_1d, nmodes_1d = results
lc_var_1d.append(var_1d)
else:
ps_1d, k, nmodes_1d = results
lc_ps_1d.append(ps_1d)

if calc_1d:
out["k"] = k
out["ps_1D"] = np.array(lc_ps_1d)
out["Nmodes_1D"] = nmodes_1d
out["mu"] = mu
if get_variance:
out["var_1D"] = np.array(lc_var_1d)
if calc_2d:
out["full_kperp"] = kperp
out["full_kpar"] = kpar
out["full_kpar"] = kpar[0]
out["full_ps_2D"] = np.array(lc_ps_2d)
out["full_Nmodes"] = nmodes
if get_variance:
out["full_var_2D"] = np.array(lc_var_2d)
if postprocess:
out["final_ps_2D"] = np.array(clean_lc_ps_2d)
out["final_kpar"] = clean_kpar
out["final_kperp"] = clean_kperp
out["final_Nmodes"] = clean_nmodes
if get_variance:
out["final_var_2D"] = np.array(clean_lc_var_2d)
if calc_global:
out["global_Tb"] = np.array(tb)
out["redshifts"] = np.array(zs)

return out


def log_bin(ps, kperp, kpar, bins=None):
def bin_kpar(ps, kperp, kpar, bins=None, interp=None, log=False, redshifts=None):
r"""
Log bin a 2D PS along the kpar axis and crop out empty bins in both axes.
Bin a 2D PS along the kpar axis and crop out empty bins in both axes.
Parameters
----------
ps : np.ndarray
The 2D power spectrum of shape [len(kperp), len(kpar)].
The 2D power spectrum of shape [len(redshifts), len(kperp), len(kpar)].
kperp : np.ndarray
Values of kperp.
kpar : np.ndarray
Expand All @@ -264,23 +326,68 @@ def log_bin(ps, kperp, kpar, bins=None):
The number of bins or the bin edges to use for binning the kpar axis.
If None, produces 16 bins logarithmically spaced between
the minimum and maximum `kpar` supplied.
interp : str, optional
If 'linear', use linear interpolation to calculate the PS at the specified
kpar bins.
log : bool, optional
If 'False', kpar is binned linearly. If 'True', it is binned logarithmically.
redshifts : np.ndarray, optional
The redshifts at which the PS was calculated.
"""
ps = ps if len(ps.shape) == 3 else ps[np.newaxis, ...]
if bins is None:
bins = np.logspace(np.log10(kpar[0]), np.log10(kpar[-1]), 17)
if log:
bins = np.logspace(
np.log10(kpar[0]), np.log10(kpar[-1]), len(kpar) // 2 + 1
)
else:
bins = np.linspace(kpar[0], kpar[-1], len(kpar) // 2 + 1)
elif isinstance(bins, int):
bins = np.logspace(np.log10(kpar[0]), np.log10(kpar[-1]), bins + 1)
if log:
bins = np.logspace(np.log10(kpar[0]), np.log10(kpar[-1]), bins + 1)
else:
bins = np.linspace(kpar[0], kpar[-1], bins + 1)
elif isinstance(bins, (np.ndarray, list)):
bins = np.array(bins)
else:
raise ValueError("Bins should be np.ndarray or int")
modes = np.zeros(len(bins) - 1)
new_ps = np.zeros((len(kperp), len(bins) - 1))
for i in range(len(bins) - 1):
m = np.logical_and(kpar >= bins[i], kpar < bins[i + 1])
new_ps[:, i] = np.nanmean(ps[:, m], axis=1)
modes[i] = np.sum(m)
bin_centers = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
if log:
bin_centers = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
else:
bin_centers = (bins[1:] + bins[:-1]) / 2
if interp == "linear":
new_ps = np.zeros((ps.shape[0], len(kperp), len(bins)))
modes = np.zeros(len(bins))
interp_fnc = RegularGridInterpolator(
(redshifts, kperp, kpar) if redshifts is not None else (kperp, kpar),
ps.squeeze(),
bounds_error=False,
fill_value=np.nan,
)

if redshifts is None:
kperp_grid, kpar_grid = np.meshgrid(
kperp, bin_centers, indexing="ij", sparse=True
)
new_ps = interp_fnc((kperp_grid, kpar_grid))
else:
redshifts_grid, kperp_grid, kpar_grid = np.meshgrid(
redshifts, kperp, bin_centers, indexing="ij", sparse=True
)
new_ps = interp_fnc((redshifts_grid, kperp_grid, kpar_grid))

idxs = np.digitize(kpar, bins) - 1
for i in range(len(bins) - 1):
modes[i] = np.sum(idxs == i)
else:
new_ps = np.zeros((ps.shape[0], len(kperp), len(bins) - 1))
modes = np.zeros(len(bins) - 1)
idxs = np.digitize(kpar, bins) - 1
for i in range(len(bins) - 1):
m = idxs == i
new_ps[..., i] = np.nanmean(ps[..., m], axis=-1)
modes[i] = np.sum(m)

return new_ps, kperp, bin_centers, modes


Expand All @@ -293,8 +400,9 @@ def postprocess_ps(
crop=None,
kperp_modes=None,
return_modes=False,
interp=None,
):
"""
r"""
Postprocess a 2D PS by cropping out empty bins and log binning the kpar axis.
Parameters
Expand Down Expand Up @@ -333,9 +441,9 @@ def postprocess_ps(
kperp = kperp[mkperp]
ps = ps[mkperp, :]

# Bin kpar in log
rebinned_ps, kperp, log_kpar, kpar_weights = log_bin(
ps, kperp, kpar, bins=kpar_bins
# maybe rebin kpar in log
rebinned_ps, kperp, log_kpar, kpar_weights = bin_kpar(
ps, kperp, kpar, bins=kpar_bins, interp=interp, log=log_bins
)
if crop is None:
crop = [0, rebinned_ps.shape[0] + 1, 0, rebinned_ps.shape[1] + 1]
Expand Down Expand Up @@ -364,9 +472,83 @@ def postprocess_ps(
log_kpar[crop[2] : crop[3]],
nmodes,
)
else:
return (
rebinned_ps[crop[0] : crop[1]][:, crop[2] : crop[3]],
kperp[crop[0] : crop[1]],
log_kpar[crop[2] : crop[3]],
)
else:
return (
rebinned_ps[crop[0] : crop[1]][:, crop[2] : crop[3]],
kperp[crop[0] : crop[1]],
log_kpar[crop[2] : crop[3]],
)


def cylindrical_to_spherical(
ps,
kperp,
kpar,
nbins=16,
weights=1,
interp=False,
mu=None,
generator=None,
bin_ave=True,
):
r"""
Angularly average 2D PS to 1D PS.
Parameters
----------
ps : np.ndarray
The 2D power spectrum of shape [len(kperp), len(kpar)].
kperp : np.ndarray
Values of kperp.
kpar : np.ndarray
Values of kpar.
nbins : int, optional
The number of bins on which to calculate 1D PS. Default is 16
weights : np.ndarray, optional
Weights to apply to the PS before averaging.
Note that to obtain a 1D PS from the 2D PS that is consistent with
the 1D PS obtained directly from the 3D PS, the weights should be
the number of modes in each bin of the 2D PS (`Nmodes`).
interp : bool, optional
If True, use linear interpolation to calculate the 1D PS.
mu : float, optional
The minimum value of
:math:`\\cos(\theta), \theta = \arctan (k_\\perp/k_\\parallel)`
for all calculated PS.
If None, all modes are included.
generator : callable, optional
A function that generates the points at which to interpolate the PS.
See powerbox.tools.get_power documentation for more details.
bin_ave : bool, optional
If True, return the center value of each k bin
i.e. len(k) = ps_1d.shape[0].
If False, return the left edge of each bin
i.e. len(k) = ps_1d.shape[0] + 1.
"""
if mu is not None and interp and generator is None:
generator = above_mu_min_angular_generator(mu=mu)

if mu is not None and not interp:
kpar_mesh, kperp_mesh = np.meshgrid(kpar, kperp)
theta = np.arctan(kperp_mesh / kpar_mesh)
mu_mesh = np.cos(theta)
weights = mu_mesh >= mu

ps_1d, k, sws = angular_average(
ps,
coords=[kperp, kpar],
bins=nbins,
weights=weights,
bin_ave=bin_ave,
log_bins=True,
return_sumweights=True,
interpolation_method="linear" if interp else None,
interp_points_generator=generator,
)
return ps_1d, k, sws
Loading

0 comments on commit fd578af

Please sign in to comment.