Skip to content

Commit

Permalink
Update failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Mar 14, 2024
1 parent 2967911 commit 19e77fa
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 10 deletions.
9 changes: 5 additions & 4 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,11 @@ def rng_fn(
+ "distribution but did not specify the distribution."
)
out_shape = sims_out.shape[:-1]
if p_outlier.shape[-1] == 1:
p_outlier = np.broadcast_to(p_outlier, out_shape)
else:
p_outlier = p_outlier.reshape(out_shape)
if not np.isscalar(p_outlier) and len(p_outlier.shape) > 0:
if p_outlier.shape[-1] == 1:
p_outlier = np.broadcast_to(p_outlier, out_shape)
else:
p_outlier = p_outlier.reshape(out_shape)
replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool)
replace_n = int(np.sum(replace, axis=None))
if replace_n == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/plotting/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def plot_posterior_predictive(
# Flip the rt values if necessary
if np.any(plotting_df["response"] == 0):
plotting_df["response"] = np.where(plotting_df["response"] == 0, -1, 1)
if model.n_responses == 2:
if model.n_choices == 2:
plotting_df["rt"] = plotting_df["rt"] * plotting_df["response"]

# Then, plot the posterior predictive distribution against the observed data
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/plotting/quantile_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def plot_quantile_probability(
# Flip the rt values if necessary
if np.any(plotting_df["response"] == 0):
plotting_df["response"] = np.where(plotting_df["response"] == 0, -1, 1)
if model.n_responses == 2:
if model.n_choices == 2:
plotting_df["rt"] = plotting_df["rt"] * plotting_df["response"]

# If group is not provided, we are producing a single plot
Expand Down
File renamed without changes.
6 changes: 2 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

def test_get_alias_dict():
# Simulate some data:
v_true, a_true, z_true, t_true, sv_true = [0.5, 1.5, 0.5, 0.5, 0.3]
obs_ddm = simulator(
[v_true, a_true, z_true, t_true, sv_true], model="ddm", n_samples=1000
)
v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.5]
obs_ddm = simulator([v_true, a_true, z_true, t_true], model="ddm", n_samples=1000)
obs_ddm = np.column_stack([obs_ddm["rts"][:, 0], obs_ddm["choices"][:, 0]])

dataset = pd.DataFrame(obs_ddm, columns=["rt", "response"])
Expand Down

0 comments on commit 19e77fa

Please sign in to comment.