Skip to content

Commit

Permalink
support any ploidy + improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed May 29, 2023
1 parent c669ef7 commit e26d2ae
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 70 deletions.
48 changes: 27 additions & 21 deletions chromax/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ def cross(
Args:
- parents (array): parents to compute the cross. The shape of
the parents is (n, 2, m, 2), where n is the number of parents
and m is the number of markers.
the parents is (n, 2, m, d), where n is the number of parents,
m is the number of markers, and d is the ploidy.
- recombination_vec (array): array of m probabilities.
The i-th value represent the probability to recombine before the marker i.
- random_key (array): PRNGKey, for reproducibility purpose
Returns:
- population (array): offspring population of shape (n, m, 2).
- population (array): offspring population of shape (n, m, d).
Example:
>>> from chromax import functional
>>> import numpy as np
>>> import jax
>>> n_chr, chr_len = 10, 100
>>> n_chr, chr_len, ploidy = 10, 100, 2
>>> n_crosses = 50
>>> parents_shape = (n_crosses, 2, n_chr * chr_len, 2)
>>> parents_shape = (n_crosses, 2, n_chr * chr_len, ploidy)
>>> parents = np.random.choice([False, True], size=parents_shape)
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
Expand All @@ -40,13 +40,16 @@ def cross(
>>> f2.shape
(50, 1000, 2)
"""
random_keys = jax.random.split(random_key, num=len(parents) * 2)
random_keys = random_keys.reshape(len(parents), 2, 2)
return _cross(parents, recombination_vec, random_keys)
parents = parents.reshape(*parents.shape[:3], -1, 2)
random_keys = jax.random.split(random_key, num=len(parents) * 2 * parents.shape[3])
random_keys = random_keys.reshape(len(parents), 2, parents.shape[3], 2)
offsprings = _cross(parents, recombination_vec, random_keys)
return offsprings.reshape(*offsprings.shape[:-2], -1)


@jax.jit
@partial(jax.vmap, in_axes=(0, None, 0)) # parallelize across individuals
@partial(jax.vmap, in_axes=(0, None, 0), out_axes=1) # parallelize parents
@partial(jax.vmap, in_axes=(0, None, 0), out_axes=2) # parallelize parents
def _cross(
parent: Individual,
recombination_vec: Float[Array, N_MARKERS],
Expand All @@ -68,22 +71,22 @@ def double_haploid(
"""Computes the double haploid of the input population.
Args:
- population (array): input population of shape (n, m, 2).
- population (array): input population of shape (n, m, d).
- n_offspring (int): number of offspring per plant.
- recombination_vec (array): array of m probabilities.
The i-th value represent the probability to recombine before the marker i.
- random_key (array): array of n PRNGKey, one for each individual.
Returns:
- population (array): output population of shape (n, n_offspring, m, 2).
- population (array): output population of shape (n, n_offspring, m, d).
This population will be homozygote.
Example:
>>> from chromax import functional
>>> import numpy as np
>>> import jax
>>> n_chr, chr_len = 10, 100
>>> pop_shape = (50, n_chr * chr_len, 2)
>>> n_chr, chr_len, ploidy = 10, 100, 2
>>> pop_shape = (50, n_chr * chr_len, ploidy)
>>> f1 = np.random.choice([False, True], size=pop_shape)
>>> rec_vec = np.full((n_chr, chr_len), 1.5 / chr_len)
>>> rec_vec[:, 0] = 0.5 # equal probability on starting haploid
Expand All @@ -93,12 +96,14 @@ def double_haploid(
>>> dh.shape
(50, 10, 1000, 2)
"""
population = population.reshape(*population.shape[:2], -1, 2)
keys = jax.random.split(
random_key,
num=len(population) * n_offspring
).reshape(len(population), n_offspring, 2)
haploid = _double_haploid(population, recombination_vec, keys)
return jnp.broadcast_to(haploid[..., None], shape=(*haploid.shape, 2))
num=len(population) * n_offspring * population.shape[2]
).reshape(len(population), n_offspring, population.shape[2], 2)
haploids = _double_haploid(population, recombination_vec, keys)
dh_pop = jnp.broadcast_to(haploids[..., None], shape=(*haploids.shape, 2))
return dh_pop.reshape(*dh_pop.shape[:-2], -1)


@jax.jit
Expand All @@ -117,6 +122,7 @@ def _double_haploid(


@jax.jit
@partial(jax.vmap, in_axes=(1, None, 0), out_axes=1) # parallelize pair of chromosomes
def _meiosis(
individual: Individual,
recombination_vec: Float[Array, N_MARKERS],
Expand Down Expand Up @@ -144,22 +150,22 @@ def select(
"""Function to select individuals based on their score (index).
Args:
- population (array): input grouped population of shape (n, m, 2)
- population (array): input grouped population of shape (n, m, d)
- k (int): number of individual to select.
- f_index (function): function that computes a score for each individual.
The function accepts as input a population, i.e. and array of shape
(n, m, 2) and returns an arrray of n float number.
Returns:
- population (array): output population of (k, m, 2)
- population (array): output population of (k, m, d)
Example:
>>> from chromax import functional
>>> from chromax.trait_model import TraitModel
>>> from chromax.index_functions import conventional_index
>>> import numpy as np
>>> n_chr, chr_len = 10, 100
>>> pop_shape = (50, n_chr * chr_len, 2)
>>> n_chr, chr_len, ploidy = 10, 100, 2
>>> pop_shape = (50, n_chr * chr_len, ploidy)
>>> f1 = np.random.choice([False, True], size=pop_shape)
>>> marker_effects = np.random.randn(n_chr * chr_len)
>>> gebv_model = TraitModel(marker_effects[:, None])
Expand Down
50 changes: 23 additions & 27 deletions chromax/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ def cross(self, parents: Parents["n"]) -> Population["n"]:
"""Main function that computes crosses from a list of parents.
:param parents: parents to compute the cross. The shape of
the parents is (n, 2, m, 2), where n is the number of parents
and m is the number of markers.
the parents is (n, 2, m, d), where n is the number of parents,
m is the number of markers, and d is the ploidy.
:type parents: ndarray
:return: offspring population of shape (n, m, 2).
:return: offspring population of shape (n, m, d).
:rtype: ndarray
:Example:
Expand Down Expand Up @@ -237,21 +237,21 @@ def differentiable_cross_func(self) -> Callable:
The differentiable crossing function takes as input:
- population (array): starting population from which performing the crosses.
The shape of the population is (n, m, 2).
- cross_weights (array): Array of shape (l, n, 2). It is used to compute
The shape of the population is (n, m, d).
- cross_weights (array): Array of shape (l, n, d). It is used to compute
l crosses, starting from a weighted average of the n possible parents.
When the n-axis has all zeros except of a single element equals to one,
this function is equivalent to the cross function.
- random_key (JAX random key): random key used for recombination sampling.
And returns a population of shape (l, m, 2).
And returns a population of shape (l, m, d).
:Example:
>>> from chromax import Simulator, sample_data
>>> import numpy as np
>>> import jax
>>> simulator = Simulator(genetic_map=sample_data.genetic_map)
>>> diff_cross = simulator.differentiable_cross_func()
>>> diff_cross = simulator.differentiable_cross_func
>>> def mean_gebv(pop, weights, random_key):
new_pop = diff_cross(pop, weights, random_key)
return simulator.GEBV(new_pop, raw_array=True).mean()
Expand Down Expand Up @@ -279,10 +279,12 @@ def diff_cross_f(
cross_weights: Float[Array, "m n 2"],
random_key: jax.random.PRNGKeyArray
) -> Population["m"]:
num_keys = len(cross_weights) * len(population) * 2
keys = jax.random.split(random_key, num=num_keys)
keys = keys.reshape(len(cross_weights), len(population), 2, 2)
population = population.reshape(*population.shape[:-1], -1, 2)
keys_shape = len(cross_weights), len(population), 2, population.shape[-2]
keys = jax.random.split(random_key, num=np.prod(keys_shape))
keys = keys.reshape(*keys_shape, 2)
outer_res = cross_pop(population, self.recombination_vec, keys)
outer_res = outer_res.reshape(*outer_res.shape[:-2], -1)
return (cross_weights[:, :, None, :] * outer_res).sum(axis=1)

return diff_cross_f
Expand Down Expand Up @@ -332,13 +334,13 @@ def diallel(
"""Diallel crossing function, i.e. crossing between every possible
couple, except self-crossing.
:param population: input population of shape (n, m, 2).
:param population: input population of shape (n, m, d).
:type population: ndarray
:param n_offspring: number of offspring per cross.
The default value is 1.
:type n_offspring: int
:return: output population of shape (l, n_offspring, m, 2),
:return: output population of shape (l, n_offspring, m, d),
where l is the number of possible pair, i.e `n * (n-1) / 2`.
:rtype: ndarray
Expand Down Expand Up @@ -380,15 +382,15 @@ def random_crosses(
) -> Population["n_crosses n_offspring"]:
"""Computes random crosses on a population.
:param population: input population of shape (n, m, 2).
:param population: input population of shape (n, m, d).
:type population: ndarray
:param n_crosses: number of random crosses to perform.
:type n_crosses: int
:param n_offspring: number of offspring per cross.
The default value is 1.
:type n_offspring: int
:return: output population of shape (n_crosses, n_offspring, m, 2).
:return: output population of shape (n_crosses, n_offspring, m, d).
:rtype: ndarray
:Example:
Expand All @@ -399,12 +401,6 @@ def random_crosses(
>>> f2.shape
(100, 10, 9839, 2)
"""

if n_crosses < 1:
raise ValueError("n_crosses must be higher or equal to 1")
if n_offspring < 1:
raise ValueError("n_offspring must be higher or equal to 1")

all_indices = np.arange(len(population))
diallel_indices = self._diallel_indices(all_indices)
if n_crosses > len(diallel_indices):
Expand Down Expand Up @@ -434,18 +430,18 @@ def select(
) -> Population["_g k"]:
"""Function to select individuals based on their score (index).
:param population: input population of shape (n, m, 2),
or shape (g, n, m, 2), to select k individual from each group population group g.
:param population: input population of shape (n, m, d),
or shape (g, n, m, d), to select k individual from each group population group g.
:type population: ndarray
:param k: number of individual to select.
:type k: int
:param f_index: function that computes a score from each individual.
The function accepts as input the population, i.e. and array of shape
(n, m, 2) and returns a n float numbers. The default f_index is the conventional index,
(n, m, d) and returns a n float numbers. The default f_index is the conventional index,
i.e. the sum of the marker effects masked with the SNPs from the genetic_map.
:type f_index: Callable
:return: output population of shape (k, m, 2) or (g, k, m, 2),
:return: output population of shape (k, m, d) or (g, k, m, d),
depending on the input population.
:rtype: ndarray
Expand Down Expand Up @@ -480,7 +476,7 @@ def GEBV(
"""Computes the Genomic Estimated Breeding Values using the
marker effects from the genetic_map.
:param population: input population of shape (n, m, 2).
:param population: input population of shape (n, m, d).
:type population: ndarray
:param raw_array: whether to return a raw array or a DataFrame.
Deafult value is False.
Expand Down Expand Up @@ -541,7 +537,7 @@ def phenotype(
This uses the Genotype-by-Environment model described in the following:
https://cran.r-project.org/web/packages/AlphaSimR/vignettes/traits.pdf
:param population: input population of shape (n, m, 2)
:param population: input population of shape (n, m, d)
:type population: ndarray
:param num_environments: number of environments to test the population.
Default value is 1.
Expand Down Expand Up @@ -597,7 +593,7 @@ def corrcoef(
"""Computes the correlation coefficient of the population against its centroid.
It can be used as an indicator of variance in the population.
:param population: input population of shape (n, m, 2)
:param population: input population of shape (n, m, d)
:type population: ndarray
:return: vector of length n, containing the correlation coefficient
Expand Down
2 changes: 1 addition & 1 deletion chromax/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jaxtyping import Array, Bool

N_MARKERS = "m"
DIPLOID_SHAPE = N_MARKERS + " 2"
DIPLOID_SHAPE = N_MARKERS + "d"

Haploid = Bool[Array, N_MARKERS]
Individual = Bool[Array, DIPLOID_SHAPE]
Expand Down
4 changes: 2 additions & 2 deletions tests/mock_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(
super().__init__(genetic_map=genetic_map, **kwargs)
self.recombination_vec = recombination_vec

def load_population(self, n_individual=100):
def load_population(self, n_individual=100, ploidy=2):
return np.random.choice(
a=[False, True],
size=(n_individual, self.n_markers, 2),
size=(n_individual, self.n_markers, ploidy),
p=[0.5, 0.5]
)
53 changes: 53 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from chromax.index_functions import conventional_index
from chromax.trait_model import TraitModel
import numpy as np
import jax
from chromax import functional
import pytest


@pytest.mark.parametrize("idx", [0, 1])
def test_cross(idx):
n_markers, ploidy = 1000, 4
n_crosses = 50
parents_shape = (n_crosses, 2, n_markers, ploidy)
parents = np.random.choice([False, True], size=parents_shape)
rec_vec = np.zeros(n_markers)
rec_vec[0] = idx
random_key = jax.random.PRNGKey(42)
new_pop = functional.cross(parents, rec_vec, random_key)

for i in range(ploidy):
assert np.all(new_pop[..., i] == parents[:, i % 2, :, i - i % 2 + idx])


def test_double_haploid():
n_chr, chr_len, ploidy = 10, 100, 2
n_offspring = 10
pop_shape = (50, n_chr * chr_len, ploidy)
f1 = np.random.choice([False, True], size=pop_shape)
rec_vec = np.full((n_chr * chr_len,), 1.5 / chr_len)
random_key = jax.random.PRNGKey(42)
dh = functional.double_haploid(f1, n_offspring, rec_vec, random_key)
assert dh.shape == (len(f1), n_offspring, n_chr * chr_len, ploidy)

for i in range(ploidy // 2):
assert np.all(dh[..., i * 2] == dh[..., i * 2 + 1])


def test_select():
n_markers, ploidy = 1000, 4
k = 10
pop_shape = (50, n_markers, ploidy)
f1 = np.random.choice([False, True], size=pop_shape)
marker_effects = np.random.randn(n_markers)
gebv_model = TraitModel(marker_effects[:, None])
f_index = conventional_index(gebv_model)
f2 = functional.select(f1, k=k, f_index=f_index)
assert f2.shape == (k, *f1.shape[1:])

f1_gebv = gebv_model(f1)
f2_gebv = gebv_model(f2)
assert np.max(f2_gebv) == np.max(f1_gebv)
assert np.mean(f2_gebv) > np.mean(f1_gebv)
assert np.min(f2_gebv) > np.min(f1_gebv)
Loading

0 comments on commit e26d2ae

Please sign in to comment.