diff --git a/lshmm/api.py b/lshmm/api.py index 2a16765..2e37c6b 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -29,13 +29,14 @@ def check_inputs( reference_panel, query, + num_alleles, prob_recombination, prob_mutation, scale_mutation_rate, ): """ - Check that the input data and parameters are valid, and return basic info - about the data: the number of reference haplotypes, number of sites, and ploidy. + Check that the input data and parameters are valid, and return data to run + the HMM algorithms. The reference panel must be an array of size (m, n) in the haploid case or (m, n, n) in the diploid case, and the query must be an array of size (k, m), @@ -44,23 +45,23 @@ def check_inputs( n = number of samples in the reference panel (haplotypes, not individuals). k = number of samples in the query (haplotypes, not individuals). + TODO: Support running on multiple queries. + The mutation rate can be scaled according to the set of alleles that can be mutated to based on the number of distinct alleles at each site. - :param numpy.ndarray reference_panel: An array of size (m, n) or (m, n, n). - :param numpy.ndarray query: An array of size (k, m). + :param numpy.ndarray reference_panel: A panel of reference sequences. + :param numpy.ndarray query: A query sequence. + :param numpy.ndarray num_alleles: Number of distinct alleles per site. :param numpy.ndarray prob_recombination: Recombination probability. :param numpy.ndarray prob_mutation: Mutation probability. :param bool scale_mutation_rate: Scale mutation rate. - :return: Number of reference haplotypes, number of sites, ploidy + :return: Number of ref. haplotypes, number of sites, ploidy, emission prob. matrix. :rtype: tuple """ - if scale_mutation_rate is None: - scale_mutation_rate = True - # Check the reference panel. if not len(reference_panel.shape) in (2, 3): - err_msg = "Reference panel array has incorrect dimensions." + err_msg = "Reference panel array has incorrect number of dimensions." raise ValueError(err_msg) if len(reference_panel.shape) == 2: @@ -79,35 +80,26 @@ def check_inputs( # Check the queries. if query.shape[1] != num_sites: - err_msg = "Number of sites in query does not match reference panel." + err_msg = "Number of sites in the query and reference panel do not match." raise ValueError(err_msg) if np.any(query == core.NONCOPY): err_msg = "Queries cannot have any NONCOPY values." raise ValueError(err_msg) - # Check the mutation probability. - if isinstance(prob_mutation, (int, float)): - if not scale_mutation_rate: - warn_msg = "Passed a scalar mutation probability, but not rescaling it." - warnings.warn(warn_msg) - elif isinstance(prob_mutation, np.ndarray) and len(prob_mutation) == num_sites: - if scale_mutation_rate: - warn_msg = "Passed an array of mutation probabilities. Rescaling them." - warnings.warn(warn_msg) - elif prob_mutation is None: - warn_msg = ( - "No mutation probability is passed. " - "Setting it based on Li & Stephens (2003) equations A2 and A3." - ) - warnings.warn(warn_msg) - else: - err_msg = "Mutation probability is not a scalar or an array of expected length." + # Check the number of distinct alleles per site. + if len(num_alleles) != num_sites: + err_msg = "Number of alleles is not an array of expected length." + raise ValueError(err_msg) + + if not np.all(num_alleles > 0) or not np.issubdtype(num_alleles.dtype, np.integer): + err_msg = "Number of alleles must be positive integers." raise ValueError(err_msg) # Check the recombination probability. if isinstance(prob_recombination, (int, float)): - pass + prob_recombination = np.zeros(num_sites, dtype=np.float64) + prob_recombination + prob_recombination[0] = 0.0 elif ( isinstance(prob_recombination, np.ndarray) and len(prob_recombination) == num_sites @@ -121,35 +113,42 @@ def check_inputs( ) raise ValueError(err_msg) - return (num_ref_haps, num_sites, ploidy) - - -def set_emission_probabilities( - num_sites, - ploidy, - num_alleles, - prob_mutation, - scale_mutation_rate, -): - if isinstance(prob_mutation, float): - prob_mutation = np.zeros(num_sites) + prob_mutation + # Check the mutation probability. + if prob_mutation is None: + warn_msg = "No mutation probability is passed; setting it as per Li & Stephens (2003) eqn. A2 and A3." + warnings.warn(warn_msg) + prob_mutation = core.estimate_mutation_probability(num_ref_haps) + prob_mutation = np.zeros(num_sites, dtype=np.float64) + prob_mutation + elif isinstance(prob_mutation, (int, float)): + if not scale_mutation_rate: + warn_msg = "A scalar mutation probability is passed, but not rescaling it." + warnings.warn(warn_msg) + prob_mutation = np.zeros(num_sites, dtype=np.float64) + prob_mutation + elif isinstance(prob_mutation, np.ndarray) and len(prob_mutation) == num_sites: + if scale_mutation_rate: + warn_msg = "Rescaling an array of mutation probabilities." + warnings.warn(warn_msg) + else: + err_msg = "Mutation probability is not a scalar or an array of expected length." + raise ValueError(err_msg) + # Calculate the emission probability matrix. if ploidy == 1: - emission_probs = core.get_emission_matrix_haploid( + emission_matrix = core.get_emission_matrix_haploid( mu=prob_mutation, num_sites=num_sites, num_alleles=num_alleles, scale_mutation_rate=scale_mutation_rate, ) else: - emission_probs = core.get_emission_matrix_diploid( + emission_matrix = core.get_emission_matrix_diploid( mu=prob_mutation, num_sites=num_sites, num_alleles=num_alleles, scale_mutation_rate=scale_mutation_rate, ) - return emission_probs + return num_ref_haps, num_sites, ploidy, emission_matrix def forwards( @@ -169,21 +168,11 @@ def forwards( if normalise is None: normalise = True - num_ref_haps, num_sites, ploidy = check_inputs( + num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, - ) - - if prob_mutation is None: - prob_mutation = core.estimate_mutation_probability(num_ref_haps) - - emission_probs = set_emission_probabilities( - num_sites=num_sites, - ploidy=ploidy, num_alleles=num_alleles, + prob_recombination=prob_recombination, prob_mutation=prob_mutation, scale_mutation_rate=scale_mutation_rate, ) @@ -202,7 +191,7 @@ def forwards( num_sites, reference_panel, query, - emission_probs, + emission_matrix, prob_recombination, norm=normalise, ) @@ -224,21 +213,11 @@ def backwards( if scale_mutation_rate is None: scale_mutation_rate = True - num_ref_haps, num_sites, ploidy = check_inputs( + num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, - ) - - if prob_mutation is None: - prob_mutation = core.estimate_mutation_probability(num_ref_haps) - - emission_probs = set_emission_probabilities( - num_sites=num_sites, - ploidy=ploidy, num_alleles=num_alleles, + prob_recombination=prob_recombination, prob_mutation=prob_mutation, scale_mutation_rate=scale_mutation_rate, ) @@ -253,7 +232,7 @@ def backwards( num_sites, reference_panel, query, - emission_probs, + emission_matrix, normalisation_factor_from_forward, prob_recombination, ) @@ -274,21 +253,11 @@ def viterbi( if scale_mutation_rate is None: scale_mutation_rate = True - num_ref_haps, num_sites, ploidy = check_inputs( + num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, - ) - - if prob_mutation is None: - prob_mutation = core.estimate_mutation_probability(num_ref_haps) - - emission_probs = set_emission_probabilities( - num_sites=num_sites, - ploidy=ploidy, num_alleles=num_alleles, + prob_recombination=prob_recombination, prob_mutation=prob_mutation, scale_mutation_rate=scale_mutation_rate, ) @@ -299,7 +268,7 @@ def viterbi( num_sites, reference_panel, query, - emission_probs, + emission_matrix, prob_recombination, ) best_path = backwards_viterbi_hap(num_sites, V, P) @@ -309,7 +278,7 @@ def viterbi( num_sites, reference_panel, query, - emission_probs, + emission_matrix, prob_recombination, ) unphased_path = backwards_viterbi_dip(num_sites, V, P) @@ -332,21 +301,11 @@ def path_loglik( if scale_mutation_rate is None: scale_mutation_rate = True - num_ref_haps, num_sites, ploidy = check_inputs( + num_ref_haps, num_sites, ploidy, emission_matrix = check_inputs( reference_panel=reference_panel, query=query, - prob_recombination=prob_recombination, - prob_mutation=prob_mutation, - scale_mutation_rate=scale_mutation_rate, - ) - - if prob_mutation is None: - prob_mutation = core.estimate_mutation_probability(num_ref_haps) - - emission_probs = set_emission_probabilities( - num_sites=num_sites, - ploidy=ploidy, num_alleles=num_alleles, + prob_recombination=prob_recombination, prob_mutation=prob_mutation, scale_mutation_rate=scale_mutation_rate, ) @@ -362,7 +321,7 @@ def path_loglik( reference_panel, path, query, - emission_probs, + emission_matrix, prob_recombination, )