Skip to content

Commit

Permalink
Merge pull request #83 from mggg/datachecks
Browse files Browse the repository at this point in the history
Data checking
  • Loading branch information
gabeschoenbach authored Oct 21, 2021
2 parents c94f1c2 + 588d0d8 commit c359865
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
14 changes: 13 additions & 1 deletion pyei/r_by_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, model_name, **additional_model_params):
self.credible_interval_95_mean_voting_prefs = None
self.num_groups_and_num_candidates = [None, None]

def fit(
def fit( # pylint: disable=too-many-branches
self,
group_fractions,
votes_fractions,
Expand Down Expand Up @@ -214,6 +214,18 @@ def fit(
raise ValueError("all elements of precinct_pops must be integer-valued")
self.precinct_pops = precinct_pops

# check that group_fractions and vote_fractions sum to 1 in each precinct
if not np.isclose(group_fractions.sum(axis=0), 1.0).all():
raise ValueError("group_fractions should sum to 1 within each precinct")
if not np.isclose(votes_fractions.sum(axis=0), 1.0).all():
raise ValueError("votes_fractions should sum to 1 within each precinct")

# check that group_fractions and vote_fractions are nonnegative
if not (group_fractions >= 0).all():
raise ValueError("group_fractions must be non-negative")
if not (votes_fractions >= 0).all():
raise ValueError("votes_fractions most be non-negative")

# give demographic groups, candidates 1-indexed numbers as names if names are not specified
if demographic_group_names is None:
demographic_group_names = [str(i) for i in range(1, group_fractions.shape[0] + 1)]
Expand Down
2 changes: 1 addition & 1 deletion pyei/two_by_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def fit(
if draw_samples:
# TODO: this workaround shouldn't be necessary. Modify the model so that the checks
# can run without error
if self.model_name == "wakefield_beta" or self.model_name == "wakefield_normal":
if self.model_name in ("wakefield_beta", "wakefield_normal"):
compute_convergence_checks = False
print("WARNING: some convergence checks currently disabled for wakefield model")
else:
Expand Down

0 comments on commit c359865

Please sign in to comment.