Skip to content

Commit

Permalink
Merge pull request #286 from lnccbrown/fix-extra-fields
Browse files Browse the repository at this point in the history
Fix bug in extra fields
  • Loading branch information
digicosmos86 authored Sep 29, 2023
2 parents c0b1501 + 9a56c56 commit c16cedc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def update_config(self, user_config: ModelConfig) -> None:

self.default_priors |= user_config.default_priors
self.bounds |= user_config.bounds
self.extra_fields = user_config.extra_fields

def validate(self) -> None:
"""Ensure that mandatory fields are not None."""
Expand Down
7 changes: 4 additions & 3 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import logging
from copy import deepcopy
from inspect import isclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal
Expand Down Expand Up @@ -835,7 +836,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
lapse=self.lapse,
extra_fields=None
if not self.extra_fields
else [self.data[field].values for field in self.extra_fields],
else [deepcopy(self.data[field].values) for field in self.extra_fields],
) # type: ignore
# If the user has provided a callable (an arbitrary likelihood function)
# If `loglik_kind` is `blackbox`, wrap it in an op and then a distribution
Expand All @@ -852,7 +853,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
lapse=self.lapse,
extra_fields=None
if not self.extra_fields
else [self.data[field].values for field in self.extra_fields],
else [deepcopy(self.data[field].values) for field in self.extra_fields],
) # type: ignore
# All other situations
if self.loglik_kind != "approx_differentiable":
Expand Down Expand Up @@ -881,7 +882,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
lapse=self.lapse,
extra_fields=None
if not self.extra_fields
else [self.data[field].values for field in self.extra_fields],
else [deepcopy(self.data[field].values) for field in self.extra_fields],
)

def _check_extra_fields(self, data: pd.DataFrame | None = None) -> bool:
Expand Down
37 changes: 27 additions & 10 deletions tests/test_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pymc as pm
import pytest
import pytensor.tensor as pt

import hssm
from hssm import distribution_utils
Expand Down Expand Up @@ -158,9 +159,9 @@ def fake_logp_function(data, param1, param2):


def test_extra_fields(data_ddm):
ones = np.ones(len(data_ddm))
ones = np.ones(data_ddm.shape[0])
x = ones * 0.5
y = ones * 2
y = ones * 4.0

def logp_ddm_extra_fields(data, v, a, z, t, x, y):
return logp_ddm(data, v, a, z, t) * x * y
Expand All @@ -176,32 +177,48 @@ def logp_ddm_extra_fields(data, v, a, z, t, x, y):

np.testing.assert_almost_equal(
pm.logp(DDM.dist(**true_values), data_ddm).eval(),
pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval(),
pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval() / 2.0,
)

data_ddm_copy = data_ddm.copy()
data_ddm_copy["x"] = x
data_ddm_copy["y"] = y

ddm_model_xy = hssm.HSSM(
data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]), p_outlier=None
data=data_ddm_copy,
model_config=dict(extra_fields=["x", "y"]),
loglik=logp_ddm_extra_fields,
p_outlier=None,
lapse=None,
)

np.testing.assert_almost_equal(
pm.logp(DDM.dist(**true_values), data_ddm).eval(),
pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval(),
pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval()
/ 2.0,
)

ddm_model = hssm.HSSM(data=data_ddm)
ddm_model_p = hssm.HSSM(
data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"])
data=data_ddm_copy,
model_config=dict(extra_fields=["x", "y"]),
loglik=logp_ddm_extra_fields,
)
ddm_model_p_logp_without_lapse = (
pm.logp(
ddm_model_p.model_distribution.dist(**true_values, p_outlier=0),
data_ddm,
)
/ 2
)
ddm_model_p_logp_lapse = pt.log(
0.95 * pt.exp(ddm_model_p_logp_without_lapse)
+ 0.05
* pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), data_ddm["rt"].values))
)
np.testing.assert_almost_equal(
pm.logp(
ddm_model.model_distribution.dist(**true_values, p_outlier=0.05), data_ddm
).eval(),
pm.logp(
ddm_model_p.model_distribution.dist(**true_values, p_outlier=0.05),
data_ddm,
).eval(),
ddm_model_p_logp_lapse.eval(),
)

0 comments on commit c16cedc

Please sign in to comment.