Skip to content

Commit

Permalink
Merge pull request #377 from lnccbrown/373-shift-prior-means-for-log_…
Browse files Browse the repository at this point in the history
…logit-settings

373 shift prior means for log logit settings
  • Loading branch information
AlexanderFengler authored Apr 1, 2024
2 parents 77e3aad + 9313429 commit 4cfbd1a
Show file tree
Hide file tree
Showing 10 changed files with 565 additions and 305 deletions.
490 changes: 254 additions & 236 deletions docs/tutorials/lapse_prob_and_dist.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = "^5.10.4"
pymc = ">=5.10.4, <5.11.0"
scipy = "^1.12.0"
arviz = "^0.17.0"
numpy = "^1.26.4"
Expand Down
49 changes: 42 additions & 7 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provide default configurations for models in the HSSM class."""

from enum import Enum
from os import PathLike
from typing import Callable, Literal, Optional, TypedDict, Union
Expand Down Expand Up @@ -91,15 +92,21 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
},
"approx_differentiable": {
"loglik": "ddm.onnx",
"backend": "jax",
"default_priors": {},
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand All @@ -116,7 +123,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -136,7 +143,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -148,7 +155,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"bounds": {
Expand All @@ -168,7 +175,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -188,7 +195,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand Down Expand Up @@ -321,6 +328,34 @@ class DefaultConfig(TypedDict):
},
}

INITVAL_SETTINGS = {
# logit link function case
# should never use priors with bounds,
# so no need to take care of _log__, and _interval__ variables
"log_logit": {
"t": -4.0,
"t_Intercept": -4.0,
"v": 0.0,
"a": 0.0,
"a_Intercept": 0.0,
"v_Intercept": 0.0,
"p_outlier": -5.0,
},
# identity link function case,
# need to take care of_log__ and _interval__ variables
"None": {
"t": 0.025,
"t_Intercept": 0.025,
"a": 1.5,
"a_Intercept": 1.5,
"p_outlier": 0.001,
},
}

INITVAL_JITTER_SETTINGS = {
"jitter_epsilon": 0.01,
}


def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> str:
"""Show the defaults for supported models.
Expand Down
7 changes: 6 additions & 1 deletion src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def apply_param_bounds_to_loglik(
return logp


# AF-TODO: define clip params


def ensure_positive_ndt(data, logp, list_params, dist_params):
"""Ensure that the non-decision time is always positive.
Expand Down Expand Up @@ -434,6 +437,8 @@ def dist(cls, **kwargs): # pylint: disable=arguments-renamed
return super().dist(dist_params, **other_kwargs)

def logp(data, *dist_params): # pylint: disable=E0213
# AF-TODO: Apply clipping here

num_params = len(list_params)
extra_fields = []

Expand All @@ -445,7 +450,7 @@ def logp(data, *dist_params): # pylint: disable=E0213
p_outlier = dist_params[-1]
dist_params = dist_params[:-1]
lapse_logp = lapse_func(data[:, 0].eval())

# AF-TODO potentially apply clipping here
logp = loglik(data, *dist_params, *extra_fields)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
Expand Down
130 changes: 112 additions & 18 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from bambi.transformations import transformations_namespace

from hssm.defaults import (
INITVAL_JITTER_SETTINGS,
INITVAL_SETTINGS,
LoglikKind,
MissingDataNetwork,
SupportedModels,
Expand Down Expand Up @@ -225,12 +227,9 @@ def __init__(
model: SupportedModels | str = "ddm",
include: list[dict | Param] | None = None,
model_config: ModelConfig | dict | None = None,
loglik: str
| PathLike
| Callable
| pytensor.graph.Op
| type[pm.Distribution]
| None = None,
loglik: (
str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None
) = None,
loglik_kind: LoglikKind | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0),
Expand All @@ -240,11 +239,9 @@ def __init__(
extra_namespace: dict[str, Any] | None = None,
missing_data: bool | float = False,
deadline: bool | str = False,
loglik_missing_data: str
| PathLike
| Callable
| pytensor.graph.Op
| None = None,
loglik_missing_data: (
str | PathLike | Callable | pytensor.graph.Op | None
) = None,
**kwargs,
):
self.data = data.copy()
Expand All @@ -271,6 +268,7 @@ def __init__(
if isinstance(model_config, ModelConfig)
else ModelConfig(**model_config) # also serves as dict validation
)

# Update loglik with user-provided value
self.model_config.update_loglik(loglik)
# Ensure that all required fields are valid
Expand Down Expand Up @@ -348,6 +346,7 @@ def __init__(
# Get the bambi formula, priors, and link
self.formula, self.priors, self.link = self._parse_bambi()

# print(self.priors)
# For parameters that are regression, apply bounds at the likelihood level to
# ensure that the samples that are out of bounds are discarded (replaced with
# a large negative value).
Expand Down Expand Up @@ -385,11 +384,16 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
self._jitter_initvals(
jitter_epsilon=INITVAL_JITTER_SETTINGS["jitter_epsilon"], vector_only=True
)

def sample(
self,
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]
| None = None,
sampler: (
Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"] | None
) = None,
init: str | None = None,
**kwargs,
) -> az.InferenceData | pm.Approximation:
Expand Down Expand Up @@ -465,6 +469,30 @@ def sample(
else:
init = "auto"

# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if "jitter" not in kwargs.keys():
kwargs["jitter"] = False
elif kwargs["jitter"]:
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
elif sampler != "nuts_numpyro":
if "jitter" in kwargs.keys():
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
else:
pass

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)
Expand Down Expand Up @@ -859,6 +887,11 @@ def __repr__(self) -> str:
prior = param.prior
output.append(f" Prior: {prior}")
output.append(f" Explicit bounds: {param.bounds}")
output.append(
" (ignored due to link function)"
if self.link_settings is not None
else ""
)

if self.p_outlier is not None:
# TODO: Allow regression for self.p_outlier
Expand Down Expand Up @@ -1068,6 +1101,11 @@ def _override_defaults(self):
)
for param in self.list_params:
param_obj = self.params[param]
# print(param)
# print('bounds before prior settings: ', param_obj.bounds)

if self.link_settings == "log_logit":
param_obj.override_default_link()
if self.prior_settings == "safe":
if is_ddm:
param_obj.override_default_priors_ddm(
Expand All @@ -1077,8 +1115,6 @@ def _override_defaults(self):
param_obj.override_default_priors(
self.data, self.additional_namespace
)
if self.link_settings == "log_logit":
param_obj.override_default_link()

def _process_all(self):
"""Process all params."""
Expand Down Expand Up @@ -1231,9 +1267,11 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
list_params=self.list_params,
bounds=self.bounds,
lapse=self.lapse,
extra_fields=None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields],
extra_fields=(
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
),
)

def _check_extra_fields(self, data: pd.DataFrame | None = None) -> bool:
Expand Down Expand Up @@ -1369,6 +1407,62 @@ def _post_check_data_sanity(self):
+ "which is not allowed."
)

def _postprocess_initvals_deterministic(
self, initval_settings: dict = INITVAL_SETTINGS
) -> None:
"""Set initial values for subset of parameters."""
# Consider case where link functions are set to 'log_logit'
# or 'None'
if self.link_settings not in ["log_logit", "None", None]:
print(
"Not preprocessing initial values, "
+ "because none of the two standard link settings are chosen!"
)
return None

link_setting_str = str(self.link_settings)
# Set initial values for particular parameters
for name_, starting_value in self.pymc_model.initial_point().items():
name_tmp = name_
name_tmp = name_tmp.replace("_log__", "")
name_tmp = name_tmp.replace("_interval__", "")
if name_tmp in initval_settings[link_setting_str].keys():
# Apply specific settings from initval_settings dictionary
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_tmp],
initval_settings[link_setting_str][name_tmp],
)

def _jitter_initvals(
self, jitter_epsilon: float = 0.01, vector_only: bool = False
) -> None:
"""Apply controlled jitter to initial values."""
if vector_only:
self.__jitter_initvals_vector_only(jitter_epsilon)
else:
self.__jitter_initvals_all(jitter_epsilon)

def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None:
initial_point_dict = self.pymc_model.initial_point()
for name_, starting_value in initial_point_dict.items():
if starting_value.ndim != 0 and starting_value.shape[0] != 1:
starting_value_tmp = starting_value + np.random.uniform(
-jitter_epsilon, jitter_epsilon, starting_value.shape
).astype(np.float32)
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_], starting_value_tmp
)

def __jitter_initvals_all(self, jitter_epsilon: float) -> None:
initial_point_dict = self.pymc_model.initial_point()
for name_, starting_value in initial_point_dict.items():
starting_value_tmp = starting_value + np.random.uniform(
-jitter_epsilon, jitter_epsilon, starting_value.shape
).astype(np.float32)
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_], starting_value_tmp
)


def _set_missing_data_and_deadline(
missing_data: bool, deadline: bool, data: pd.DataFrame
Expand Down
Loading

0 comments on commit 4cfbd1a

Please sign in to comment.