diff --git a/src/hssm/config.py b/src/hssm/config.py index fd6054e2..639f83be 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -132,6 +132,7 @@ def update_config(self, user_config: ModelConfig) -> None: self.default_priors |= user_config.default_priors self.bounds |= user_config.bounds + self.extra_fields = user_config.extra_fields def validate(self) -> None: """Ensure that mandatory fields are not None.""" diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 377864e6..e2020a2c 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging +from copy import deepcopy from inspect import isclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal @@ -835,7 +836,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: lapse=self.lapse, extra_fields=None if not self.extra_fields - else [self.data[field].values for field in self.extra_fields], + else [deepcopy(self.data[field].values) for field in self.extra_fields], ) # type: ignore # If the user has provided a callable (an arbitrary likelihood function) # If `loglik_kind` is `blackbox`, wrap it in an op and then a distribution @@ -852,7 +853,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: lapse=self.lapse, extra_fields=None if not self.extra_fields - else [self.data[field].values for field in self.extra_fields], + else [deepcopy(self.data[field].values) for field in self.extra_fields], ) # type: ignore # All other situations if self.loglik_kind != "approx_differentiable": @@ -881,7 +882,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: lapse=self.lapse, extra_fields=None if not self.extra_fields - else [self.data[field].values for field in self.extra_fields], + else [deepcopy(self.data[field].values) for field in self.extra_fields], ) def _check_extra_fields(self, data: pd.DataFrame | None = None) -> bool: diff --git a/tests/test_distribution_utils.py b/tests/test_distribution_utils.py index bef9e551..64a4f320 100644 --- a/tests/test_distribution_utils.py +++ b/tests/test_distribution_utils.py @@ -2,6 +2,7 @@ import numpy as np import pymc as pm import pytest +import pytensor.tensor as pt import hssm from hssm import distribution_utils @@ -158,9 +159,9 @@ def fake_logp_function(data, param1, param2): def test_extra_fields(data_ddm): - ones = np.ones(len(data_ddm)) + ones = np.ones(data_ddm.shape[0]) x = ones * 0.5 - y = ones * 2 + y = ones * 4.0 def logp_ddm_extra_fields(data, v, a, z, t, x, y): return logp_ddm(data, v, a, z, t) * x * y @@ -176,7 +177,7 @@ def logp_ddm_extra_fields(data, v, a, z, t, x, y): np.testing.assert_almost_equal( pm.logp(DDM.dist(**true_values), data_ddm).eval(), - pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval(), + pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval() / 2.0, ) data_ddm_copy = data_ddm.copy() @@ -184,24 +185,40 @@ def logp_ddm_extra_fields(data, v, a, z, t, x, y): data_ddm_copy["y"] = y ddm_model_xy = hssm.HSSM( - data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]), p_outlier=None + data=data_ddm_copy, + model_config=dict(extra_fields=["x", "y"]), + loglik=logp_ddm_extra_fields, + p_outlier=None, + lapse=None, ) np.testing.assert_almost_equal( pm.logp(DDM.dist(**true_values), data_ddm).eval(), - pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval(), + pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval() + / 2.0, ) ddm_model = hssm.HSSM(data=data_ddm) ddm_model_p = hssm.HSSM( - data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]) + data=data_ddm_copy, + model_config=dict(extra_fields=["x", "y"]), + loglik=logp_ddm_extra_fields, + ) + ddm_model_p_logp_without_lapse = ( + pm.logp( + ddm_model_p.model_distribution.dist(**true_values, p_outlier=0), + data_ddm, + ) + / 2 + ) + ddm_model_p_logp_lapse = pt.log( + 0.95 * pt.exp(ddm_model_p_logp_without_lapse) + + 0.05 + * pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), data_ddm["rt"].values)) ) np.testing.assert_almost_equal( pm.logp( ddm_model.model_distribution.dist(**true_values, p_outlier=0.05), data_ddm ).eval(), - pm.logp( - ddm_model_p.model_distribution.dist(**true_values, p_outlier=0.05), - data_ddm, - ).eval(), + ddm_model_p_logp_lapse.eval(), )