Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Oct 16, 2023
1 parent 851cbbe commit c00c2a0
Showing 1 changed file with 142 additions and 45 deletions.
187 changes: 142 additions & 45 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,89 @@
import pytest

import arviz as az
import matplotlib
import matplotlib.pyplot as plt
import hssm
import pymc as pm

matplotlib.use("Agg")
hssm.set_floatX("float32")

parameter_names = "loglik_kind,backend,sampler,step,expected"
parameter_grid = [
("analytical", None, None, None, True), # Defaults should work
("analytical", None, "mcmc", None, True),
("analytical", None, "mcmc", "slice", True),
("analytical", None, "nuts_numpyro", None, True),
("analytical", None, "nuts_numpyro", "slice", TypeError),
("approx_differentiable", "pytensor", None, None, True), # Defaults should work
("approx_differentiable", "pytensor", "mcmc", None, True),
("approx_differentiable", "pytensor", "mcmc", "slice", True),
("approx_differentiable", "pytensor", "nuts_numpyro", None, True),
("approx_differentiable", "pytensor", "nuts_numpyro", "slice", TypeError),
("approx_differentiable", "jax", None, None, True), # Defaults should work
("approx_differentiable", "jax", "mcmc", None, True),
("approx_differentiable", "jax", "mcmc", "slice", True),
("approx_differentiable", "jax", "nuts_numpyro", None, True),
("approx_differentiable", "jax", "nuts_numpyro", "slice", TypeError),
("blackbox", None, None, None, True), # Defaults should work
("blackbox", None, "mcmc", None, True),
("blackbox", None, "mcmc", "slice", True),
("blackbox", None, "nuts_numpyro", None, ValueError),
("blackbox", None, "nuts_numpyro", "slice", ValueError),
]


def sample(model, sampler, step):
if step == "slice":
model.sample(
sampler=sampler,
cores=1,
chains=1,
tune=10,
draws=10,
step=pm.Slice(model=model.pymc_model),
)
else:
model.sample(
sampler=sampler,
cores=1,
chains=1,
tune=10,
draws=10,
)


def run_sample(model, sampler, step, expected):
if expected == True:
sample(model, sampler, step)
assert isinstance(model.traces, az.InferenceData)
else:
with pytest.raises(expected):
sample(model, sampler, step)


@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected):
model = hssm.HSSM(
data_ddm, loglik_kind=loglik_kind, model_config={"backend": backend}
)
run_sample(model, sampler, step, expected)

def test_non_reg_models(data_ddm):
model1 = hssm.HSSM(data_ddm)
model1.sample_prior_predictive(draws=10)

model1.sample(cores=1, chains=1, tune=10, draws=10)
model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

model2 = hssm.HSSM(data_ddm, loglik_kind="approx_differentiable")
model2.sample(cores=1, chains=1, tune=10, draws=10)
model2.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

model3 = hssm.HSSM(data_ddm, loglik_kind="blackbox")
model3.sample(cores=1, chains=1, tune=10, draws=10)
model3.sample(cores=1, chains=1, tune=10, draws=10)
# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert not model._get_deterministic_var_names()
# test summary:
summary = model.summary()
assert summary.shape[0] == 4

model1.sample_posterior_predictive(data=data_ddm.iloc[:10, :])
model.plot_trace(show=False)
fig = plt.gcf()
assert len(fig.axes) // 2 == 4


def test_reg_models(data_ddm_reg):
@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected):
param_reg = dict(
formula="v ~ 1 + x + y",
prior={
Expand All @@ -32,46 +92,83 @@ def test_reg_models(data_ddm_reg):
"y": {"name": "Uniform", "lower": -0.50, "upper": 0.50},
},
)
model = hssm.HSSM(
data_ddm_reg,
loglik_kind=loglik_kind,
model_config={"backend": backend},
v=param_reg,
)
run_sample(model, sampler, step, expected)

model1 = hssm.HSSM(data_ddm_reg, v=param_reg)
model1.sample_prior_predictive(draws=10)

model1.sample(cores=1, chains=1, tune=10, draws=10)
model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", v=param_reg)
model2.sample(cores=1, chains=1, tune=10, draws=10)
model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10)

model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", v=param_reg)
model3.sample(cores=1, chains=1, tune=10, draws=10)

with pytest.raises(ValueError):
model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)
# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert not model._get_deterministic_var_names()
# test summary:
summary = model.summary()
assert summary.shape[0] == 6

model1.sample_posterior_predictive(data=data_ddm_reg.iloc[:10, :])
model.plot_trace(show=False)
fig = plt.gcf()
assert len(fig.axes) // 2 == 6


def test_reg_models_a(data_ddm_reg):
param_reg = dict(
formula="a ~ 1 + x + y",
@pytest.mark.parametrize(parameter_names, parameter_grid)
def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expected):
param_reg_v = dict(
formula="v ~ 1 + x + y",
prior={
"Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
"x": {"name": "Uniform", "lower": -0.50, "upper": 0.50},
"y": {"name": "Uniform", "lower": -0.50, "upper": 0.50},
},
)
param_reg_a = dict(
formula="v ~ 1 + x + y",
prior={
"Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0},
"x": {"name": "Uniform", "lower": -0.50, "upper": 0.50},
"y": {"name": "Uniform", "lower": -0.50, "upper": 0.50},
},
link="log",
)

model1 = hssm.HSSM(data_ddm_reg, a=param_reg)
model1.sample(cores=1, chains=1, tune=10, draws=10)
model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)
model = hssm.HSSM(
data_ddm_reg,
loglik_kind=loglik_kind,
model_config={"backend": backend},
v=param_reg_v,
a=param_reg_a,
)
run_sample(model, sampler, step, expected)

# Only runs once
if loglik_kind == "analytical" and sampler is None:
assert model._get_deterministic_var_names() == ["~a"]
# test summary:
summary = model.summary()
assert summary.shape[0] == 8

summary = model.summary(var_names=["~a"])
assert summary.shape[0] == 8

summary = model.summary(var_names=["~t"])
assert summary.shape[0] == 7

summary = model.summary(var_names=["~a", "~t"])
assert summary.shape[0] == 7

model.plot_trace(show=False)
fig = plt.gcf()
assert len(fig.axes) // 2 == 8

model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", a=param_reg)
model2.sample(cores=1, chains=1, tune=10, draws=10)
model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10)
model.plot_trace(show=False, var_names=["~a"])
fig = plt.gcf()
assert len(fig.axes) // 2 == 8

model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", a=param_reg)
model3.sample(cores=1, chains=1, tune=10, draws=10)
model.plot_trace(show=False, var_names=["~t"])
fig = plt.gcf()
assert len(fig.axes) // 2 == 7

with pytest.raises(ValueError):
model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)
model.plot_trace(show=False, var_names=["~a", "~t"])
fig = plt.gcf()
assert len(fig.axes) // 2 == 7

0 comments on commit c00c2a0

Please sign in to comment.