Skip to content

Commit

Permalink
[MAINT] raise errors properly for assertions in the main module
Browse files Browse the repository at this point in the history
Move test_all to the test script
  • Loading branch information
htwangtw committed Jul 10, 2024
1 parent a3184d2 commit a8bfce0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
42 changes: 19 additions & 23 deletions general_class_balancer/general_class_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def discretize_value(v, buckets):
return i
else:
return np.searchsorted(buckets, v)
assert False


def get_prime_form(confounds, n_buckets, sorted_confounds=None):
Expand Down Expand Up @@ -71,7 +70,8 @@ def get_prime_form(confounds, n_buckets, sorted_confounds=None):

# Given buckets, selects values that fall into each one
def get_class_selection(classes, primed, unique_classes=None):
assert len(classes) == len(primed)
if len(classes) != len(primed):
raise ValueError("Classes and primed must be the same length")
if unique_classes is None:
num_classes = len(np.unique(classes))
else:
Expand Down Expand Up @@ -106,7 +106,7 @@ def multi_mannwhitneyu(arr):
for j in range(i + 1, len(arr)):
try:
s, p = stats.ttest_ind(arr[i], arr[j])
except:
except: # HTW: ignore using bare except
p = 1
if p > max_p:
max_p = p
Expand All @@ -115,21 +115,11 @@ def multi_mannwhitneyu(arr):
return min_p, max_p


def test_all(classes, confounds):
unique_classes = np.unique(classes)
all_min_p = np.inf
for i in range(confounds.shape[0]):
if not isinstance(confounds[i, 0], str):
ts = [confounds[i, classes == j] for j in unique_classes]
min_p, max_p = multi_mannwhitneyu(ts)
if min_p < all_min_p:
all_min_p = min_p
return all_min_p


def integrate_arrs(S1, S2):
assert len(S1) >= len(S2)
assert np.sum(~S1) == len(S2)
if not len(S1) >= len(S2):
raise ValueError
if not np.sum(~S1) == len(S2):
raise ValueError
if len(S1) == len(S2):
return S2
i = 0
Expand All @@ -140,13 +130,16 @@ def integrate_arrs(S1, S2):
output[i] = S2[i2]
i2 += 1
i += 1
assert np.sum(output) == np.sum(S2)
if not np.sum(output) == np.sum(S2):
raise ValueError
return output


def integrate_arrs_none(S1, S2):
assert len(S1) >= len(S2)
assert np.sum(S1) == len(S2)
if not len(S1) >= len(S2):
raise ValueError
if not np.sum(S1) == len(S2):
raise ValueError
i = 0
i2 = 0
output = np.zeros(S1.shape, dtype=bool)
Expand Down Expand Up @@ -310,13 +303,16 @@ def class_balance(
selection = np.logical_or(selection, recurse_selection)
if exclude_none:
selection = integrate_arrs_none(~has_none, selection)
assert len(selection) == len(has_none)
assert np.sum(~has_none) == len(classes)
if not len(selection) == len(has_none):
raise ValueError
if not np.sum(~has_none) == len(classes):
raise ValueError
return selection


def separate_set(selections, set_divisions=[0.5, 0.5], IDs=None):
assert isinstance(set_divisions, list)
if not isinstance(set_divisions, list):
raise TypeError
set_divisions = [i / np.sum(set_divisions) for i in set_divisions]
rr = list(range(len(selections)))
random.shuffle(rr)
Expand Down
15 changes: 14 additions & 1 deletion general_class_balancer/tests/random_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from general_class_balancer.general_class_balancer import (
class_balance,
test_all,
multi_mannwhitneyu,
)

# Sample script showing how this balances on simulated, random data.
Expand Down Expand Up @@ -52,6 +52,19 @@
+ [2 for x in range(int(N / 3))]
)


def test_all(classes, confounds):
unique_classes = np.unique(classes)
all_min_p = np.inf
for i in range(confounds.shape[0]):
if not isinstance(confounds[i, 0], str):
ts = [confounds[i, classes == j] for j in unique_classes]
min_p, max_p = multi_mannwhitneyu(ts)
if min_p < all_min_p:
all_min_p = min_p
return all_min_p


selection = class_balance(classes, confounds, plim=0.25)
print(np.sum(selection))
print(test_all(classes[selection], confounds[:, selection]))
Expand Down

0 comments on commit a8bfce0

Please sign in to comment.