Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Mar 6, 2024
1 parent c6f4bc0 commit 07988a6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
2 changes: 0 additions & 2 deletions tests/test_hssm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import bambi as bmb
import numpy as np
import pytest
Expand Down
37 changes: 15 additions & 22 deletions tests/test_missing_and_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def data():


def test_make_missing_data_callable_cpn(data, ddm, cpn):
# Test corner case where data_dim == 0 (CPN case)
# Test corner case where no data is passed (CPN case)
# Also needs to be careful when all parameters are scalar
# In which case the cpn callable should return a scalar
data = data[:, :-1]
Expand All @@ -60,11 +60,11 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):

# Test cpn when all inputs are scalars
cpn_callable_jax = make_missing_data_callable(
cpn, is_cpn_only=True, backend="jax", params_is_reg=[False] * 4
cpn, params_only=True, backend="jax", params_is_reg=[False] * 4
)

cpn_callable_pytensor = make_missing_data_callable(
cpn, is_cpn_only=True, backend="pytensor"
cpn, params_only=True, backend="pytensor"
)

result_jax = cpn_callable_jax(None, *dist_params).eval()
Expand All @@ -74,7 +74,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):

# Test cpn when some inputs are vectors
cpn_callable_jax_vector = make_missing_data_callable(
cpn, is_cpn_only=True, backend="jax", params_is_reg=[True] + [False] * 3
cpn, params_only=True, backend="jax", params_is_reg=[True] + [False] * 3
)

result_jax = cpn_callable_jax_vector(None, *dist_params_vector).eval()
Expand All @@ -88,15 +88,14 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[False] * 4,
data_dim=2,
)
n_missing = np.sum(data[:, 0] == -999.0).astype(int)

result_data_jax = logp_callable_jax(data[n_missing:, :], *dist_params).eval()
missing_eval = cpn_callable_jax(None, *dist_params).eval()

assembled_loglik = assemble_callables(
logp_callable_jax, cpn_callable_jax, is_cpn_only=True, has_deadline=False
logp_callable_jax, cpn_callable_jax, params_only=True, has_deadline=False
)

result_assembled = assembled_loglik(data, *dist_params).eval()
Expand All @@ -117,7 +116,6 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[True] + [False] * 3,
data_dim=2,
)

result_data_jax = logp_callable_jax_vector(
Expand All @@ -130,7 +128,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):
assembled_loglik = assemble_callables(
logp_callable_jax_vector,
cpn_callable_jax_vector,
is_cpn_only=True,
params_only=True,
has_deadline=False,
)

Expand All @@ -151,13 +149,12 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):
ddm,
loglik_kind="approx_differentiable",
backend="pytensor",
data_dim=2,
)

assembled_loglik = assemble_callables(
logp_callable_pytensor,
cpn_callable_pytensor,
is_cpn_only=True,
params_only=True,
has_deadline=False,
)

Expand All @@ -173,7 +170,7 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):
assembled_loglik = assemble_callables(
logp_callable_pytensor,
cpn_callable_pytensor,
is_cpn_only=True,
params_only=True,
has_deadline=False,
)

Expand All @@ -187,7 +184,6 @@ def test_make_missing_data_callable_cpn(data, ddm, cpn):


def test_make_missing_data_callable_opn(data, ddm, opn):
# Test edge case where data_dim == 0 (OPN case)
# Also needs to be careful when all parameters are scalar
# In which case the cpn callable should return a scalar

Expand All @@ -198,11 +194,11 @@ def test_make_missing_data_callable_opn(data, ddm, opn):

# Test cpn when all inputs are scalars
opn_callable_jax = make_missing_data_callable(
opn, is_cpn_only=False, backend="jax", params_is_reg=[False] * 4
opn, params_only=False, backend="jax", params_is_reg=[False] * 4
)

opn_callable_pytensor = make_missing_data_callable(
opn, is_cpn_only=False, backend="pytensor"
opn, params_only=False, backend="pytensor"
)

result_jax = opn_callable_jax(data[:, -1].reshape((100, 1)), *dist_params).eval()
Expand All @@ -212,7 +208,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn):

# Test opn when some inputs are vectors
opn_callable_jax_vector = make_missing_data_callable(
opn, is_cpn_only=False, backend="jax", params_is_reg=[True] + [False] * 3
opn, params_only=False, backend="jax", params_is_reg=[True] + [False] * 3
)

result_jax = opn_callable_jax_vector(data[:, [-1]], *dist_params_vector).eval()
Expand All @@ -226,7 +222,6 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[False] * 4,
data_dim=2,
)

n_missing = np.sum(data[:, 0] == -999.0).astype(int)
Expand All @@ -235,7 +230,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
missing_eval = opn_callable_jax(data[:n_missing, -1:], *dist_params).eval()

assembled_loglik = assemble_callables(
logp_callable_jax, opn_callable_jax, is_cpn_only=False, has_deadline=True
logp_callable_jax, opn_callable_jax, params_only=False, has_deadline=True
)

result_assembled = assembled_loglik(data, *dist_params).eval()
Expand All @@ -256,7 +251,6 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
loglik_kind="approx_differentiable",
backend="jax",
params_is_reg=[True] + [False] * 3,
data_dim=2,
)

result_data_jax = logp_callable_jax_vector(
Expand All @@ -269,7 +263,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
assembled_loglik = assemble_callables(
logp_callable_jax_vector,
opn_callable_jax_vector,
is_cpn_only=False,
params_only=False,
has_deadline=True,
)

Expand All @@ -290,13 +284,12 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
ddm,
loglik_kind="approx_differentiable",
backend="pytensor",
data_dim=2,
)

assembled_loglik = assemble_callables(
logp_callable_pytensor,
opn_callable_pytensor,
is_cpn_only=False,
params_only=False,
has_deadline=True,
)

Expand All @@ -312,7 +305,7 @@ def test_make_missing_data_callable_opn(data, ddm, opn):
assembled_loglik = assemble_callables(
logp_callable_pytensor,
opn_callable_pytensor,
is_cpn_only=False,
params_only=False,
has_deadline=True,
)

Expand Down

0 comments on commit 07988a6

Please sign in to comment.