Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent efe28ac commit 052d804
Show file tree
Hide file tree
Showing 13 changed files with 1,513 additions and 1,984 deletions.
164 changes: 55 additions & 109 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@

import numpy as np

from .forward_backward.fb_diploid import backward_ls_dip_loop, forward_ls_dip_loop
from .forward_backward.fb_haploid import backwards_ls_hap, forwards_ls_hap
from . import core
from .fb_diploid import (
backward_ls_dip_loop,
forward_ls_dip_loop,
)
from .fb_haploid import (
backwards_ls_hap,
forwards_ls_hap,
)
from .vit_diploid import (
backwards_viterbi_dip,
forwards_viterbi_dip_low_mem,
Expand All @@ -18,47 +25,6 @@
path_ll_hap,
)

EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
REF_HOM_OBS_HET = 1
REF_HET_OBS_HOM = 2
MISSING_INDEX = 3

MISSING = -1


def check_alleles(alleles, m):
"""
Check a list of allele lists (or strings representing alleles) at m sites, and
return a list of counts of distinct alleles at the m sites.
If alleles is a list of strings, then each string represents distinct alleles
at a site, and each character in a string represents a distinct allele.
It is assumed that MISSING is not encoded in these strings.
Note MISSING values in allele lists are excluded from the counts.
:param list alleles: A list of lists of alleles (or strings).
:param int m: Number of sites.
:return: An array of counts of distinct alleles at each site.
:rtype: numpy.ndarray
"""
num_sites = m
if len(alleles) != num_sites:
err_msg = "Number of allele lists (or strings) is not equal to number of sites."
raise ValueError(err_msg)
# Process string encoding of distinct alleles.
if isinstance(alleles[0], str):
return np.int8([len(alleles) for _ in range(m)])
# Otherwise, process allele lists.
exclusion_set = np.array([MISSING])
n_alleles = np.zeros(num_sites, dtype=np.int8)
for i in range(num_sites):
uniq_alleles = np.unique(alleles[i])
n_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set))
return n_alleles


def checks(
reference_panel,
Expand All @@ -68,7 +34,7 @@ def checks(
scale_mutation_based_on_n_alleles,
):
"""
Checks that the input data and parameters are valid.
Check that the input data and parameters are valid.
The reference panel must be a matrix of size (m, n) or (m, n, n).
The query must be a matrix of size (k, m) or (k, m, 2).
Expand All @@ -77,17 +43,21 @@ def checks(
n = number of samples in the reference panel (haplotypes, not individuals).
k = number of samples in the query (haplotypes, not individuals).
The mutation rate can be scaled according to the set of alleles
that can be mutated to based on the number of distinct alleles at each site.
:param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n).
:param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2).
:param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m.
:param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m.
:param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False).
:param bool scale_mutation_based_on_n_alleles: Scale the mutation probability or not.
:return: n, m, ploidy
:rtype: tuple
"""
# Check reference panel
if not len(reference_panel.shape) in (2, 3):
raise ValueError("Reference panel array must have 2 or 3 dimensions.")
err_msg = "Reference panel array must have 2 or 3 dimensions."
raise ValueError(err_msg)

if len(reference_panel.shape) == 2:
m, n = reference_panel.shape
Expand All @@ -97,42 +67,49 @@ def checks(
ploidy = 2

if ploidy == 2 and (reference_panel.shape[1] != reference_panel.shape[2]):
raise ValueError(
"Reference_panel dimensions are incorrect, perhaps a sample x sample x variant matrix was passed. Expected sites x samples x samples."
err_msg = (
"Reference_panel dimensions are incorrect, "
"perhaps a sample x sample x variant matrix was passed. "
"Expected sites x samples x samples."
)
raise ValueError(err_msg)

# Check query sequence(s)
if query.shape[1] != m:
raise ValueError(
"Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed."
err_msg = (
"Number of sites in query does not match reference panel. "
"If haploid, ensure a sites x samples matrix is passed."
)
raise ValueError(err_msg)

# Ensure that the mutation probability is either a scalar or vector of length m
# Ensure that the mutation probability is either a scalar or vector of length m.
if isinstance(p_mutation, (int, float)):
if not scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site."
)
warn_msg = "Passed a scalar mutation probability, but not rescaling it."
warnings.warn(warn_msg)
elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m:
if scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site."
)
warn_msg = "Passed a vector of mutation probabilities. Rescaling them."
warnings.warn(warn_msg)
elif p_mutation is None:
warnings.warn(
"No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)"
warn_msg = (
"No mutation probability passed. "
"Setting it based on Li & Stephens (2003) equations A2 and A3."
)
warnings.warn(warn_msg)
else:
raise ValueError(
f"Mutation probability is not None, a scalar, or vector of length m: {m}"
err_msg = (
f"Mutation probability is not None, a scalar, or vector of length {m}."
)
raise ValueError(err_msg)

# Ensure that the recombination probability is either a scalar or a vector of length m
if not (
isinstance(p_recombination, (int, float))
or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m)
):
raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}")
err_msg = f"Recombination probability is not a scalar or vector of length {m}."
raise ValueError(err_msg)

return (n, m, ploidy)

Expand All @@ -148,60 +125,32 @@ def set_emission_probabilities(
scale_mutation_based_on_n_alleles,
):
# Check alleles should go in here, and modify e before passing to the algorithm
# If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel.
# If alleles is not passed, we don't perform a test of alleles,
# but set n_alleles based on the reference_panel.
if alleles is None:
exclusion_set = np.array([MISSING])
n_alleles = np.zeros(m, dtype=np.int8)
for j in range(reference_panel.shape[0]):
uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j]))
n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set))
n_alleles = core.get_num_alleles(reference_panel, query)
else:
n_alleles = check_alleles(alleles, m)
n_alleles = core.check_alleles(alleles, m)

if p_mutation is None:
# Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003).
# Set the mutation probability to be the proposed mutation probability in Li & Stephens (2003).
theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)])
p_mutation = 0.5 * (theta_tilde / (n + theta_tilde))

if isinstance(p_mutation, float):
p_mutation = p_mutation * np.ones(m)

if ploidy == 1:
# Haploid
# Evaluate emission probabilities here using p_mutation - this can take a scalar or vector.
e = np.zeros((m, 2))

if scale_mutation_based_on_n_alleles:
# Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles.
# The overall mutation probability is then (n_alleles - 1) * p_mutation.
e[:, 0] = p_mutation - p_mutation * np.equal(
n_alleles, np.ones(m)
) # Added boolean in case we're at an invariant site
e[:, 1] = 1 - (n_alleles - 1) * p_mutation
else:
# No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything
# (summing over the states we can switch to). This means that we must rescale the probability of mutation to
# a different allele by the number of alleles at the site.
for j in range(m):
if n_alleles[j] == 1: # In case we're at an invariant site
e[j, 0] = 0
e[j, 1] = 1
else:
e[j, 0] = p_mutation[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - p_mutation[j]
emission_probs = core.get_emission_matrix_haploid(
mu=p_mutation,
m=m,
n_alleles=n_alleles,
scale_mutation_based_on_n_alleles=scale_mutation_based_on_n_alleles
)
else:
# Diploid
# Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector.
# DEV: there's a wrinkle here.
e = np.zeros((m, 8))
e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2
e[:, UNEQUAL_BOTH_HOM] = p_mutation**2
e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2
e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation)
e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation)
e[:, MISSING_INDEX] = 1
emission_probs = core.get_emission_matrix_diploid(mu=p_mutation, m=m)

return e
return emission_probs


def viterbi_hap(n, m, reference_panel, query, emissions, p_recombination):
Expand Down Expand Up @@ -233,8 +182,7 @@ def forwards(
norm=True,
):
"""
Run the Li and Stephens forwards algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens forwards algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down Expand Up @@ -281,8 +229,7 @@ def backwards(
scale_mutation_based_on_n_alleles=True,
):
"""
Run the Li and Stephens backwards algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens backwards algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down Expand Up @@ -330,8 +277,7 @@ def viterbi(
scale_mutation_based_on_n_alleles=True,
):
"""
Run the Li and Stephens Viterbi algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens Viterbi algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down
Loading

0 comments on commit 052d804

Please sign in to comment.