From e26d2ae3c10a1c3d301554519892fd0b2ec507af Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 29 May 2023 13:04:28 +0200 Subject: [PATCH] support any ploidy + improve tests --- chromax/functional.py | 48 +++++++++++++++------------- chromax/simulator.py | 50 ++++++++++++++---------------- chromax/typing.py | 2 +- tests/mock_simulator.py | 4 +-- tests/test_functional.py | 53 +++++++++++++++++++++++++++++++ tests/test_simulator.py | 67 ++++++++++++++++++++++++++++------------ 6 files changed, 154 insertions(+), 70 deletions(-) create mode 100644 tests/test_functional.py diff --git a/chromax/functional.py b/chromax/functional.py index 5b2f86f..808bf29 100644 --- a/chromax/functional.py +++ b/chromax/functional.py @@ -15,22 +15,22 @@ def cross( Args: - parents (array): parents to compute the cross. The shape of - the parents is (n, 2, m, 2), where n is the number of parents - and m is the number of markers. + the parents is (n, 2, m, d), where n is the number of parents, + m is the number of markers, and d is the ploidy. - recombination_vec (array): array of m probabilities. The i-th value represent the probability to recombine before the marker i. - random_key (array): PRNGKey, for reproducibility purpose Returns: - - population (array): offspring population of shape (n, m, 2). + - population (array): offspring population of shape (n, m, d). Example: >>> from chromax import functional >>> import numpy as np >>> import jax - >>> n_chr, chr_len = 10, 100 + >>> n_chr, chr_len, ploidy = 10, 100, 2 >>> n_crosses = 50 - >>> parents_shape = (n_crosses, 2, n_chr * chr_len, 2) + >>> parents_shape = (n_crosses, 2, n_chr * chr_len, ploidy) >>> parents = np.random.choice([False, True], size=parents_shape) >>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len) >>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid @@ -40,13 +40,16 @@ def cross( >>> f2.shape (50, 1000, 2) """ - random_keys = jax.random.split(random_key, num=len(parents) * 2) - random_keys = random_keys.reshape(len(parents), 2, 2) - return _cross(parents, recombination_vec, random_keys) + parents = parents.reshape(*parents.shape[:3], -1, 2) + random_keys = jax.random.split(random_key, num=len(parents) * 2 * parents.shape[3]) + random_keys = random_keys.reshape(len(parents), 2, parents.shape[3], 2) + offsprings = _cross(parents, recombination_vec, random_keys) + return offsprings.reshape(*offsprings.shape[:-2], -1) + @jax.jit @partial(jax.vmap, in_axes=(0, None, 0)) # parallelize across individuals -@partial(jax.vmap, in_axes=(0, None, 0), out_axes=1) # parallelize parents +@partial(jax.vmap, in_axes=(0, None, 0), out_axes=2) # parallelize parents def _cross( parent: Individual, recombination_vec: Float[Array, N_MARKERS], @@ -68,22 +71,22 @@ def double_haploid( """Computes the double haploid of the input population. Args: - - population (array): input population of shape (n, m, 2). + - population (array): input population of shape (n, m, d). - n_offspring (int): number of offspring per plant. - recombination_vec (array): array of m probabilities. The i-th value represent the probability to recombine before the marker i. - random_key (array): array of n PRNGKey, one for each individual. Returns: - - population (array): output population of shape (n, n_offspring, m, 2). + - population (array): output population of shape (n, n_offspring, m, d). This population will be homozygote. Example: >>> from chromax import functional >>> import numpy as np >>> import jax - >>> n_chr, chr_len = 10, 100 - >>> pop_shape = (50, n_chr * chr_len, 2) + >>> n_chr, chr_len, ploidy = 10, 100, 2 + >>> pop_shape = (50, n_chr * chr_len, ploidy) >>> f1 = np.random.choice([False, True], size=pop_shape) >>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len) >>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid @@ -93,12 +96,14 @@ def double_haploid( >>> dh.shape (50, 10, 1000, 2) """ + population = population.reshape(*population.shape[:2], -1, 2) keys = jax.random.split( random_key, - num=len(population) * n_offspring - ).reshape(len(population), n_offspring, 2) - haploid = _double_haploid(population, recombination_vec, keys) - return jnp.broadcast_to(haploid[..., None], shape=(*haploid.shape, 2)) + num=len(population) * n_offspring * population.shape[2] + ).reshape(len(population), n_offspring, population.shape[2], 2) + haploids = _double_haploid(population, recombination_vec, keys) + dh_pop = jnp.broadcast_to(haploids[..., None], shape=(*haploids.shape, 2)) + return dh_pop.reshape(*dh_pop.shape[:-2], -1) @jax.jit @@ -117,6 +122,7 @@ def _double_haploid( @jax.jit +@partial(jax.vmap, in_axes=(1, None, 0), out_axes=1) # parallelize pair of chromosomes def _meiosis( individual: Individual, recombination_vec: Float[Array, N_MARKERS], @@ -144,22 +150,22 @@ def select( """Function to select individuals based on their score (index). Args: - - population (array): input grouped population of shape (n, m, 2) + - population (array): input grouped population of shape (n, m, d) - k (int): number of individual to select. - f_index (function): function that computes a score for each individual. The function accepts as input a population, i.e. and array of shape (n, m, 2) and returns an arrray of n float number. Returns: - - population (array): output population of (k, m, 2) + - population (array): output population of (k, m, d) Example: >>> from chromax import functional >>> from chromax.trait_model import TraitModel >>> from chromax.index_functions import conventional_index >>> import numpy as np - >>> n_chr, chr_len = 10, 100 - >>> pop_shape = (50, n_chr * chr_len, 2) + >>> n_chr, chr_len, ploidy = 10, 100, 2 + >>> pop_shape = (50, n_chr * chr_len, ploidy) >>> f1 = np.random.choice([False, True], size=pop_shape) >>> marker_effects = np.random.randn(n_chr * chr_len) >>> gebv_model = TraitModel(marker_effects[:, None]) diff --git a/chromax/simulator.py b/chromax/simulator.py index 6aaf45c..ecaa885 100644 --- a/chromax/simulator.py +++ b/chromax/simulator.py @@ -204,12 +204,12 @@ def cross(self, parents: Parents["n"]) -> Population["n"]: """Main function that computes crosses from a list of parents. :param parents: parents to compute the cross. The shape of - the parents is (n, 2, m, 2), where n is the number of parents - and m is the number of markers. + the parents is (n, 2, m, d), where n is the number of parents, + m is the number of markers, and d is the ploidy. :type parents: ndarray - :return: offspring population of shape (n, m, 2). + :return: offspring population of shape (n, m, d). :rtype: ndarray :Example: @@ -237,21 +237,21 @@ def differentiable_cross_func(self) -> Callable: The differentiable crossing function takes as input: - population (array): starting population from which performing the crosses. - The shape of the population is (n, m, 2). - - cross_weights (array): Array of shape (l, n, 2). It is used to compute + The shape of the population is (n, m, d). + - cross_weights (array): Array of shape (l, n, d). It is used to compute l crosses, starting from a weighted average of the n possible parents. When the n-axis has all zeros except of a single element equals to one, this function is equivalent to the cross function. - random_key (JAX random key): random key used for recombination sampling. - And returns a population of shape (l, m, 2). + And returns a population of shape (l, m, d). :Example: >>> from chromax import Simulator, sample_data >>> import numpy as np >>> import jax >>> simulator = Simulator(genetic_map=sample_data.genetic_map) - >>> diff_cross = simulator.differentiable_cross_func() + >>> diff_cross = simulator.differentiable_cross_func >>> def mean_gebv(pop, weights, random_key): new_pop = diff_cross(pop, weights, random_key) return simulator.GEBV(new_pop, raw_array=True).mean() @@ -279,10 +279,12 @@ def diff_cross_f( cross_weights: Float[Array, "m n 2"], random_key: jax.random.PRNGKeyArray ) -> Population["m"]: - num_keys = len(cross_weights) * len(population) * 2 - keys = jax.random.split(random_key, num=num_keys) - keys = keys.reshape(len(cross_weights), len(population), 2, 2) + population = population.reshape(*population.shape[:-1], -1, 2) + keys_shape = len(cross_weights), len(population), 2, population.shape[-2] + keys = jax.random.split(random_key, num=np.prod(keys_shape)) + keys = keys.reshape(*keys_shape, 2) outer_res = cross_pop(population, self.recombination_vec, keys) + outer_res = outer_res.reshape(*outer_res.shape[:-2], -1) return (cross_weights[:, :, None, :] * outer_res).sum(axis=1) return diff_cross_f @@ -332,13 +334,13 @@ def diallel( """Diallel crossing function, i.e. crossing between every possible couple, except self-crossing. - :param population: input population of shape (n, m, 2). + :param population: input population of shape (n, m, d). :type population: ndarray :param n_offspring: number of offspring per cross. The default value is 1. :type n_offspring: int - :return: output population of shape (l, n_offspring, m, 2), + :return: output population of shape (l, n_offspring, m, d), where l is the number of possible pair, i.e `n * (n-1) / 2`. :rtype: ndarray @@ -380,7 +382,7 @@ def random_crosses( ) -> Population["n_crosses n_offspring"]: """Computes random crosses on a population. - :param population: input population of shape (n, m, 2). + :param population: input population of shape (n, m, d). :type population: ndarray :param n_crosses: number of random crosses to perform. :type n_crosses: int @@ -388,7 +390,7 @@ def random_crosses( The default value is 1. :type n_offspring: int - :return: output population of shape (n_crosses, n_offspring, m, 2). + :return: output population of shape (n_crosses, n_offspring, m, d). :rtype: ndarray :Example: @@ -399,12 +401,6 @@ def random_crosses( >>> f2.shape (100, 10, 9839, 2) """ - - if n_crosses < 1: - raise ValueError("n_crosses must be higher or equal to 1") - if n_offspring < 1: - raise ValueError("n_offspring must be higher or equal to 1") - all_indices = np.arange(len(population)) diallel_indices = self._diallel_indices(all_indices) if n_crosses > len(diallel_indices): @@ -434,18 +430,18 @@ def select( ) -> Population["_g k"]: """Function to select individuals based on their score (index). - :param population: input population of shape (n, m, 2), - or shape (g, n, m, 2), to select k individual from each group population group g. + :param population: input population of shape (n, m, d), + or shape (g, n, m, d), to select k individual from each group population group g. :type population: ndarray :param k: number of individual to select. :type k: int :param f_index: function that computes a score from each individual. The function accepts as input the population, i.e. and array of shape - (n, m, 2) and returns a n float numbers. The default f_index is the conventional index, + (n, m, d) and returns a n float numbers. The default f_index is the conventional index, i.e. the sum of the marker effects masked with the SNPs from the genetic_map. :type f_index: Callable - :return: output population of shape (k, m, 2) or (g, k, m, 2), + :return: output population of shape (k, m, d) or (g, k, m, d), depending on the input population. :rtype: ndarray @@ -480,7 +476,7 @@ def GEBV( """Computes the Genomic Estimated Breeding Values using the marker effects from the genetic_map. - :param population: input population of shape (n, m, 2). + :param population: input population of shape (n, m, d). :type population: ndarray :param raw_array: whether to return a raw array or a DataFrame. Deafult value is False. @@ -541,7 +537,7 @@ def phenotype( This uses the Genotype-by-Environment model described in the following: https://cran.r-project.org/web/packages/AlphaSimR/vignettes/traits.pdf - :param population: input population of shape (n, m, 2) + :param population: input population of shape (n, m, d) :type population: ndarray :param num_environments: number of environments to test the population. Default value is 1. @@ -597,7 +593,7 @@ def corrcoef( """Computes the correlation coefficient of the population against its centroid. It can be used as an indicator of variance in the population. - :param population: input population of shape (n, m, 2) + :param population: input population of shape (n, m, d) :type population: ndarray :return: vector of length n, containing the correlation coefficient diff --git a/chromax/typing.py b/chromax/typing.py index 90fa99c..c42057b 100644 --- a/chromax/typing.py +++ b/chromax/typing.py @@ -1,7 +1,7 @@ from jaxtyping import Array, Bool N_MARKERS = "m" -DIPLOID_SHAPE = N_MARKERS + " 2" +DIPLOID_SHAPE = N_MARKERS + "d" Haploid = Bool[Array, N_MARKERS] Individual = Bool[Array, DIPLOID_SHAPE] diff --git a/tests/mock_simulator.py b/tests/mock_simulator.py index fe65eac..d855a8b 100644 --- a/tests/mock_simulator.py +++ b/tests/mock_simulator.py @@ -49,9 +49,9 @@ def __init__( super().__init__(genetic_map=genetic_map, **kwargs) self.recombination_vec = recombination_vec - def load_population(self, n_individual=100): + def load_population(self, n_individual=100, ploidy=2): return np.random.choice( a=[False, True], - size=(n_individual, self.n_markers, 2), + size=(n_individual, self.n_markers, ploidy), p=[0.5, 0.5] ) diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 0000000..4526b47 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,53 @@ +from chromax.index_functions import conventional_index +from chromax.trait_model import TraitModel +import numpy as np +import jax +from chromax import functional +import pytest + + +@pytest.mark.parametrize("idx", [0, 1]) +def test_cross(idx): + n_markers, ploidy = 1000, 4 + n_crosses = 50 + parents_shape = (n_crosses, 2, n_markers, ploidy) + parents = np.random.choice([False, True], size=parents_shape) + rec_vec = np.zeros(n_markers) + rec_vec[0] = idx + random_key = jax.random.PRNGKey(42) + new_pop = functional.cross(parents, rec_vec, random_key) + + for i in range(ploidy): + assert np.all(new_pop[..., i] == parents[:, i % 2, :, i - i % 2 + idx]) + + +def test_double_haploid(): + n_chr, chr_len, ploidy = 10, 100, 2 + n_offspring = 10 + pop_shape = (50, n_chr * chr_len, ploidy) + f1 = np.random.choice([False, True], size=pop_shape) + rec_vec = np.full((n_chr * chr_len,), 1.5 / chr_len) + random_key = jax.random.PRNGKey(42) + dh = functional.double_haploid(f1, n_offspring, rec_vec, random_key) + assert dh.shape == (len(f1), n_offspring, n_chr * chr_len, ploidy) + + for i in range(ploidy // 2): + assert np.all(dh[..., i * 2] == dh[..., i * 2 + 1]) + + +def test_select(): + n_markers, ploidy = 1000, 4 + k = 10 + pop_shape = (50, n_markers, ploidy) + f1 = np.random.choice([False, True], size=pop_shape) + marker_effects = np.random.randn(n_markers) + gebv_model = TraitModel(marker_effects[:, None]) + f_index = conventional_index(gebv_model) + f2 = functional.select(f1, k=k, f_index=f_index) + assert f2.shape == (k, *f1.shape[1:]) + + f1_gebv = gebv_model(f1) + f2_gebv = gebv_model(f2) + assert np.max(f2_gebv) == np.max(f1_gebv) + assert np.mean(f2_gebv) > np.mean(f1_gebv) + assert np.min(f2_gebv) > np.min(f1_gebv) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 983e837..b032f51 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -5,6 +5,7 @@ import numpy as np import pytest import warnings +import pandas as pd @pytest.mark.parametrize("idx", [0, 1]) @@ -14,26 +15,29 @@ def test_cross_r(idx): recombination_vec[0] = idx simulator = MockSimulator(recombination_vec=recombination_vec) - size = (1, 2, simulator.n_markers, 2) + ploidy = 4 + size = (1, 2, simulator.n_markers, ploidy) parents = np.random.choice(a=[False, True], size=size, p=[0.5, 0.5]) new_pop = simulator.cross(parents) - assert new_pop.shape == (1, simulator.n_markers, 2) + assert new_pop.shape == (1, simulator.n_markers, ploidy) ind = new_pop[0] - assert np.all(ind[:, 0] == parents[0, 0, :, idx]) - assert np.all(ind[:, 1] == parents[0, 1, :, idx]) + for i in range(ploidy): + pair_chr_idx = i % 2 + assert np.all(ind[:, i] == parents[0, pair_chr_idx, :, i - pair_chr_idx + idx]) def test_equal_parents(): simulator = Simulator(genetic_map=genetic_map) - parents = np.zeros((1, 2, simulator.n_markers, 2), dtype="bool") + ploidy = 4 + parents = np.zeros((1, 2, simulator.n_markers, ploidy), dtype="bool") child = simulator.cross(parents) assert np.all(child == 0) - parents = np.ones((1, 2, simulator.n_markers, 2), dtype="bool") + parents = np.ones((1, 2, simulator.n_markers, ploidy), dtype="bool") child = simulator.cross(parents) assert np.all(child == 1) @@ -45,7 +49,7 @@ def test_ad_hoc_cross(): ) simulator = MockSimulator(recombination_vec=rec_vec) - population = simulator.load_population(2) + population = simulator.load_population(2, ploidy=4) parents = population[np.array([[0, 1]])] child = simulator.cross(parents) @@ -57,13 +61,15 @@ def test_ad_hoc_cross(): chr_idx = 1 - chr_idx assert child[1, mrk_idx, 0] == population[0, mrk_idx, chr_idx] assert child[1, mrk_idx, 1] == population[1, mrk_idx, chr_idx] + assert child[1, mrk_idx, 2] == population[0, mrk_idx, 2 + chr_idx] + assert child[1, mrk_idx, 3] == population[1, mrk_idx, 2 + chr_idx] def test_cross_two_times(): n_markers = 100_000 n_ind = 2 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=4) parents = population[np.array([[0, 1], [0, 1]])] children = simulator.cross(parents) @@ -75,40 +81,45 @@ def test_double_haploid(): n_markers = 1000 n_ind = 100 n_offspring = 10 + ploidy = 4 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=ploidy) new_pop = simulator.double_haploid(population, n_offspring=n_offspring) assert new_pop.shape == (len(population), n_offspring, *population.shape[1:]) assert np.all(new_pop[..., 0] == new_pop[..., 1]) + assert np.all(new_pop[..., 2] == new_pop[..., 3]) new_pop = simulator.double_haploid(population) assert new_pop.shape == population.shape assert np.all(new_pop[..., 0] == new_pop[..., 1]) + assert np.all(new_pop[..., 2] == new_pop[..., 3]) def test_diallel(): n_markers = 1000 n_ind = 100 + ploidy = 4 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=ploidy) diallel_indices = simulator._diallel_indices(np.arange(10)) assert len(np.unique(diallel_indices, axis=0)) == 45 new_pop = simulator.diallel(population) - assert new_pop.shape == (n_ind * (n_ind - 1) // 2, n_markers, 2) + assert new_pop.shape == (n_ind * (n_ind - 1) // 2, n_markers, ploidy) new_pop = simulator.diallel(population, n_offspring=10) - assert new_pop.shape == (n_ind * (n_ind - 1) // 2, 10, n_markers, 2) + assert new_pop.shape == (n_ind * (n_ind - 1) // 2, 10, n_markers, ploidy) def test_select(): n_markers = 1000 n_ind = 100 + ploidy = 4 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=ploidy) pop_GEBV = simulator.GEBV(population) selected_pop = simulator.select(population, k=10) @@ -119,7 +130,7 @@ def test_select(): dh = simulator.double_haploid(population, n_offspring=100) selected_dh = simulator.select(dh, k=5) - assert selected_dh.shape == (n_ind, 5, n_markers, 2) + assert selected_dh.shape == (n_ind, 5, n_markers, ploidy) for i in range(n_ind): dh_GEBV = simulator.GEBV(dh[i]) selected_GEBV = simulator.GEBV(selected_dh[i]) @@ -131,12 +142,13 @@ def test_select(): def test_random_crosses(): n_markers = 1000 n_ind = 100 + ploidy = 4 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=ploidy) n_crosses = 300 new_pop = simulator.random_crosses(population, n_crosses=n_crosses) - assert new_pop.shape == (n_crosses, n_markers, 2) + assert new_pop.shape == (n_crosses, n_markers, ploidy) n_offspring = 10 new_pop = simulator.random_crosses( @@ -144,7 +156,7 @@ def test_random_crosses(): n_crosses=n_crosses, n_offspring=n_offspring ) - assert new_pop.shape == (n_crosses, n_offspring, n_markers, 2) + assert new_pop.shape == (n_crosses, n_offspring, n_markers, ploidy) def test_multi_trait(): @@ -199,10 +211,11 @@ def test_device(): def test_seed_deterministic(): n_ind = 100 + ploidy = 4 simulator1 = Simulator(genetic_map=genetic_map, seed=7) simulator2 = Simulator(genetic_map=genetic_map, seed=7) mock_simulator = MockSimulator(n_markers=simulator1.n_markers) - population = mock_simulator.load_population(n_ind) + population = mock_simulator.load_population(n_ind, ploidy=ploidy) new_pop1 = simulator1.random_crosses(population, n_crosses=10) new_pop2 = simulator2.random_crosses(population, n_crosses=10) @@ -210,10 +223,26 @@ def test_seed_deterministic(): assert np.all(new_pop1 == new_pop2) +def test_gebv(): + n_markers, n_ind = 100, 10 + ploidy = 4 + simulator = MockSimulator(n_markers=n_markers) + population = simulator.load_population(n_ind, ploidy=ploidy) + + gebv_pandas = simulator.GEBV(population) + assert len(gebv_pandas) == n_ind + assert isinstance(gebv_pandas, pd.DataFrame) + + gebv_array = simulator.GEBV(population, raw_array=True) + assert len(gebv_array) == n_ind + assert np.all(gebv_pandas.values == gebv_array) + + def test_phenotyping(): n_markers, n_ind = 100, 10 + ploidy = 4 simulator = MockSimulator(n_markers=n_markers) - population = simulator.load_population(n_ind) + population = simulator.load_population(n_ind, ploidy=ploidy) phenotype = simulator.phenotype(population, num_environments=4) assert len(phenotype) == n_ind