Skip to content

Commit

Permalink
Rework check_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent e518bf6 commit 9349f7a
Showing 1 changed file with 54 additions and 95 deletions.
149 changes: 54 additions & 95 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
def check_inputs(
reference_panel,
query,
num_alleles,
prob_recombination,
prob_mutation,
scale_mutation_rate,
):
"""
Check that the input data and parameters are valid, and return basic info
about the data: the number of reference haplotypes, number of sites, and ploidy.
Check that the input data and parameters are valid, and return data to run
the HMM algorithms.
The reference panel must be an array of size (m, n) in the haploid case or
(m, n, n) in the diploid case, and the query must be an array of size (k, m),
Expand All @@ -44,23 +45,23 @@ def check_inputs(
n = number of samples in the reference panel (haplotypes, not individuals).
k = number of samples in the query (haplotypes, not individuals).
TODO: Support running on multiple queries.
The mutation rate can be scaled according to the set of alleles
that can be mutated to based on the number of distinct alleles at each site.
:param numpy.ndarray reference_panel: An array of size (m, n) or (m, n, n).
:param numpy.ndarray query: An array of size (k, m).
:param numpy.ndarray reference_panel: A panel of reference sequences.
:param numpy.ndarray query: A query sequence.
:param numpy.ndarray num_alleles: Number of distinct alleles per site.
:param numpy.ndarray prob_recombination: Recombination probability.
:param numpy.ndarray prob_mutation: Mutation probability.
:param bool scale_mutation_rate: Scale mutation rate.
:return: Number of reference haplotypes, number of sites, ploidy
:rtype: tuple
"""
if scale_mutation_rate is None:
scale_mutation_rate = True

# Check the reference panel.
if not len(reference_panel.shape) in (2, 3):
err_msg = "Reference panel array has incorrect dimensions."
err_msg = "Reference panel array has incorrect number of dimensions."
raise ValueError(err_msg)

if len(reference_panel.shape) == 2:
Expand All @@ -79,35 +80,26 @@ def check_inputs(

# Check the queries.
if query.shape[1] != num_sites:
err_msg = "Number of sites in query does not match reference panel."
err_msg = "Number of sites in the query and reference panel do not match."
raise ValueError(err_msg)

if np.any(query == core.NONCOPY):
err_msg = "Queries cannot have any NONCOPY values."
raise ValueError(err_msg)

# Check the mutation probability.
if isinstance(prob_mutation, (int, float)):
if not scale_mutation_rate:
warn_msg = "Passed a scalar mutation probability, but not rescaling it."
warnings.warn(warn_msg)
elif isinstance(prob_mutation, np.ndarray) and len(prob_mutation) == num_sites:
if scale_mutation_rate:
warn_msg = "Passed an array of mutation probabilities. Rescaling them."
warnings.warn(warn_msg)
elif prob_mutation is None:
warn_msg = (
"No mutation probability is passed. "
"Setting it based on Li & Stephens (2003) equations A2 and A3."
)
warnings.warn(warn_msg)
else:
err_msg = "Mutation probability is not a scalar or an array of expected length."
# Check the number of distinct alleles per site.
if len(num_alleles) != num_sites:
err_msg = "Number of alleles is not an array of expected length."
raise ValueError(err_msg)

if not np.all(num_alleles > 0) or not np.issubdtype(num_alleles.dtype, np.integer):
err_msg = "Number of alleles must be positive integers."
raise ValueError(err_msg)

# Check the recombination probability.
if isinstance(prob_recombination, (int, float)):
pass
prob_recombination = np.zeros(num_sites, dtype=np.float64) + prob_recombination
prob_recombination[0] = 0.0
elif (
isinstance(prob_recombination, np.ndarray)
and len(prob_recombination) == num_sites
Expand All @@ -121,35 +113,42 @@ def check_inputs(
)
raise ValueError(err_msg)

return (num_ref_haps, num_sites, ploidy)


def set_emission_probabilities(
num_sites,
ploidy,
num_alleles,
prob_mutation,
scale_mutation_rate,
):
if isinstance(prob_mutation, float):
prob_mutation = np.zeros(num_sites) + prob_mutation
# Check the mutation probability.
if prob_mutation is None:
warn_msg = "No mutation probability is passed; setting it as per Li & Stephens (2003) eqn. A2 and A3."
warnings.warn(warn_msg)
prob_mutation = core.estimate_mutation_probability(num_ref_haps)
prob_mutation = np.zeros(num_sites, dtype=np.float64) + prob_mutation
elif isinstance(prob_mutation, (int, float)):
if not scale_mutation_rate:
warn_msg = "A scalar mutation probability is passed, but not rescaling it."
warnings.warn(warn_msg)
prob_mutation = np.zeros(num_sites, dtype=np.float64) + prob_mutation
elif isinstance(prob_mutation, np.ndarray) and len(prob_mutation) == num_sites:
if scale_mutation_rate:
warn_msg = "Rescaling an array of mutation probabilities."
warnings.warn(warn_msg)
else:
err_msg = "Mutation probability is not a scalar or an array of expected length."
raise ValueError(err_msg)

# Calculate the emission probability matrix.
if ploidy == 1:
emission_probs = core.get_emission_matrix_haploid(
emission_matrix = core.get_emission_matrix_haploid(
mu=prob_mutation,
num_sites=num_sites,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)
else:
emission_probs = core.get_emission_matrix_diploid(
emission_matrix = core.get_emission_matrix_diploid(
mu=prob_mutation,
num_sites=num_sites,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)

return emission_probs
return num_ref_haps, num_sites, ploidy, emission_matrix


def forwards(
Expand All @@ -169,21 +168,11 @@ def forwards(
if normalise is None:
normalise = True

num_ref_haps, num_sites, ploidy = check_inputs(
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)

if prob_mutation is None:
prob_mutation = core.estimate_mutation_probability(num_ref_haps)

emission_probs = set_emission_probabilities(
num_sites=num_sites,
ploidy=ploidy,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
Expand All @@ -202,7 +191,7 @@ def forwards(
num_sites,
reference_panel,
query,
emission_probs,
emission_matrix,
prob_recombination,
norm=normalise,
)
Expand All @@ -224,21 +213,11 @@ def backwards(
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy = check_inputs(
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)

if prob_mutation is None:
prob_mutation = core.estimate_mutation_probability(num_ref_haps)

emission_probs = set_emission_probabilities(
num_sites=num_sites,
ploidy=ploidy,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
Expand All @@ -253,7 +232,7 @@ def backwards(
num_sites,
reference_panel,
query,
emission_probs,
emission_matrix,
normalisation_factor_from_forward,
prob_recombination,
)
Expand All @@ -274,21 +253,11 @@ def viterbi(
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy = check_inputs(
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)

if prob_mutation is None:
prob_mutation = core.estimate_mutation_probability(num_ref_haps)

emission_probs = set_emission_probabilities(
num_sites=num_sites,
ploidy=ploidy,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
Expand All @@ -299,7 +268,7 @@ def viterbi(
num_sites,
reference_panel,
query,
emission_probs,
emission_matrix,
prob_recombination,
)
best_path = backwards_viterbi_hap(num_sites, V, P)
Expand All @@ -309,7 +278,7 @@ def viterbi(
num_sites,
reference_panel,
query,
emission_probs,
emission_matrix,
prob_recombination,
)
unphased_path = backwards_viterbi_dip(num_sites, V, P)
Expand All @@ -332,21 +301,11 @@ def path_loglik(
if scale_mutation_rate is None:
scale_mutation_rate = True

num_ref_haps, num_sites, ploidy = check_inputs(
num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs(
reference_panel=reference_panel,
query=query,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)

if prob_mutation is None:
prob_mutation = core.estimate_mutation_probability(num_ref_haps)

emission_probs = set_emission_probabilities(
num_sites=num_sites,
ploidy=ploidy,
num_alleles=num_alleles,
prob_recombination=prob_recombination,
prob_mutation=prob_mutation,
scale_mutation_rate=scale_mutation_rate,
)
Expand All @@ -362,7 +321,7 @@ def path_loglik(
reference_panel,
path,
query,
emission_probs,
emission_matrix,
prob_recombination,
)

Expand Down

0 comments on commit 9349f7a

Please sign in to comment.