From 219da54ed27947c7508993b410a557b05a55c55d Mon Sep 17 00:00:00 2001 From: szhan Date: Tue, 18 Jun 2024 14:14:03 +0100 Subject: [PATCH] Raise errors instead of asserting --- lshmm/core.py | 84 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/lshmm/core.py b/lshmm/core.py index c3d65b8..f6acf75 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -19,8 +19,12 @@ @jit.numba_njit def np_apply_along_axis(func1d, axis, arr): """Create Numpy-like functions for max, sum, etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] + if arr.ndim != 2: + err_msg = "Array does not have two dimensions." + raise ValueError(err_msg) + if axis not in [0, 1]: + err_msg = "Axis is not 0 or 1." + raise ValueError(err_msg) if axis == 0: result = np.empty(arr.shape[1]) for i in range(len(result)): @@ -56,9 +60,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): Convert a set of haplotypes into a matrix of diploid genotypes encoded as allele dosages, and return the genotypes. + TODO: Handle multiallelic sites. It is assumed all sites are biallelic and alleles are encoded as ancestral/derived. The only allowable allele states are 0, 1, and NONCOPY (for partial ancestral haplotypes). - TODO: Handle multiallelic sites. Allowable genotype values are 0, 1, 2, and NONCOPY. If either one haplotype is NONCOPY at a site, then the genotype at the site is assigned NONCOPY. @@ -73,9 +77,9 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): :rtype: numpy.ndarray """ ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int8) - assert np.all( - np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES) - ), f"Reference haplotypes contain illegal allele states." + if not np.all(np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)): + err_msg = "Reference haplotypes contain illegal allele states." + raise ValueError(err_msg) num_sites = ref_panel.shape[0] num_haps = ref_panel.shape[1] genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int8) - np.inf @@ -107,12 +111,14 @@ def convert_haplotypes_to_unphased_genotypes(query): :rtype: numpy.ndarray """ ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int8) - assert np.all( - np.isin(np.unique(query), ALLOWED_ALLELE_STATES) - ), f"Query haplotypes contain illegal allele states." + if not np.all(np.isin(np.unique(query), ALLOWED_ALLELE_STATES)): + err_msg = "Query haplotypes contain illegal allele states." + raise ValueError(err_msg) num_sites = query.shape[1] num_haps = query.shape[0] - assert num_haps == 2, "Two haplotypes are expected in a diploid query." + if num_haps != 2: + err_msg = "Two haplotypes are expected in a diploid query." + raise ValueError(err_msg) genotypes = np.zeros((1, num_sites), dtype=np.int8) - np.inf genotypes[0, :] = np.sum(query, axis=0) genotypes[0, np.any(query == MISSING, axis=0)] = MISSING @@ -137,9 +143,9 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps): :return: True if the condition is satisfied, otherwise False. :rtype: bool """ - assert np.all( - genotype_matrix != MISSING - ), "Reference panel cannot contain any MISSING values." + if not np.all(genotype_matrix != MISSING): + err_msg = "Reference panel cannot contain any MISSING values." + raise ValueError(err_msg) max_num_copiable_entries = 2 * num_sample_haps - 1 num_copiable_entries_per_site = np.sum(genotype_matrix != NONCOPY, axis=1) return np.all(num_copiable_entries_per_site <= max_num_copiable_entries) @@ -147,10 +153,12 @@ def check_genotype_matrix(genotype_matrix, num_sample_haps): @jit.numba_njit def get_num_copiable_entries(ref_panel): - assert ref_panel.ndim in [2, 3], "Reference panel array has incorrect dimensions." - assert np.all( - ref_panel != MISSING - ), "Reference panel cannot contain any MISSING values." + if ref_panel.ndim not in [2, 3]: + err_msg = "Reference panel array has incorrect dimensions." + raise ValueError(err_msg) + if not np.all(ref_panel != MISSING): + err_msg = "Reference panel cannot contain any MISSING values." + raise ValueError(err_msg) if ref_panel.ndim == 2: num_copiable_entries = np.sum(ref_panel != NONCOPY, axis=1) else: @@ -158,15 +166,17 @@ def get_num_copiable_entries(ref_panel): num_copiable_entries = np.zeros(num_sites, dtype=np.int32) for i in range(num_sites): num_copiable_entries[i] = np.sum(ref_panel[i, :, :] != NONCOPY) - assert np.all( - num_copiable_entries > 0 - ), "Number of copiable entries must be greater than zero at all sites." + if not np.all(num_copiable_entries > 0): + err_msg = "Number of copiable entries must be > 0 at all sites." + raise ValueError(err_msg) return num_copiable_entries def get_num_alleles(ref_panel, query): - assert ref_panel.shape[0] == query.shape[1] num_sites = ref_panel.shape[0] + if ref_panel.shape[0] != query.shape[1]: + err_msg = "Number of sites in the reference panel and query do not match." + raise ValueError(err_msg) allele_lists = [] for i in range(num_sites): all_alleles = np.append(ref_panel[i, :], query[:, i]) @@ -232,9 +242,9 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate) :param numpy.ndarray: Number of distinct alleles per site. :param bool scale_mutation_rate: Scale mutation rate based on the number of alleles. """ - assert len(mu) == len( - num_alleles - ), "Arrays of mutation probability and number of alleles are unequal in length." + if len(mu) != len(num_alleles): + err_msg = "Arrays of mutation probability and number of alleles are unequal in length." + raise ValueError(err_msg) if isinstance(mu, float): mu = np.zeros(num_sites, dtype=np.float64) + mu emission_matrix = np.zeros((num_sites, 2), np.float64) - np.inf @@ -272,11 +282,15 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma :return: Emission probability. :rtype: float """ - assert ref_allele != MISSING, "Reference allele cannot be MISSING." - assert query_allele != NONCOPY, "Query allele cannot be NONCOPY." - assert ( - emission_matrix.shape[1] == 2 - ), "Emission probability matrix has incorrect shape." + if ref_allele == MISSING: + err_msg = "Reference allele cannot be MISSING." + raise ValueError(err_msg) + if query_allele == NONCOPY: + err_msg = "Query allele cannot be NONCOPY." + raise ValueError(err_msg) + if emission_matrix.shape[1] != 2: + err_msg = "Emission probability matrix has incorrect shape." + raise ValueError(err_msg) if ref_allele == NONCOPY: return 0.0 elif query_allele == MISSING: @@ -321,9 +335,9 @@ def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate) :param numpy.ndarray: Number of distinct alleles per site. :param bool scale_mutation_rate: Scale mutation rate based on the number of alleles. """ - assert len(mu) == len( - num_alleles - ), "Arrays of mutation probability and number of alleles are unequal in length." + if len(mu) != len(num_alleles): + err_msg = "Arrays of mutation probability and number of alleles are unequal in length." + raise ValueError(err_msg) if isinstance(mu, float): mu = np.zeros(num_sites, dtype=np.float64) + mu prob_mutation = np.zeros(num_sites, dtype=np.float64) - np.inf @@ -368,9 +382,9 @@ def get_emission_probability_diploid( def get_emission_probability_diploid_genotypes( ref_genotypes, query_genotype, site, emission_matrix ): - assert ( - ref_genotypes.shape[0] == ref_genotypes.shape[1] - ), "Reference genotype matrix must be a square matrix." + if ref_genotypes.shape[0] != ref_genotypes.shape[1]: + err_msg = "Reference genotype matrix must be a square matrix." + raise ValueError(err_msg) num_ref_genotypes = len(ref_genotypes) emission_probs = np.zeros((num_ref_genotypes, num_ref_genotypes), dtype=np.float64) for i in range(num_ref_genotypes):