diff --git a/pyproject.toml b/pyproject.toml index 5abca6a0..65c97446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ numpy = ">=1.23.4,<1.26" onnx = "^1.12.0" jax = "^0.4.0" jaxlib = "^0.4.0" -ssm-simulators = "^0.4.1" +ssm-simulators = "0.5.1" huggingface-hub = "^0.15.1" onnxruntime = "^1.15.0" bambi = "^0.12.0" diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index b549fab5..a6adda00 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -183,12 +183,18 @@ def __init__( loglik_kind: LoglikKind | None = None, p_outlier: float | dict | bmb.Prior | None = 0.05, lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0), - hierarchical: bool = True, + hierarchical: bool = False, **kwargs, ): self.data = data self._inference_obj = None - self.hierarchical = hierarchical and "participant_id" in data.columns + self.hierarchical = hierarchical + + if self.hierarchical and "participant_id" not in self.data.columns: + raise ValueError( + "You have specified a hierarchical model, but there is no " + + "`participant_id` field in the DataFrame that you have passed." + ) # Construct a model_config from defaults self.model_config = Config.from_defaults(model, loglik_kind) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 666c7fbe..65261f55 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -224,21 +224,27 @@ def test_hierarchical(data_ddm): data_ddm = data_ddm.iloc[:10, :].copy() data_ddm["participant_id"] = np.arange(10) - model = HSSM(data=data_ddm) + model = HSSM(data=data_ddm, hierarchical=True) assert all( param.is_regression for name, param in model.params.items() if name != "p_outlier" ) - model = HSSM(data=data_ddm, v=bmb.Prior("Uniform", lower=-10.0, upper=10.0)) + model = HSSM( + data=data_ddm, + v=bmb.Prior("Uniform", lower=-10.0, upper=10.0), + hierarchical=True, + ) assert all( param.is_regression for name, param in model.params.items() if name not in ["v", "p_outlier"] ) - model = HSSM(data=data_ddm, a=bmb.Prior("Uniform", lower=0.0, upper=10.0)) + model = HSSM( + data=data_ddm, a=bmb.Prior("Uniform", lower=0.0, upper=10.0), hierarchical=True + ) assert all( param.is_regression for name, param in model.params.items() @@ -249,6 +255,7 @@ def test_hierarchical(data_ddm): data=data_ddm, v=bmb.Prior("Uniform", lower=-10.0, upper=10.0), a=bmb.Prior("Uniform", lower=0.0, upper=10.0), + hierarchical=True, ) assert all( param.is_regression