Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Sep 25, 2023
1 parent 85e6169 commit 9a56c56
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions tests/test_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pymc as pm
import pytest
import pytensor.tensor as pt

import hssm
from hssm import distribution_utils
Expand Down Expand Up @@ -158,9 +159,9 @@ def fake_logp_function(data, param1, param2):


def test_extra_fields(data_ddm):
ones = np.ones(len(data_ddm))
ones = np.ones(data_ddm.shape[0])
x = ones * 0.5
y = ones * 2
y = ones * 4.0

def logp_ddm_extra_fields(data, v, a, z, t, x, y):
return logp_ddm(data, v, a, z, t) * x * y
Expand All @@ -176,32 +177,48 @@ def logp_ddm_extra_fields(data, v, a, z, t, x, y):

np.testing.assert_almost_equal(
pm.logp(DDM.dist(**true_values), data_ddm).eval(),
pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval(),
pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval() / 2.0,
)

data_ddm_copy = data_ddm.copy()
data_ddm_copy["x"] = x
data_ddm_copy["y"] = y

ddm_model_xy = hssm.HSSM(
data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]), p_outlier=None
data=data_ddm_copy,
model_config=dict(extra_fields=["x", "y"]),
loglik=logp_ddm_extra_fields,
p_outlier=None,
lapse=None,
)

np.testing.assert_almost_equal(
pm.logp(DDM.dist(**true_values), data_ddm).eval(),
pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval(),
pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval()
/ 2.0,
)

ddm_model = hssm.HSSM(data=data_ddm)
ddm_model_p = hssm.HSSM(
data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"])
data=data_ddm_copy,
model_config=dict(extra_fields=["x", "y"]),
loglik=logp_ddm_extra_fields,
)
ddm_model_p_logp_without_lapse = (
pm.logp(
ddm_model_p.model_distribution.dist(**true_values, p_outlier=0),
data_ddm,
)
/ 2
)
ddm_model_p_logp_lapse = pt.log(
0.95 * pt.exp(ddm_model_p_logp_without_lapse)
+ 0.05
* pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), data_ddm["rt"].values))
)
np.testing.assert_almost_equal(
pm.logp(
ddm_model.model_distribution.dist(**true_values, p_outlier=0.05), data_ddm
).eval(),
pm.logp(
ddm_model_p.model_distribution.dist(**true_values, p_outlier=0.05),
data_ddm,
).eval(),
ddm_model_p_logp_lapse.eval(),
)

0 comments on commit 9a56c56

Please sign in to comment.