diff --git a/tests/lsbase.py b/tests/lsbase.py index f4f1c48..6684585 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -14,7 +14,7 @@ def verify(self, ts): raise NotImplementedError() def assertAllClose(self, A, B): - assert np.allclose(A, B, rtol=1e-9, atol=0.0) + np.testing.assert_allclose(A, B, rtol=1e-9, atol=0.0) # Helper routine def get_num_alleles(self, ref_haps, query): @@ -134,6 +134,7 @@ def get_emission_matrix_diploid(self, mu, m): return e def get_examples_pars_diploid(self, ts, seed=42): + """Returns an iterator over combinations of examples and parameters.""" np.random.seed(seed) H, G, genotypes = self.get_examples_diploid(ts) m = ts.num_sites @@ -156,6 +157,7 @@ def get_examples_pars_diploid(self, ts, seed=42): yield n, m, G, s, e, r, mu def get_examples_pars_larger_diploid(self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42): + """Returns an iterator over combinations of examples and parameters.""" np.random.seed(seed) H, G, genotypes = self.get_examples_diploid(ts) m = H.shape[0]