From 19e77fa6fc92d0bc5788e7c79ddfd571e8e2f921 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 13 Mar 2024 21:39:41 -0400 Subject: [PATCH] Update failing tests --- src/hssm/distribution_utils/dist.py | 9 +++++---- src/hssm/plotting/posterior_predictive.py | 2 +- src/hssm/plotting/quantile_probability.py | 2 +- tests/{ => slow}/test_mcmc.py | 0 tests/test_utils.py | 6 ++---- 5 files changed, 9 insertions(+), 10 deletions(-) rename tests/{ => slow}/test_mcmc.py (100%) diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index 86b2162e..bdc1cb37 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -324,10 +324,11 @@ def rng_fn( + "distribution but did not specify the distribution." ) out_shape = sims_out.shape[:-1] - if p_outlier.shape[-1] == 1: - p_outlier = np.broadcast_to(p_outlier, out_shape) - else: - p_outlier = p_outlier.reshape(out_shape) + if not np.isscalar(p_outlier) and len(p_outlier.shape) > 0: + if p_outlier.shape[-1] == 1: + p_outlier = np.broadcast_to(p_outlier, out_shape) + else: + p_outlier = p_outlier.reshape(out_shape) replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool) replace_n = int(np.sum(replace, axis=None)) if replace_n == 0: diff --git a/src/hssm/plotting/posterior_predictive.py b/src/hssm/plotting/posterior_predictive.py index 2af5c956..57d84038 100644 --- a/src/hssm/plotting/posterior_predictive.py +++ b/src/hssm/plotting/posterior_predictive.py @@ -380,7 +380,7 @@ def plot_posterior_predictive( # Flip the rt values if necessary if np.any(plotting_df["response"] == 0): plotting_df["response"] = np.where(plotting_df["response"] == 0, -1, 1) - if model.n_responses == 2: + if model.n_choices == 2: plotting_df["rt"] = plotting_df["rt"] * plotting_df["response"] # Then, plot the posterior predictive distribution against the observed data diff --git a/src/hssm/plotting/quantile_probability.py b/src/hssm/plotting/quantile_probability.py index 8764ebdb..f81f9da3 100644 --- a/src/hssm/plotting/quantile_probability.py +++ b/src/hssm/plotting/quantile_probability.py @@ -346,7 +346,7 @@ def plot_quantile_probability( # Flip the rt values if necessary if np.any(plotting_df["response"] == 0): plotting_df["response"] = np.where(plotting_df["response"] == 0, -1, 1) - if model.n_responses == 2: + if model.n_choices == 2: plotting_df["rt"] = plotting_df["rt"] * plotting_df["response"] # If group is not provided, we are producing a single plot diff --git a/tests/test_mcmc.py b/tests/slow/test_mcmc.py similarity index 100% rename from tests/test_mcmc.py rename to tests/slow/test_mcmc.py diff --git a/tests/test_utils.py b/tests/test_utils.py index 6e629aa8..c63405e0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,10 +17,8 @@ def test_get_alias_dict(): # Simulate some data: - v_true, a_true, z_true, t_true, sv_true = [0.5, 1.5, 0.5, 0.5, 0.3] - obs_ddm = simulator( - [v_true, a_true, z_true, t_true, sv_true], model="ddm", n_samples=1000 - ) + v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.5] + obs_ddm = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=1000) obs_ddm = np.column_stack([obs_ddm["rts"][:, 0], obs_ddm["choices"][:, 0]]) dataset = pd.DataFrame(obs_ddm, columns=["rt", "response"])