Skip to content

Commit

Permalink
Rework check_inputs to no longer take num_alleles
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent 8d474c5 commit 6c62be7
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 97 deletions.
173 changes: 102 additions & 71 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def check_inputs(
reference_panel,
query,
num_alleles,
ploidy,
prob_recombination,
prob_mutation,
scale_mutation_rate,
Expand All @@ -38,63 +38,66 @@ def check_inputs(
Check that the input data and parameters are valid, and return data to run
the HMM algorithms.
The reference panel must be an array of size (m, n) in the haploid case or
(m, n, n) in the diploid case, and the query must be an array of size (k, m),
The reference panel and query are arrays of size (m, n) and (k, m), respectively,
where:
m = number of sites.
n = number of samples in the reference panel (haplotypes, not individuals).
k = number of samples in the query (haplotypes, not individuals).
TODO: Support running on multiple queries.
TODO: Support running on multiple queries. Currently, only k = 1 or 2 is supported.
The mutation rate can be scaled according to the set of alleles
that can be mutated to based on the number of distinct alleles at each site.
:param numpy.ndarray reference_panel: A panel of reference sequences.
:param numpy.ndarray query: A query sequence.
:param numpy.ndarray num_alleles: Number of distinct alleles per site.
:param numpy.ndarray ploidy: Ploidy (only 1 or 2 are supported).
:param numpy.ndarray prob_recombination: Recombination probability.
:param numpy.ndarray prob_mutation: Mutation probability.
:param bool scale_mutation_rate: Scale mutation rate or not.
:return: Number of ref. haplotypes, number of sites, ploidy, emission prob. matrix.
:return: Num. ref. hap., num. sites, checked ref. panel, checked query, emission prob. matrix.
:rtype: tuple
"""
# Check the reference panel.
if not len(reference_panel.shape) in (2, 3):
err_msg = "Reference panel array has incorrect number of dimensions."
# Check ploidy.
if not ploidy in [1, 2]:
err_msg = "Only ploidy levels 1 and 2 are supported."
raise ValueError(err_msg)

if len(reference_panel.shape) == 2:
num_sites, num_ref_haps = reference_panel.shape
ploidy = 1
else:
num_sites, num_ref_haps, _num_ref_haps = reference_panel.shape
if num_ref_haps != _num_ref_haps:
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)
ploidy = 2
# Check the reference panel.
if not len(reference_panel.shape) == 2:
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)

if np.any(reference_panel == core.MISSING):
err_msg = "Reference panel cannot have any MISSING values."
raise ValueError(err_msg)

if ploidy == 2:
if not np.all(np.isin(reference_panel, [0, 1, core.NONCOPY])):
err_msg = "Reference panel has illegal alleles. "
err_msg += "Only 0/1 encoding is supported in diploid mode."
raise ValueError(err_msg)

num_sites, num_ref_haps = reference_panel.shape

# Check the queries.
if query.shape[1] != num_sites:
err_msg = "Number of sites in the query and reference panel do not match."
if query.shape[0] != ploidy:
err_msg = "Query array has incorrect dimensions."
raise ValueError(err_msg)

if np.any(query == core.NONCOPY):
err_msg = "Queries cannot have any NONCOPY values."
if query.shape[1] != num_sites:
err_msg = "Number of sites in the query and reference panel don't match."
raise ValueError(err_msg)

# Check the number of distinct alleles per site.
if len(num_alleles) != num_sites:
err_msg = "Number of alleles is not an array of expected length."
if np.any(query == core.NONCOPY):
err_msg = "Query cannot have any NONCOPY values."
raise ValueError(err_msg)

if not np.all(num_alleles > 0) or not np.issubdtype(num_alleles.dtype, np.integer):
err_msg = "Number of alleles must be positive integers."
raise ValueError(err_msg)
if ploidy == 2:
if not np.all(np.isin(query, [0, 1, core.MISSING])):
err_msg = "Query has illegal alleles. "
err_msg += "Only 0/1 encoding is supported in diploid mode."
raise ValueError(err_msg)

# Check the recombination probability.
if isinstance(prob_recombination, (int, float)):
Expand Down Expand Up @@ -137,6 +140,7 @@ def check_inputs(
raise ValueError(err_msg)

# Calculate the emission probability matrix.
num_alleles = core.get_num_alleles(reference_panel, query)
if ploidy == 1:
emission_matrix = core.get_emission_matrix_haploid(
mu=prob_mutation,
Expand All @@ -152,13 +156,32 @@ def check_inputs(
scale_mutation_rate=scale_mutation_rate,
)

return num_ref_haps, num_sites, ploidy, emission_matrix
if ploidy == 1:
return (
num_ref_haps,
num_sites,
reference_panel,
query,
emission_matrix,
)
else:
ref_panel_genotypes = core.convert_haplotypes_to_phased_genotypes(
reference_panel
)
query_genotypes = core.convert_haplotypes_to_unphased_genotypes(query)
return (
num_ref_haps,
num_sites,
ref_panel_genotypes,
query_genotypes,
emission_matrix,
)


def forwards(
reference_panel,
query,
num_alleles,
ploidy,
prob_recombination,
*,
prob_mutation=None,
Expand All @@ -169,13 +192,15 @@ def forwards(
if normalise is None:
normalise = True

num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = (
check_inputs(
reference_panel=reference_panel,
query=query,
ploidy=ploidy,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
)

if ploidy == 1:
Expand All @@ -190,8 +215,8 @@ def forwards(
) = forward_function(
num_ref_haps,
num_sites,
reference_panel,
query,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
norm=normalise,
Expand All @@ -203,21 +228,23 @@ def forwards(
def backwards(
reference_panel,
query,
num_alleles,
ploidy,
normalisation_factor_from_forward,
prob_recombination,
*,
prob_mutation=None,
scale_mutation_rate=None,
):
"""Run the backwards algorithm on haploid or diploid genotype data."""
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = (
check_inputs(
reference_panel=reference_panel,
query=query,
ploidy=ploidy,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
)

if ploidy == 1:
Expand All @@ -228,8 +255,8 @@ def backwards(
backwards_array = backward_function(
num_ref_haps,
num_sites,
reference_panel,
query,
ref_panel_checked,
query_checked,
emission_matrix,
normalisation_factor_from_forward,
prob_recombination,
Expand All @@ -241,28 +268,30 @@ def backwards(
def viterbi(
reference_panel,
query,
num_alleles,
ploidy,
prob_recombination,
*,
prob_mutation=None,
scale_mutation_rate=None,
):
"""Run the Viterbi algorithm on haploid or diploid genotype data."""
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = (
check_inputs(
reference_panel=reference_panel,
query=query,
ploidy=ploidy,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
)

if ploidy == 1:
V, P, log_lik = forwards_viterbi_hap_lower_mem_rescaling(
num_ref_haps,
num_sites,
reference_panel,
query,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
)
Expand All @@ -271,8 +300,8 @@ def viterbi(
V, P, log_lik = forwards_viterbi_dip_low_mem(
num_ref_haps,
num_sites,
reference_panel,
query,
ref_panel_checked,
query_checked,
emission_matrix,
prob_recombination,
)
Expand All @@ -285,21 +314,23 @@ def viterbi(
def path_loglik(
reference_panel,
query,
num_alleles,
ploidy,
path,
prob_recombination,
*,
prob_mutation=None,
scale_mutation_rate=None,
):
"""Evaluate the log-likelihood of a copying path for a query through a reference panel."""
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
num_ref_haps, num_sites, ref_panel_checked, query_checked, emission_matrix = (
check_inputs(
reference_panel=reference_panel,
query=query,
ploidy=ploidy,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
)

if ploidy == 1:
Expand All @@ -310,9 +341,9 @@ def path_loglik(
log_lik = path_ll_function(
num_ref_haps,
num_sites,
reference_panel,
ref_panel_checked,
path,
query,
query_checked,
emission_matrix,
prob_recombination,
)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_api_fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 2
for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=2,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
n=n,
Expand All @@ -38,18 +38,18 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
)
F, c, ll = ls.forwards(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
reference_panel=H_vs,
query=query,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
normalise=True,
)
B = ls.backwards(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
reference_panel=H_vs,
query=query,
ploidy=ploidy,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=1,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
Expand All @@ -36,7 +37,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
F, c, ll = ls.forwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
Expand All @@ -45,7 +46,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
B = ls.backwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down
Loading

0 comments on commit 6c62be7

Please sign in to comment.