Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent efe28ac commit c7ba3b2
Show file tree
Hide file tree
Showing 12 changed files with 1,270 additions and 1,564 deletions.
108 changes: 59 additions & 49 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@

import numpy as np

from .forward_backward.fb_diploid import backward_ls_dip_loop, forward_ls_dip_loop
from .forward_backward.fb_haploid import backwards_ls_hap, forwards_ls_hap
from . import core
from .fb_diploid import (
backward_ls_dip_loop,
forward_ls_dip_loop,
)
from .fb_haploid import (
backwards_ls_hap,
forwards_ls_hap,
)
from .vit_diploid import (
backwards_viterbi_dip,
forwards_viterbi_dip_low_mem,
Expand All @@ -18,15 +25,6 @@
path_ll_hap,
)

EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
REF_HOM_OBS_HET = 1
REF_HET_OBS_HOM = 2
MISSING_INDEX = 3

MISSING = -1


def check_alleles(alleles, m):
"""
Expand All @@ -52,7 +50,7 @@ def check_alleles(alleles, m):
if isinstance(alleles[0], str):
return np.int8([len(alleles) for _ in range(m)])
# Otherwise, process allele lists.
exclusion_set = np.array([MISSING])
exclusion_set = np.array([core.MISSING])
n_alleles = np.zeros(num_sites, dtype=np.int8)
for i in range(num_sites):
uniq_alleles = np.unique(alleles[i])
Expand All @@ -68,7 +66,7 @@ def checks(
scale_mutation_based_on_n_alleles,
):
"""
Checks that the input data and parameters are valid.
Check that the input data and parameters are valid.
The reference panel must be a matrix of size (m, n) or (m, n, n).
The query must be a matrix of size (k, m) or (k, m, 2).
Expand All @@ -77,17 +75,21 @@ def checks(
n = number of samples in the reference panel (haplotypes, not individuals).
k = number of samples in the query (haplotypes, not individuals).
The mutation rate can be scaled according to the set of alleles
that can be mutated to based on the number of distinct alleles at each site.
:param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n).
:param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2).
:param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m.
:param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m.
:param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False).
:param bool scale_mutation_based_on_n_alleles: Scale the mutation probability or not.
:return: n, m, ploidy
:rtype: tuple
"""
# Check reference panel
if not len(reference_panel.shape) in (2, 3):
raise ValueError("Reference panel array must have 2 or 3 dimensions.")
err_msg = "Reference panel array must have 2 or 3 dimensions."
raise ValueError(err_msg)

if len(reference_panel.shape) == 2:
m, n = reference_panel.shape
Expand All @@ -97,42 +99,49 @@ def checks(
ploidy = 2

if ploidy == 2 and (reference_panel.shape[1] != reference_panel.shape[2]):
raise ValueError(
"Reference_panel dimensions are incorrect, perhaps a sample x sample x variant matrix was passed. Expected sites x samples x samples."
err_msg = (
"Reference_panel dimensions are incorrect, "
"perhaps a sample x sample x variant matrix was passed. "
"Expected sites x samples x samples."
)
raise ValueError(err_msg)

# Check query sequence(s)
if query.shape[1] != m:
raise ValueError(
"Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed."
err_msg = (
"Number of sites in query does not match reference panel. "
"If haploid, ensure a sites x samples matrix is passed."
)
raise ValueError(err_msg)

# Ensure that the mutation probability is either a scalar or vector of length m
# Ensure that the mutation probability is either a scalar or vector of length m.
if isinstance(p_mutation, (int, float)):
if not scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site."
)
warn_msg = "Passed a scalar mutation probability, but not rescaling it."
warnings.warn(warn_msg)
elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m:
if scale_mutation_based_on_n_alleles:
warnings.warn(
"Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site."
)
warn_msg = "Passed a vector of mutation probabilities. Rescaling them."
warnings.warn(warn_msg)
elif p_mutation is None:
warnings.warn(
"No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)"
warn_msg = (
"No mutation probability passed. "
"Setting it based on Li & Stephens (2003) equations A2 and A3."
)
warnings.warn(warn_msg)
else:
raise ValueError(
f"Mutation probability is not None, a scalar, or vector of length m: {m}"
err_msg = (
f"Mutation probability is not None, a scalar, or vector of length {m}."
)
raise ValueError(err_msg)

# Ensure that the recombination probability is either a scalar or a vector of length m
if not (
isinstance(p_recombination, (int, float))
or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m)
):
raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}")
err_msg = f"Recombination probability is not a scalar or vector of length {m}."
raise ValueError(err_msg)

return (n, m, ploidy)

Expand All @@ -148,9 +157,10 @@ def set_emission_probabilities(
scale_mutation_based_on_n_alleles,
):
# Check alleles should go in here, and modify e before passing to the algorithm
# If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel.
# If alleles is not passed, we don't perform a test of alleles,
# but set n_alleles based on the reference_panel.
if alleles is None:
exclusion_set = np.array([MISSING])
exclusion_set = np.array([core.MISSING])
n_alleles = np.zeros(m, dtype=np.int8)
for j in range(reference_panel.shape[0]):
uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j]))
Expand All @@ -159,7 +169,7 @@ def set_emission_probabilities(
n_alleles = check_alleles(alleles, m)

if p_mutation is None:
# Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003).
# Set the mutation probability to be the proposed mutation probability in Li & Stephens (2003).
theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)])
p_mutation = 0.5 * (theta_tilde / (n + theta_tilde))

Expand All @@ -172,15 +182,18 @@ def set_emission_probabilities(
e = np.zeros((m, 2))

if scale_mutation_based_on_n_alleles:
# Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles.
# Scale mutation based on the number of alleles,
# so p_mutation is probability of mutation any given one of the alleles.
# The overall mutation probability is then (n_alleles - 1) * p_mutation.
e[:, 0] = p_mutation - p_mutation * np.equal(
n_alleles, np.ones(m)
) # Added boolean in case we're at an invariant site
e[:, 1] = 1 - (n_alleles - 1) * p_mutation
else:
# No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything
# (summing over the states we can switch to). This means that we must rescale the probability of mutation to
# No scaling based on the number of alleles,
# so p_mutation is the probability of mutation to anything
# (summing over the states we can switch to).
# This means that we must rescale the probability of mutation to
# a different allele by the number of alleles at the site.
for j in range(m):
if n_alleles[j] == 1: # In case we're at an invariant site
Expand All @@ -194,12 +207,12 @@ def set_emission_probabilities(
# Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector.
# DEV: there's a wrinkle here.
e = np.zeros((m, 8))
e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2
e[:, UNEQUAL_BOTH_HOM] = p_mutation**2
e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2
e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation)
e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation)
e[:, MISSING_INDEX] = 1
e[:, core.EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = p_mutation**2
e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2
e[:, core.REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation)
e[:, core.REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation)
e[:, core.MISSING_INDEX] = 1

return e

Expand Down Expand Up @@ -233,8 +246,7 @@ def forwards(
norm=True,
):
"""
Run the Li and Stephens forwards algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens forwards algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down Expand Up @@ -281,8 +293,7 @@ def backwards(
scale_mutation_based_on_n_alleles=True,
):
"""
Run the Li and Stephens backwards algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens backwards algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down Expand Up @@ -330,8 +341,7 @@ def viterbi(
scale_mutation_based_on_n_alleles=True,
):
"""
Run the Li and Stephens Viterbi algorithm on haplotype or
unphased genotype data.
Run the Li & Stephens Viterbi algorithm on haplotype or unphased genotype data.
"""
n, m, ploidy = checks(
reference_panel,
Expand Down
85 changes: 85 additions & 0 deletions lshmm/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

from lshmm import jit


EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
REF_HOM_OBS_HET = 1
REF_HET_OBS_HOM = 2
MISSING_INDEX = 3

MISSING = -1


""" Helper functions. """


# https://github.com/numba/numba/issues/1269
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
"""Create Numpy-like functions for max, sum, etc."""
assert arr.ndim == 2
assert axis in [0, 1]
if axis == 0:
result = np.empty(arr.shape[1])
for i in range(len(result)):
result[i] = func1d(arr[:, i])
else:
result = np.empty(arr.shape[0])
for i in range(len(result)):
result[i] = func1d(arr[i, :])
return result


@jit.numba_njit
def np_amax(array, axis):
"""Numba implementation of Numpy-vectorised max."""
return np_apply_along_axis(np.amax, axis, array)


@jit.numba_njit
def np_sum(array, axis):
"""Numba implementation of Numpy-vectorised sum."""
return np_apply_along_axis(np.sum, axis, array)


@jit.numba_njit
def np_argmax(array, axis):
"""Numba implementation of Numpy-vectorised argmax."""
return np_apply_along_axis(np.argmax, axis, array)


""" Functions used across different implementations of the LS HMM. """


@jit.numba_njit
def get_index_in_emission_matrix(ref_allele, query_allele):
is_allele_match = np.equal(ref_allele, query_allele)
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
return 0


@jit.numba_njit
def get_index_in_emission_matrix_diploid(ref_allele, query_allele):
if query_allele == MISSING:
return MISSING_INDEX
else:
is_match = ref_allele == query_allele
is_ref_one = ref_allele == 1
is_query_one = query_allele == 1
return 4 * is_match + 2 * is_ref_one + is_query_one


@jit.numba_njit
def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n):
if query_allele == MISSING:
return MISSING_INDEX * np.ones((n, n), dtype=np.int64)
else:
is_match = ref_G == query_allele
is_ref_one = ref_G == 1
is_query_one = query_allele == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
Loading

0 comments on commit c7ba3b2

Please sign in to comment.