Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

373 shift prior means for log logit settings #377

Merged
merged 15 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
490 changes: 254 additions & 236 deletions docs/tutorials/lapse_prob_and_dist.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = "^5.10.4"
pymc = ">=5.10.4, <5.11.0"
scipy = "^1.12.0"
arviz = "^0.17.0"
numpy = "^1.26.4"
Expand Down
49 changes: 42 additions & 7 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provide default configurations for models in the HSSM class."""

from enum import Enum
from os import PathLike
from typing import Callable, Literal, Optional, TypedDict, Union
Expand Down Expand Up @@ -91,15 +92,21 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
},
"approx_differentiable": {
"loglik": "ddm.onnx",
"backend": "jax",
"default_priors": {},
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.05,
},
},
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand All @@ -116,7 +123,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -136,7 +143,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -148,7 +155,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"bounds": {
Expand All @@ -168,7 +175,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand All @@ -188,7 +195,7 @@ class DefaultConfig(TypedDict):
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
"initval": 0.05,
},
},
"extra_fields": None,
Expand Down Expand Up @@ -321,6 +328,34 @@ class DefaultConfig(TypedDict):
},
}

INITVAL_SETTINGS = {
# logit link function case
# should never use priors with bounds,
# so no need to take care of _log__, and _interval__ variables
"log_logit": {
"t": -4.0,
"t_Intercept": -4.0,
"v": 0.0,
"a": 0.0,
"a_Intercept": 0.0,
"v_Intercept": 0.0,
"p_outlier": -5.0,
},
# identity link function case,
# need to take care of_log__ and _interval__ variables
"None": {
"t": 0.025,
"t_Intercept": 0.025,
"a": 1.5,
"a_Intercept": 1.5,
"p_outlier": 0.001,
},
}

INITVAL_JITTER_SETTINGS = {
"jitter_epsilon": 0.01,
}


def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> str:
"""Show the defaults for supported models.
Expand Down
7 changes: 6 additions & 1 deletion src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def apply_param_bounds_to_loglik(
return logp


# AF-TODO: define clip params


def ensure_positive_ndt(data, logp, list_params, dist_params):
"""Ensure that the non-decision time is always positive.

Expand Down Expand Up @@ -434,6 +437,8 @@ def dist(cls, **kwargs): # pylint: disable=arguments-renamed
return super().dist(dist_params, **other_kwargs)

def logp(data, *dist_params): # pylint: disable=E0213
# AF-TODO: Apply clipping here

num_params = len(list_params)
extra_fields = []

Expand All @@ -445,7 +450,7 @@ def logp(data, *dist_params): # pylint: disable=E0213
p_outlier = dist_params[-1]
dist_params = dist_params[:-1]
lapse_logp = lapse_func(data[:, 0].eval())

# AF-TODO potentially apply clipping here
logp = loglik(data, *dist_params, *extra_fields)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
Expand Down
130 changes: 112 additions & 18 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from bambi.transformations import transformations_namespace

from hssm.defaults import (
INITVAL_JITTER_SETTINGS,
INITVAL_SETTINGS,
LoglikKind,
MissingDataNetwork,
SupportedModels,
Expand Down Expand Up @@ -225,12 +227,9 @@ def __init__(
model: SupportedModels | str = "ddm",
include: list[dict | Param] | None = None,
model_config: ModelConfig | dict | None = None,
loglik: str
| PathLike
| Callable
| pytensor.graph.Op
| type[pm.Distribution]
| None = None,
loglik: (
str | PathLike | Callable | pytensor.graph.Op | type[pm.Distribution] | None
) = None,
loglik_kind: LoglikKind | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0),
Expand All @@ -240,11 +239,9 @@ def __init__(
extra_namespace: dict[str, Any] | None = None,
missing_data: bool | float = False,
deadline: bool | str = False,
loglik_missing_data: str
| PathLike
| Callable
| pytensor.graph.Op
| None = None,
loglik_missing_data: (
str | PathLike | Callable | pytensor.graph.Op | None
) = None,
**kwargs,
):
self.data = data.copy()
Expand All @@ -271,6 +268,7 @@ def __init__(
if isinstance(model_config, ModelConfig)
else ModelConfig(**model_config) # also serves as dict validation
)

# Update loglik with user-provided value
self.model_config.update_loglik(loglik)
# Ensure that all required fields are valid
Expand Down Expand Up @@ -348,6 +346,7 @@ def __init__(
# Get the bambi formula, priors, and link
self.formula, self.priors, self.link = self._parse_bambi()

# print(self.priors)
# For parameters that are regression, apply bounds at the likelihood level to
# ensure that the samples that are out of bounds are discarded (replaced with
# a large negative value).
Expand Down Expand Up @@ -385,11 +384,16 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
self._jitter_initvals(
jitter_epsilon=INITVAL_JITTER_SETTINGS["jitter_epsilon"], vector_only=True
)

def sample(
self,
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]
| None = None,
sampler: (
Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"] | None
) = None,
init: str | None = None,
**kwargs,
) -> az.InferenceData | pm.Approximation:
Expand Down Expand Up @@ -465,6 +469,30 @@ def sample(
else:
init = "auto"

# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if "jitter" not in kwargs.keys():
kwargs["jitter"] = False
elif kwargs["jitter"]:
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
elif sampler != "nuts_numpyro":
if "jitter" in kwargs.keys():
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
else:
pass

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)
Expand Down Expand Up @@ -859,6 +887,11 @@ def __repr__(self) -> str:
prior = param.prior
output.append(f" Prior: {prior}")
output.append(f" Explicit bounds: {param.bounds}")
output.append(
" (ignored due to link function)"
if self.link_settings is not None
else ""
)

if self.p_outlier is not None:
# TODO: Allow regression for self.p_outlier
Expand Down Expand Up @@ -1068,6 +1101,11 @@ def _override_defaults(self):
)
for param in self.list_params:
param_obj = self.params[param]
# print(param)
# print('bounds before prior settings: ', param_obj.bounds)

if self.link_settings == "log_logit":
param_obj.override_default_link()
if self.prior_settings == "safe":
if is_ddm:
param_obj.override_default_priors_ddm(
Expand All @@ -1077,8 +1115,6 @@ def _override_defaults(self):
param_obj.override_default_priors(
self.data, self.additional_namespace
)
if self.link_settings == "log_logit":
param_obj.override_default_link()

def _process_all(self):
"""Process all params."""
Expand Down Expand Up @@ -1231,9 +1267,11 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
list_params=self.list_params,
bounds=self.bounds,
lapse=self.lapse,
extra_fields=None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields],
extra_fields=(
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
),
)

def _check_extra_fields(self, data: pd.DataFrame | None = None) -> bool:
Expand Down Expand Up @@ -1369,6 +1407,62 @@ def _post_check_data_sanity(self):
+ "which is not allowed."
)

def _postprocess_initvals_deterministic(
self, initval_settings: dict = INITVAL_SETTINGS
) -> None:
"""Set initial values for subset of parameters."""
# Consider case where link functions are set to 'log_logit'
# or 'None'
if self.link_settings not in ["log_logit", "None", None]:
print(
"Not preprocessing initial values, "
+ "because none of the two standard link settings are chosen!"
)
return None

link_setting_str = str(self.link_settings)
# Set initial values for particular parameters
for name_, starting_value in self.pymc_model.initial_point().items():
name_tmp = name_
name_tmp = name_tmp.replace("_log__", "")
name_tmp = name_tmp.replace("_interval__", "")
if name_tmp in initval_settings[link_setting_str].keys():
# Apply specific settings from initval_settings dictionary
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_tmp],
initval_settings[link_setting_str][name_tmp],
)

def _jitter_initvals(
self, jitter_epsilon: float = 0.01, vector_only: bool = False
) -> None:
"""Apply controlled jitter to initial values."""
if vector_only:
self.__jitter_initvals_vector_only(jitter_epsilon)
else:
self.__jitter_initvals_all(jitter_epsilon)

def __jitter_initvals_vector_only(self, jitter_epsilon: float) -> None:
initial_point_dict = self.pymc_model.initial_point()
for name_, starting_value in initial_point_dict.items():
if starting_value.ndim != 0 and starting_value.shape[0] != 1:
starting_value_tmp = starting_value + np.random.uniform(
-jitter_epsilon, jitter_epsilon, starting_value.shape
).astype(np.float32)
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_], starting_value_tmp
)

def __jitter_initvals_all(self, jitter_epsilon: float) -> None:
initial_point_dict = self.pymc_model.initial_point()
for name_, starting_value in initial_point_dict.items():
starting_value_tmp = starting_value + np.random.uniform(
-jitter_epsilon, jitter_epsilon, starting_value.shape
).astype(np.float32)
self.pymc_model.set_initval(
self.pymc_model.named_vars[name_], starting_value_tmp
)


def _set_missing_data_and_deadline(
missing_data: bool, deadline: bool, data: pd.DataFrame
Expand Down
Loading
Loading