Skip to content

Commit

Permalink
pep8_isort
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Aug 29, 2023
1 parent e498f14 commit 72c2c29
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 40 deletions.
3 changes: 1 addition & 2 deletions diffmah/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# flake8: noqa

from ._version import __version__

from .monte_carlo_halo_population import mc_halo_population
from .individual_halo_assembly import calc_halo_history
from .monte_carlo_diffmah_hiz import mc_diffmah_params_hiz
from .monte_carlo_halo_population import mc_halo_population
15 changes: 9 additions & 6 deletions diffmah/halo_population_assembly.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Calculate differentiable probabilistic history of an individual halo."""
import numpy as np
from jax import numpy as jnp
from jax import jit as jjit
from jax import numpy as jnp
from jax import vmap
from jax.scipy.stats import multivariate_normal as jnorm
from .individual_halo_assembly import _calc_halo_history, _get_early_late
from .individual_halo_assembly import DEFAULT_MAH_PARAMS
from .rockstar_pdf_model import _get_mah_means_and_covs
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, LGT0, _get_mah_means_and_covs

CLIP = -10.0

Expand Down Expand Up @@ -55,7 +58,7 @@ def _get_bimodal_halo_history_kern(
dmhdts, log_mahs = _halo_history_integrand(
logt, logtmp, logmp, lgtc_arr, k, early_arr, late_arr
)
mahs = 10 ** log_mahs
mahs = 10**log_mahs

weights_early = _get_mah_weights(ue_arr, ul_arr, lgtc_arr, mu_early, cov_early)
weights_late = _get_mah_weights(ue_arr, ul_arr, lgtc_arr, mu_late, cov_late)
Expand Down
7 changes: 5 additions & 2 deletions diffmah/individual_halo_assembly.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Model for individual halo mass assembly based on a power-law with rolling index."""
from collections import OrderedDict
from jax import numpy as jnp

from jax import grad
from jax import jit as jjit
from jax import vmap, grad
from jax import numpy as jnp
from jax import vmap

from .utils import get_1d_arrays

DEFAULT_MAH_PARAMS = OrderedDict(mah_logtc=0.05, mah_k=3.5, mah_ue=2.4, mah_ul=-2.0)
Expand Down
3 changes: 2 additions & 1 deletion diffmah/load_mah_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Load the diffmah data into memory."""
import numpy as np
import os
import warnings

import numpy as np

TASSO = "/Users/aphearin/work/DATA/diffmah_data/PUBLISHED_DATA"
BEBOP = "/lcrc/project/halotools/diffmah_data/PUBLISHED_DATA"

Expand Down
9 changes: 5 additions & 4 deletions diffmah/measure_mahs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Functions used to define the target data for fitting the halo population model."""
import numpy as np
import warnings

import numpy as np


def get_clean_sample_mask(log_mah_fit, logmp_sample, it_min, lim=0.01, z_cut=3):
"""Calculate mask to remove halos with outlier MAH behavior.
Expand Down Expand Up @@ -64,7 +65,7 @@ def measure_target_data(mah, dmhdt, lgt, lgt_target, logmp_sample):
"""
mah0 = mah[:, -1].reshape(-1, 1)
mp_sample = 10 ** logmp_sample
mp_sample = 10**logmp_sample
scaled_mah = mp_sample * mah / mah0
scaled_dmhdt = mp_sample * dmhdt / mah0
with warnings.catch_warnings():
Expand All @@ -84,6 +85,6 @@ def measure_target_data(mah, dmhdt, lgt, lgt_target, logmp_sample):
std_dmhdt = 10 ** np.interp(lgt_target, lgt, np.log10(std_dmhdt_table))
std_log_mah = np.interp(lgt_target, lgt, std_log_mah_table)

var_dmhdt = std_dmhdt ** 2
var_log_mah = std_log_mah ** 2
var_dmhdt = std_dmhdt**2
var_log_mah = std_log_mah**2
return mean_mah, mean_log_mah, var_log_mah, mean_dmhdt, var_dmhdt
4 changes: 2 additions & 2 deletions diffmah/monte_carlo_diffmah_hiz.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Generate Diffmah parameters for halos identified at higher redshift
"""
import numpy as np
from .monte_carlo_halo_population import mc_halo_population

from .halo_population_assembly import LGT0
from .individual_halo_assembly import calc_halo_history
from .monte_carlo_halo_population import mc_halo_population


def mc_diffmah_params_hiz(ran_key, t_obs, logmh, lgt0=LGT0, npop=int(1e5)):
Expand Down Expand Up @@ -59,7 +60,6 @@ def mc_diffmah_params_hiz(ran_key, t_obs, logmh, lgt0=LGT0, npop=int(1e5)):
late_index = np.array(halopop.late_index)
lgtc = np.array(halopop.lgtc)
else:

lgm0_guess = _guess_logmp_z0(t_obs, logmh, lgt0, npop)

tarr = np.array((t_obs, t0))
Expand Down
12 changes: 8 additions & 4 deletions diffmah/monte_carlo_halo_population.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""
"""
import typing

import numpy as np
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran
from jax import jit as jjit
from jax import vmap
from .rockstar_pdf_model import _get_mah_means_and_covs, DEFAULT_MAH_PDF_PARAMS
from .individual_halo_assembly import _get_early_late
from .individual_halo_assembly import DEFAULT_MAH_PARAMS, _calc_halo_history

from .individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_get_early_late,
)
from .rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS, _get_mah_means_and_covs

MAH_K = DEFAULT_MAH_PARAMS["mah_k"]

Expand Down
10 changes: 8 additions & 2 deletions diffmah/optimize_nbody.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""
"""
from collections import OrderedDict

from jax import jit as jjit
from jax import numpy as jnp
from .halo_population_assembly import _get_bimodal_halo_history
from .halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR

from .halo_population_assembly import (
LGTC_ARR,
UE_ARR,
UL_ARR,
_get_bimodal_halo_history,
)

BOUNDS = OrderedDict(
frac_late_ylo=(0.35, 0.45),
Expand Down
10 changes: 8 additions & 2 deletions diffmah/optimize_tng.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""
"""
from collections import OrderedDict

from jax import jit as jjit
from jax import numpy as jnp
from .halo_population_assembly import _get_bimodal_halo_history
from .halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR

from .halo_population_assembly import (
LGTC_ARR,
UE_ARR,
UL_ARR,
_get_bimodal_halo_history,
)

BOUNDS = OrderedDict(
frac_late_ylo=(0.3, 0.6),
Expand Down
7 changes: 3 additions & 4 deletions diffmah/tests/test_diff_mc_halos.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Unit test of the new differentiable Monte Carlo generator
"""
import numpy as np
from jax import random as jran
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran
from jax import value_and_grad
from ..monte_carlo_halo_population import mc_halo_population
from ..monte_carlo_halo_population import _mc_halo_mahs
from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS

from ..monte_carlo_halo_population import _mc_halo_mahs, mc_halo_population
from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS

SEED = 43

Expand Down
5 changes: 3 additions & 2 deletions diffmah/tests/test_fit_mah_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
"""
import numpy as np

from ..fit_mah_helpers import get_target_data


Expand All @@ -14,7 +15,7 @@ def test_get_target_data_no_cuts():
logt_target, log_mah_target = get_target_data(
t_sim, log_mah_sim, lgm_min, dlogm_cut, t_fit_min
)
assert np.allclose(10 ** logt_target, t_sim)
assert np.allclose(10**logt_target, t_sim)
assert np.allclose(log_mah_sim, log_mah_target, atol=0.01)


Expand All @@ -29,5 +30,5 @@ def test_get_target_data_lgm_cut():
t_sim, log_mah_sim, lgm_min, dlogm_cut, t_fit_min
)
assert logt_target.shape == log_mah_target.shape
assert np.allclose(t_sim[1:], 10 ** logt_target)
assert np.allclose(t_sim[1:], 10**logt_target)
assert np.allclose(log_mah_sim[1:], log_mah_target, atol=0.01)
10 changes: 8 additions & 2 deletions diffmah/tests/test_halo_population_assembly.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""
"""
import os

import numpy as np
from ..halo_population_assembly import _get_bimodal_halo_history
from ..halo_population_assembly import UE_ARR, UL_ARR, LGTC_ARR

from ..halo_population_assembly import (
LGTC_ARR,
UE_ARR,
UL_ARR,
_get_bimodal_halo_history,
)
from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_PARAMS

_THIS_DRNAME = os.path.dirname(os.path.abspath(__file__))
Expand Down
15 changes: 10 additions & 5 deletions diffmah/tests/test_individual_halo_assembly.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""
"""
import os
from jax import numpy as jnp

import numpy as np
from ..individual_halo_assembly import _calc_halo_history, DEFAULT_MAH_PARAMS
from ..individual_halo_assembly import _power_law_index_vs_logt, _get_early_late
from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late
from ..individual_halo_assembly import _calc_halo_history_scalar
from jax import numpy as jnp

from ..individual_halo_assembly import (
DEFAULT_MAH_PARAMS,
_calc_halo_history,
_calc_halo_history_scalar,
_get_early_late,
_power_law_index_vs_logt,
)
from ..rockstar_pdf_model import _get_mean_mah_params_early, _get_mean_mah_params_late

_THIS_DRNAME = os.path.dirname(os.path.abspath(__file__))
DDRN = os.path.join(_THIS_DRNAME, "testing_data")
Expand Down
1 change: 1 addition & 0 deletions diffmah/tests/test_mc_halos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import numpy as np

from ..monte_carlo_halo_population import mc_halo_population
from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_PDF_PARAMS

Expand Down
4 changes: 3 additions & 1 deletion diffmah/tests/test_monte_carlo_diffmah_hiz.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
"""
import warnings
from jax import random as jran

import numpy as np
from jax import random as jran

from ..individual_halo_assembly import calc_halo_history
from ..monte_carlo_diffmah_hiz import mc_diffmah_params_hiz

Expand Down
2 changes: 1 addition & 1 deletion diffmah/tests/test_opt_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"""
from ..optimize_nbody import BOUNDS as NBODY_BOUNDS
from ..optimize_tng import BOUNDS as TNG_BOUNDS
from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_DEFAULTS
from ..rockstar_pdf_model import DEFAULT_MAH_PDF_PARAMS as NBODY_DEFAULTS
from ..tng_pdf_model import DEFAULT_MAH_PDF_PARAMS as TNG_DEFAULTS


def test_nbody_params_are_correctly_bounded():
Expand Down

0 comments on commit 72c2c29

Please sign in to comment.