Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise errors instead of asserting #111

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
"""Create Numpy-like functions for max, sum, etc."""
assert arr.ndim == 2
assert axis in [0, 1]
if arr.ndim != 2:
err_msg = "Array does not have two dimensions."
raise ValueError(err_msg)
if axis not in [0, 1]:
err_msg = "Axis is not 0 or 1."
raise ValueError(err_msg)
if axis == 0:
result = np.empty(arr.shape[1])
for i in range(len(result)):
Expand Down Expand Up @@ -56,9 +60,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
Convert a set of haplotypes into a matrix of diploid genotypes encoded as allele dosages,
and return the genotypes.

TODO: Handle multiallelic sites.
It is assumed all sites are biallelic and alleles are encoded as ancestral/derived.
The only allowable allele states are 0, 1, and NONCOPY (for partial ancestral haplotypes).
TODO: Handle multiallelic sites.

Allowable genotype values are 0, 1, 2, and NONCOPY. If either one haplotype is NONCOPY
at a site, then the genotype at the site is assigned NONCOPY.
Expand All @@ -73,9 +77,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int8)
assert np.all(
np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)
), f"Reference haplotypes contain illegal allele states."
if not np.all(np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)):
err_msg = "Reference haplotypes contain illegal allele states."
raise ValueError(err_msg)
num_sites = ref_panel.shape[0]
num_haps = ref_panel.shape[1]
genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int8) - np.inf
Expand Down Expand Up @@ -107,12 +111,14 @@ def convert_haplotypes_to_unphased_genotypes(query):
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int8)
assert np.all(
np.isin(np.unique(query), ALLOWED_ALLELE_STATES)
), f"Query haplotypes contain illegal allele states."
if not np.all(np.isin(np.unique(query), ALLOWED_ALLELE_STATES)):
err_msg = "Query haplotypes contain illegal allele states."
raise ValueError(err_msg)
num_sites = query.shape[1]
num_haps = query.shape[0]
assert num_haps == 2, "Two haplotypes are expected in a diploid query."
if num_haps != 2:
err_msg = "Two haplotypes are expected in a diploid query."
raise ValueError(err_msg)
genotypes = np.zeros((1, num_sites), dtype=np.int8) - np.inf
genotypes[0, :] = np.sum(query, axis=0)
genotypes[0, np.any(query == MISSING, axis=0)] = MISSING
Expand All @@ -137,36 +143,40 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps):
:return: True if the condition is satisfied, otherwise False.
:rtype: bool
"""
assert np.all(
genotype_matrix != MISSING
), "Reference panel cannot contain any MISSING values."
if not np.all(genotype_matrix != MISSING):
err_msg = "Reference panel cannot contain any MISSING values."
raise ValueError(err_msg)
max_num_copiable_entries = 2 * num_sample_haps - 1
num_copiable_entries_per_site = np.sum(genotype_matrix != NONCOPY, axis=1)
return np.all(num_copiable_entries_per_site <= max_num_copiable_entries)


@jit.numba_njit
def get_num_copiable_entries(ref_panel):
assert ref_panel.ndim in [2, 3], "Reference panel array has incorrect dimensions."
assert np.all(
ref_panel != MISSING
), "Reference panel cannot contain any MISSING values."
if ref_panel.ndim not in [2, 3]:
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)
if not np.all(ref_panel != MISSING):
err_msg = "Reference panel cannot contain any MISSING values."
raise ValueError(err_msg)
if ref_panel.ndim == 2:
num_copiable_entries = np.sum(ref_panel != NONCOPY, axis=1)
else:
num_sites = ref_panel.shape[0]
num_copiable_entries = np.zeros(num_sites, dtype=np.int32)
for i in range(num_sites):
num_copiable_entries[i] = np.sum(ref_panel[i, :, :] != NONCOPY)
assert np.all(
num_copiable_entries > 0
), "Number of copiable entries must be greater than zero at all sites."
if not np.all(num_copiable_entries > 0):
err_msg = "Number of copiable entries must be > 0 at all sites."
raise ValueError(err_msg)
return num_copiable_entries


def get_num_alleles(ref_panel, query):
assert ref_panel.shape[0] == query.shape[1]
num_sites = ref_panel.shape[0]
if ref_panel.shape[0] != query.shape[1]:
err_msg = "Number of sites in the reference panel and query do not match."
raise ValueError(err_msg)
allele_lists = []
for i in range(num_sites):
all_alleles = np.append(ref_panel[i, :], query[:, i])
Expand Down Expand Up @@ -232,9 +242,9 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
if len(mu) != len(num_alleles):
err_msg = "Arrays of mutation probability and number of alleles are unequal in length."
raise ValueError(err_msg)
if isinstance(mu, float):
mu = np.zeros(num_sites, dtype=np.float64) + mu
emission_matrix = np.zeros((num_sites, 2), np.float64) - np.inf
Expand Down Expand Up @@ -272,11 +282,15 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma
:return: Emission probability.
:rtype: float
"""
assert ref_allele != MISSING, "Reference allele cannot be MISSING."
assert query_allele != NONCOPY, "Query allele cannot be NONCOPY."
assert (
emission_matrix.shape[1] == 2
), "Emission probability matrix has incorrect shape."
if ref_allele == MISSING:
err_msg = "Reference allele cannot be MISSING."
raise ValueError(err_msg)
if query_allele == NONCOPY:
err_msg = "Query allele cannot be NONCOPY."
raise ValueError(err_msg)
if emission_matrix.shape[1] != 2:
err_msg = "Emission probability matrix has incorrect shape."
raise ValueError(err_msg)
if ref_allele == NONCOPY:
return 0.0
elif query_allele == MISSING:
Expand Down Expand Up @@ -321,9 +335,9 @@ def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate)
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
if len(mu) != len(num_alleles):
err_msg = "Arrays of mutation probability and number of alleles are unequal in length."
raise ValueError(err_msg)
if isinstance(mu, float):
mu = np.zeros(num_sites, dtype=np.float64) + mu
prob_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf
Expand Down Expand Up @@ -368,9 +382,9 @@ def get_emission_probability_diploid(
def get_emission_probability_diploid_genotypes(
ref_genotypes, query_genotype, site, emission_matrix
):
assert (
ref_genotypes.shape[0] == ref_genotypes.shape[1]
), "Reference genotype matrix must be a square matrix."
if ref_genotypes.shape[0] != ref_genotypes.shape[1]:
err_msg = "Reference genotype matrix must be a square matrix."
raise ValueError(err_msg)
num_ref_genotypes = len(ref_genotypes)
emission_probs = np.zeros((num_ref_genotypes, num_ref_genotypes), dtype=np.float64)
for i in range(num_ref_genotypes):
Expand Down
Loading