From 07988a63161c091f105e7045ff7ef31fa4a32b4f Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 6 Mar 2024 10:09:58 -0500 Subject: [PATCH] Update tests --- tests/test_hssm.py | 2 -- tests/test_missing_and_deadline.py | 37 ++++++++++++------------------ 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 840415ef..0a660303 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -1,5 +1,3 @@ -from pathlib import Path - import bambi as bmb import numpy as np import pytest diff --git a/tests/test_missing_and_deadline.py b/tests/test_missing_and_deadline.py index 36dff14c..1ca9cb5b 100644 --- a/tests/test_missing_and_deadline.py +++ b/tests/test_missing_and_deadline.py @@ -48,7 +48,7 @@ def data(): def test_make_missing_data_callable_cpn(data, ddm, cpn): - # Test corner case where data_dim == 0 (CPN case) + # Test corner case where no data is passed (CPN case) # Also needs to be careful when all parameters are scalar # In which case the cpn callable should return a scalar data = data[:, :-1] @@ -60,11 +60,11 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): # Test cpn when all inputs are scalars cpn_callable_jax = make_missing_data_callable( - cpn, is_cpn_only=True, backend="jax", params_is_reg=[False] * 4 + cpn, params_only=True, backend="jax", params_is_reg=[False] * 4 ) cpn_callable_pytensor = make_missing_data_callable( - cpn, is_cpn_only=True, backend="pytensor" + cpn, params_only=True, backend="pytensor" ) result_jax = cpn_callable_jax(None, *dist_params).eval() @@ -74,7 +74,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): # Test cpn when some inputs are vectors cpn_callable_jax_vector = make_missing_data_callable( - cpn, is_cpn_only=True, backend="jax", params_is_reg=[True] + [False] * 3 + cpn, params_only=True, backend="jax", params_is_reg=[True] + [False] * 3 ) result_jax = cpn_callable_jax_vector(None, *dist_params_vector).eval() @@ -88,7 +88,6 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): loglik_kind="approx_differentiable", backend="jax", params_is_reg=[False] * 4, - data_dim=2, ) n_missing = np.sum(data[:, 0] == -999.0).astype(int) @@ -96,7 +95,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): missing_eval = cpn_callable_jax(None, *dist_params).eval() assembled_loglik = assemble_callables( - logp_callable_jax, cpn_callable_jax, is_cpn_only=True, has_deadline=False + logp_callable_jax, cpn_callable_jax, params_only=True, has_deadline=False ) result_assembled = assembled_loglik(data, *dist_params).eval() @@ -117,7 +116,6 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): loglik_kind="approx_differentiable", backend="jax", params_is_reg=[True] + [False] * 3, - data_dim=2, ) result_data_jax = logp_callable_jax_vector( @@ -130,7 +128,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): assembled_loglik = assemble_callables( logp_callable_jax_vector, cpn_callable_jax_vector, - is_cpn_only=True, + params_only=True, has_deadline=False, ) @@ -151,13 +149,12 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): ddm, loglik_kind="approx_differentiable", backend="pytensor", - data_dim=2, ) assembled_loglik = assemble_callables( logp_callable_pytensor, cpn_callable_pytensor, - is_cpn_only=True, + params_only=True, has_deadline=False, ) @@ -173,7 +170,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): assembled_loglik = assemble_callables( logp_callable_pytensor, cpn_callable_pytensor, - is_cpn_only=True, + params_only=True, has_deadline=False, ) @@ -187,7 +184,6 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn): def test_make_missing_data_callable_opn(data, ddm, opn): - # Test edge case where data_dim == 0 (OPN case) # Also needs to be careful when all parameters are scalar # In which case the cpn callable should return a scalar @@ -198,11 +194,11 @@ def test_make_missing_data_callable_opn(data, ddm, opn): # Test cpn when all inputs are scalars opn_callable_jax = make_missing_data_callable( - opn, is_cpn_only=False, backend="jax", params_is_reg=[False] * 4 + opn, params_only=False, backend="jax", params_is_reg=[False] * 4 ) opn_callable_pytensor = make_missing_data_callable( - opn, is_cpn_only=False, backend="pytensor" + opn, params_only=False, backend="pytensor" ) result_jax = opn_callable_jax(data[:, -1].reshape((100, 1)), *dist_params).eval() @@ -212,7 +208,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn): # Test opn when some inputs are vectors opn_callable_jax_vector = make_missing_data_callable( - opn, is_cpn_only=False, backend="jax", params_is_reg=[True] + [False] * 3 + opn, params_only=False, backend="jax", params_is_reg=[True] + [False] * 3 ) result_jax = opn_callable_jax_vector(data[:, [-1]], *dist_params_vector).eval() @@ -226,7 +222,6 @@ def test_make_missing_data_callable_opn(data, ddm, opn): loglik_kind="approx_differentiable", backend="jax", params_is_reg=[False] * 4, - data_dim=2, ) n_missing = np.sum(data[:, 0] == -999.0).astype(int) @@ -235,7 +230,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn): missing_eval = opn_callable_jax(data[:n_missing, -1:], *dist_params).eval() assembled_loglik = assemble_callables( - logp_callable_jax, opn_callable_jax, is_cpn_only=False, has_deadline=True + logp_callable_jax, opn_callable_jax, params_only=False, has_deadline=True ) result_assembled = assembled_loglik(data, *dist_params).eval() @@ -256,7 +251,6 @@ def test_make_missing_data_callable_opn(data, ddm, opn): loglik_kind="approx_differentiable", backend="jax", params_is_reg=[True] + [False] * 3, - data_dim=2, ) result_data_jax = logp_callable_jax_vector( @@ -269,7 +263,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn): assembled_loglik = assemble_callables( logp_callable_jax_vector, opn_callable_jax_vector, - is_cpn_only=False, + params_only=False, has_deadline=True, ) @@ -290,13 +284,12 @@ def test_make_missing_data_callable_opn(data, ddm, opn): ddm, loglik_kind="approx_differentiable", backend="pytensor", - data_dim=2, ) assembled_loglik = assemble_callables( logp_callable_pytensor, opn_callable_pytensor, - is_cpn_only=False, + params_only=False, has_deadline=True, ) @@ -312,7 +305,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn): assembled_loglik = assemble_callables( logp_callable_pytensor, opn_callable_pytensor, - is_cpn_only=False, + params_only=False, has_deadline=True, )