From c00c2a0ca75e1fc5e338a18d262a64d38fa993dc Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 16 Oct 2023 16:57:11 -0400 Subject: [PATCH] Update tests --- tests/test_mcmc.py | 187 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 142 insertions(+), 45 deletions(-) diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index f3caa2d4..ef58d522 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -1,29 +1,89 @@ import pytest +import arviz as az +import matplotlib +import matplotlib.pyplot as plt import hssm +import pymc as pm +matplotlib.use("Agg") hssm.set_floatX("float32") +parameter_names = "loglik_kind,backend,sampler,step,expected" +parameter_grid = [ + ("analytical", None, None, None, True), # Defaults should work + ("analytical", None, "mcmc", None, True), + ("analytical", None, "mcmc", "slice", True), + ("analytical", None, "nuts_numpyro", None, True), + ("analytical", None, "nuts_numpyro", "slice", TypeError), + ("approx_differentiable", "pytensor", None, None, True), # Defaults should work + ("approx_differentiable", "pytensor", "mcmc", None, True), + ("approx_differentiable", "pytensor", "mcmc", "slice", True), + ("approx_differentiable", "pytensor", "nuts_numpyro", None, True), + ("approx_differentiable", "pytensor", "nuts_numpyro", "slice", TypeError), + ("approx_differentiable", "jax", None, None, True), # Defaults should work + ("approx_differentiable", "jax", "mcmc", None, True), + ("approx_differentiable", "jax", "mcmc", "slice", True), + ("approx_differentiable", "jax", "nuts_numpyro", None, True), + ("approx_differentiable", "jax", "nuts_numpyro", "slice", TypeError), + ("blackbox", None, None, None, True), # Defaults should work + ("blackbox", None, "mcmc", None, True), + ("blackbox", None, "mcmc", "slice", True), + ("blackbox", None, "nuts_numpyro", None, ValueError), + ("blackbox", None, "nuts_numpyro", "slice", ValueError), +] + + +def sample(model, sampler, step): + if step == "slice": + model.sample( + sampler=sampler, + cores=1, + chains=1, + tune=10, + draws=10, + step=pm.Slice(model=model.pymc_model), + ) + else: + model.sample( + sampler=sampler, + cores=1, + chains=1, + tune=10, + draws=10, + ) + + +def run_sample(model, sampler, step, expected): + if expected == True: + sample(model, sampler, step) + assert isinstance(model.traces, az.InferenceData) + else: + with pytest.raises(expected): + sample(model, sampler, step) + + +@pytest.mark.parametrize(parameter_names, parameter_grid) +def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): + model = hssm.HSSM( + data_ddm, loglik_kind=loglik_kind, model_config={"backend": backend} + ) + run_sample(model, sampler, step, expected) -def test_non_reg_models(data_ddm): - model1 = hssm.HSSM(data_ddm) - model1.sample_prior_predictive(draws=10) - - model1.sample(cores=1, chains=1, tune=10, draws=10) - model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) - - model2 = hssm.HSSM(data_ddm, loglik_kind="approx_differentiable") - model2.sample(cores=1, chains=1, tune=10, draws=10) - model2.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) - - model3 = hssm.HSSM(data_ddm, loglik_kind="blackbox") - model3.sample(cores=1, chains=1, tune=10, draws=10) - model3.sample(cores=1, chains=1, tune=10, draws=10) + # Only runs once + if loglik_kind == "analytical" and sampler is None: + assert not model._get_deterministic_var_names() + # test summary: + summary = model.summary() + assert summary.shape[0] == 4 - model1.sample_posterior_predictive(data=data_ddm.iloc[:10, :]) + model.plot_trace(show=False) + fig = plt.gcf() + assert len(fig.axes) // 2 == 4 -def test_reg_models(data_ddm_reg): +@pytest.mark.parametrize(parameter_names, parameter_grid) +def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected): param_reg = dict( formula="v ~ 1 + x + y", prior={ @@ -32,46 +92,83 @@ def test_reg_models(data_ddm_reg): "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, }, ) + model = hssm.HSSM( + data_ddm_reg, + loglik_kind=loglik_kind, + model_config={"backend": backend}, + v=param_reg, + ) + run_sample(model, sampler, step, expected) - model1 = hssm.HSSM(data_ddm_reg, v=param_reg) - model1.sample_prior_predictive(draws=10) - - model1.sample(cores=1, chains=1, tune=10, draws=10) - model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) - - model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", v=param_reg) - model2.sample(cores=1, chains=1, tune=10, draws=10) - model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10) - - model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", v=param_reg) - model3.sample(cores=1, chains=1, tune=10, draws=10) - - with pytest.raises(ValueError): - model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + # Only runs once + if loglik_kind == "analytical" and sampler is None: + assert not model._get_deterministic_var_names() + # test summary: + summary = model.summary() + assert summary.shape[0] == 6 - model1.sample_posterior_predictive(data=data_ddm_reg.iloc[:10, :]) + model.plot_trace(show=False) + fig = plt.gcf() + assert len(fig.axes) // 2 == 6 -def test_reg_models_a(data_ddm_reg): - param_reg = dict( - formula="a ~ 1 + x + y", +@pytest.mark.parametrize(parameter_names, parameter_grid) +def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expected): + param_reg_v = dict( + formula="v ~ 1 + x + y", + prior={ + "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, + "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + }, + ) + param_reg_a = dict( + formula="v ~ 1 + x + y", prior={ "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, }, + link="log", ) - model1 = hssm.HSSM(data_ddm_reg, a=param_reg) - model1.sample(cores=1, chains=1, tune=10, draws=10) - model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model = hssm.HSSM( + data_ddm_reg, + loglik_kind=loglik_kind, + model_config={"backend": backend}, + v=param_reg_v, + a=param_reg_a, + ) + run_sample(model, sampler, step, expected) + + # Only runs once + if loglik_kind == "analytical" and sampler is None: + assert model._get_deterministic_var_names() == ["~a"] + # test summary: + summary = model.summary() + assert summary.shape[0] == 8 + + summary = model.summary(var_names=["~a"]) + assert summary.shape[0] == 8 + + summary = model.summary(var_names=["~t"]) + assert summary.shape[0] == 7 + + summary = model.summary(var_names=["~a", "~t"]) + assert summary.shape[0] == 7 + + model.plot_trace(show=False) + fig = plt.gcf() + assert len(fig.axes) // 2 == 8 - model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", a=param_reg) - model2.sample(cores=1, chains=1, tune=10, draws=10) - model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10) + model.plot_trace(show=False, var_names=["~a"]) + fig = plt.gcf() + assert len(fig.axes) // 2 == 8 - model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", a=param_reg) - model3.sample(cores=1, chains=1, tune=10, draws=10) + model.plot_trace(show=False, var_names=["~t"]) + fig = plt.gcf() + assert len(fig.axes) // 2 == 7 - with pytest.raises(ValueError): - model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model.plot_trace(show=False, var_names=["~a", "~t"]) + fig = plt.gcf() + assert len(fig.axes) // 2 == 7