Skip to content

Commit

Permalink
Merge pull request #111 from szhan/raise_errors_core
Browse files Browse the repository at this point in the history
Raise errors instead of asserting
  • Loading branch information
szhan authored Jun 18, 2024
2 parents 7bb3108 + 219da54 commit d328110
Showing 1 changed file with 49 additions and 35 deletions.
84 changes: 49 additions & 35 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
"""Create Numpy-like functions for max, sum, etc."""
assert arr.ndim == 2
assert axis in [0, 1]
if arr.ndim != 2:
err_msg = "Array does not have two dimensions."
raise ValueError(err_msg)
if axis not in [0, 1]:
err_msg = "Axis is not 0 or 1."
raise ValueError(err_msg)
if axis == 0:
result = np.empty(arr.shape[1])
for i in range(len(result)):
Expand Down Expand Up @@ -56,9 +60,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
Convert a set of haplotypes into a matrix of diploid genotypes encoded as allele dosages,
and return the genotypes.
TODO: Handle multiallelic sites.
It is assumed all sites are biallelic and alleles are encoded as ancestral/derived.
The only allowable allele states are 0, 1, and NONCOPY (for partial ancestral haplotypes).
TODO: Handle multiallelic sites.
Allowable genotype values are 0, 1, 2, and NONCOPY. If either one haplotype is NONCOPY
at a site, then the genotype at the site is assigned NONCOPY.
Expand All @@ -73,9 +77,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int8)
assert np.all(
np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)
), f"Reference haplotypes contain illegal allele states."
if not np.all(np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)):
err_msg = "Reference haplotypes contain illegal allele states."
raise ValueError(err_msg)
num_sites = ref_panel.shape[0]
num_haps = ref_panel.shape[1]
genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int8) - np.inf
Expand Down Expand Up @@ -107,12 +111,14 @@ def convert_haplotypes_to_unphased_genotypes(query):
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int8)
assert np.all(
np.isin(np.unique(query), ALLOWED_ALLELE_STATES)
), f"Query haplotypes contain illegal allele states."
if not np.all(np.isin(np.unique(query), ALLOWED_ALLELE_STATES)):
err_msg = "Query haplotypes contain illegal allele states."
raise ValueError(err_msg)
num_sites = query.shape[1]
num_haps = query.shape[0]
assert num_haps == 2, "Two haplotypes are expected in a diploid query."
if num_haps != 2:
err_msg = "Two haplotypes are expected in a diploid query."
raise ValueError(err_msg)
genotypes = np.zeros((1, num_sites), dtype=np.int8) - np.inf
genotypes[0, :] = np.sum(query, axis=0)
genotypes[0, np.any(query == MISSING, axis=0)] = MISSING
Expand All @@ -137,36 +143,40 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps):
:return: True if the condition is satisfied, otherwise False.
:rtype: bool
"""
assert np.all(
genotype_matrix != MISSING
), "Reference panel cannot contain any MISSING values."
if not np.all(genotype_matrix != MISSING):
err_msg = "Reference panel cannot contain any MISSING values."
raise ValueError(err_msg)
max_num_copiable_entries = 2 * num_sample_haps - 1
num_copiable_entries_per_site = np.sum(genotype_matrix != NONCOPY, axis=1)
return np.all(num_copiable_entries_per_site <= max_num_copiable_entries)


@jit.numba_njit
def get_num_copiable_entries(ref_panel):
assert ref_panel.ndim in [2, 3], "Reference panel array has incorrect dimensions."
assert np.all(
ref_panel != MISSING
), "Reference panel cannot contain any MISSING values."
if ref_panel.ndim not in [2, 3]:
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)
if not np.all(ref_panel != MISSING):
err_msg = "Reference panel cannot contain any MISSING values."
raise ValueError(err_msg)
if ref_panel.ndim == 2:
num_copiable_entries = np.sum(ref_panel != NONCOPY, axis=1)
else:
num_sites = ref_panel.shape[0]
num_copiable_entries = np.zeros(num_sites, dtype=np.int32)
for i in range(num_sites):
num_copiable_entries[i] = np.sum(ref_panel[i, :, :] != NONCOPY)
assert np.all(
num_copiable_entries > 0
), "Number of copiable entries must be greater than zero at all sites."
if not np.all(num_copiable_entries > 0):
err_msg = "Number of copiable entries must be > 0 at all sites."
raise ValueError(err_msg)
return num_copiable_entries


def get_num_alleles(ref_panel, query):
assert ref_panel.shape[0] == query.shape[1]
num_sites = ref_panel.shape[0]
if ref_panel.shape[0] != query.shape[1]:
err_msg = "Number of sites in the reference panel and query do not match."
raise ValueError(err_msg)
allele_lists = []
for i in range(num_sites):
all_alleles = np.append(ref_panel[i, :], query[:, i])
Expand Down Expand Up @@ -232,9 +242,9 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
if len(mu) != len(num_alleles):
err_msg = "Arrays of mutation probability and number of alleles are unequal in length."
raise ValueError(err_msg)
if isinstance(mu, float):
mu = np.zeros(num_sites, dtype=np.float64) + mu
emission_matrix = np.zeros((num_sites, 2), np.float64) - np.inf
Expand Down Expand Up @@ -272,11 +282,15 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma
:return: Emission probability.
:rtype: float
"""
assert ref_allele != MISSING, "Reference allele cannot be MISSING."
assert query_allele != NONCOPY, "Query allele cannot be NONCOPY."
assert (
emission_matrix.shape[1] == 2
), "Emission probability matrix has incorrect shape."
if ref_allele == MISSING:
err_msg = "Reference allele cannot be MISSING."
raise ValueError(err_msg)
if query_allele == NONCOPY:
err_msg = "Query allele cannot be NONCOPY."
raise ValueError(err_msg)
if emission_matrix.shape[1] != 2:
err_msg = "Emission probability matrix has incorrect shape."
raise ValueError(err_msg)
if ref_allele == NONCOPY:
return 0.0
elif query_allele == MISSING:
Expand Down Expand Up @@ -321,9 +335,9 @@ def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate)
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
if len(mu) != len(num_alleles):
err_msg = "Arrays of mutation probability and number of alleles are unequal in length."
raise ValueError(err_msg)
if isinstance(mu, float):
mu = np.zeros(num_sites, dtype=np.float64) + mu
prob_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf
Expand Down Expand Up @@ -368,9 +382,9 @@ def get_emission_probability_diploid(
def get_emission_probability_diploid_genotypes(
ref_genotypes, query_genotype, site, emission_matrix
):
assert (
ref_genotypes.shape[0] == ref_genotypes.shape[1]
), "Reference genotype matrix must be a square matrix."
if ref_genotypes.shape[0] != ref_genotypes.shape[1]:
err_msg = "Reference genotype matrix must be a square matrix."
raise ValueError(err_msg)
num_ref_genotypes = len(ref_genotypes)
emission_probs = np.zeros((num_ref_genotypes, num_ref_genotypes), dtype=np.float64)
for i in range(num_ref_genotypes):
Expand Down

0 comments on commit d328110

Please sign in to comment.