From 2fc8476ef80ec73ac4fa2e08450b969e25051555 Mon Sep 17 00:00:00 2001 From: szhan Date: Tue, 18 Jun 2024 15:18:09 +0100 Subject: [PATCH] Pre-release checks --- lshmm/api.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index b5c0e10..7b32f38 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -30,8 +30,8 @@ def check_inputs( reference_panel, query, prob_recombination, - prob_mutation=None, - scale_mutation_rate=None, + prob_mutation, + scale_mutation_rate, ): """ Check that the input data and parameters are valid, and return basic info @@ -50,8 +50,8 @@ def check_inputs( :param numpy.ndarray reference_panel: An array of size (m, n) or (m, n, n). :param numpy.ndarray query: An array of size (k, m). :param numpy.ndarray prob_recombination: Recombination probability. - :param numpy.ndarray prob_mutation: Mutation probability. If None (default), set as per Li & Stephens (2003). - :param bool scale_mutation_rate: Scale mutation rate if True (default). + :param numpy.ndarray prob_mutation: Mutation probability. + :param bool scale_mutation_rate: Scale mutation rate. :return: Number of reference haplotypes, number of sites, ploidy :rtype: tuple """ @@ -60,7 +60,7 @@ def check_inputs( # Check the reference panel. if not len(reference_panel.shape) in (2, 3): - err_msg = "Reference panel array must have 2 or 3 dimensions." + err_msg = "Reference panel array has incorrect dimensions." raise ValueError(err_msg) if len(reference_panel.shape) == 2: @@ -129,7 +129,7 @@ def set_emission_probabilities( scale_mutation_rate, ): if isinstance(prob_mutation, float): - prob_mutation = prob_mutation * np.ones(num_sites) + prob_mutation = np.zeros(num_sites) + prob_mutation if ploidy == 1: emission_probs = core.get_emission_matrix_haploid( @@ -159,7 +159,7 @@ def forwards( scale_mutation_rate=None, normalise=None, ): - """Run the forwards algorithm on haplotype or unphased genotype data.""" + """Run the forwards algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True @@ -217,7 +217,7 @@ def backwards( prob_mutation=None, scale_mutation_rate=None, ): - """Run the backwards algorithm on haplotype or unphased genotype data.""" + """Run the backwards algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True @@ -267,7 +267,7 @@ def viterbi( prob_mutation=None, scale_mutation_rate=None, ): - """Run the Viterbi algorithm on haplotype or unphased genotype data.""" + """Run the Viterbi algorithm on haploid or diploid genotype data.""" if scale_mutation_rate is None: scale_mutation_rate = True