Skip to content

Commit

Permalink
Modify haploid Viterbi to handle NONCOPY state in reference panel
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 26, 2024
1 parent f94dd05 commit 3f7a936
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 25 deletions.
70 changes: 45 additions & 25 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import jit

MISSING = -1
NONCOPY = -2


@jit.numba_njit
Expand All @@ -13,10 +14,10 @@ def viterbi_naive_init(n, m, H, s, e, r):
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
for i in range(n):
V[0, i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)

em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[0, i] = 1 / n * em_prob
return V, P, r_n


Expand All @@ -29,9 +30,10 @@ def viterbi_init(n, m, H, s, e, r):
r_n = r / n

for i in range(n):
V_previous[i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V_previous[i] = 1 / n * em_prob

return V, V_previous, P, r_n

Expand All @@ -47,10 +49,10 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V[j - 1, k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V[j - 1, k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand All @@ -74,7 +76,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
for i in range(n):
v = np.copy(v_tmp)
v[i] += V[j - 1, i] * (1 - r[j])
v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v *= em_prob
P[j, i] = np.argmax(v)
V[j, i] = v[P[j, i]]

Expand All @@ -94,10 +99,10 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = (em_prob * V_previous[k])
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -125,10 +130,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V_previous[k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -161,7 +166,10 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob
V_previous = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand All @@ -175,7 +183,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
c = np.ones(m)
Expand All @@ -190,7 +201,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

ll = np.sum(np.log10(c)) + np.log10(np.max(V))

Expand All @@ -203,7 +217,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
r_n = r / n
c = np.ones(m)
recombs = [
Expand All @@ -224,7 +241,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
recombs[j] = np.append(
recombs[j], i
) # We add template i as a potential template to recombine to at site j.
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

V_argmaxes[m - 1] = np.argmax(V)
ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand Down
232 changes: 232 additions & 0 deletions tests/test_API_noncopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import numpy as np
import pytest

import lshmm.vit_haploid as vh

MISSING = -1
NONCOPY = -2
NC = NONCOPY # Sugar


# Helper functions
# TODO: Use the functions in the API instead.
def _get_emission_probabilities(m, p_mutation, n_alleles):
# Note that this is different than `set_emission_probabilities` in `api.py`.
# No scaling.
e = np.zeros((m, 2))
for j in range(m):
if n_alleles[j] == 1:
e[j, 0] = 0
e[j, 1] = 1
else:
e[j, 0] = p_mutation[j] / (n_alleles[j] - 1)
e[j, 1] = 1 - p_mutation[j]
return e


def _get_num_alleles_per_site(H):
# Used to rescale mutation and recombination probabilities.
m = H.shape[0] # Number of sites
n_alleles = np.zeros(m, dtype=np.int64) - 1
for i in range(m):
uniq_a = np.unique(H[i, :])
assert len(uniq_a) > 0
assert MISSING not in uniq_a
n_alleles[i] = np.sum(uniq_a != NONCOPY)
return n_alleles


# Prepare test data for testing.
def get_example_data():
"""
Assumptions:
1. Non-NONCOPY states are contiguous.
2. No MISSING states in ref. panel.
"""
# Only NONCOPY
H_only_noncopy = np.array([
[NC, NC, NC, NC, NC, NC, NC, NC, NC, NC],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]).T
query_only_noncopy = np.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
path_only_noncopy = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
# NONCOPY on right
H_noncopy_on_right = np.array([
[ 0, 0, 0, 0, 0, NC, NC, NC, NC, NC],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_noncopy_on_right = np.array([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])
path_noncopy_on_right = np.array([ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
# NONCOPY on left
H_noncopy_on_left = np.array([
[NC, NC, NC, NC, NC, 0, 0, 0, 0, 0],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_noncopy_on_left = np.array([[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
path_noncopy_on_left = np.array([ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
# NONCOPY in middle
H_noncopy_middle = np.array([
[NC, NC, NC, 0, 0, 0, 0, NC, NC, NC],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_noncopy_middle = np.array([[ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1]])
path_noncopy_middle = np.array([ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1])
# Two switches
H_two_switches = np.array([
[ 0, 0, 0, NC, NC, NC, NC, NC, NC, NC],
[NC, NC, NC, 0, 0, 0, NC, NC, NC, NC],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_two_switches = np.array([[ 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
path_two_switches = np.array([ 0, 0, 0, 1, 1, 1, 2, 2, 2, 2])
# MISSING at switch position
# This causes more than one best paths
H_miss_switch = np.array([
[NC, NC, NC, 0, 0, 0, 0, NC, NC, NC],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_miss_switch = np.array([[ 1, 1, 1, -1, 0, 0, 0, 1, 1, 1]])
path_miss_switch = np.array([ 1, 1, 1, 1, 0, 0, 0, 1, 1, 1])
# MISSING left of switch position.
H_miss_next_switch = np.array([
[NC, NC, NC, 0, 0, 0, 0, NC, NC, NC],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]).T
query_next_switch = np.array([[ 1, 1, -1, 0, 0, 0, 0, 1, 1, 1]])
path_next_switch = np.array([ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1])

return [
(H_only_noncopy, query_only_noncopy, path_only_noncopy),
(H_noncopy_on_right, query_noncopy_on_right, path_noncopy_on_right),
(H_noncopy_on_left, query_noncopy_on_left, path_noncopy_on_left),
(H_noncopy_middle, query_noncopy_middle, path_noncopy_middle),
(H_two_switches, query_two_switches, path_two_switches),
(H_miss_switch, query_miss_switch, path_miss_switch),
(H_miss_next_switch, query_next_switch, path_next_switch),
]


# Tests for naive matrix-based implementation.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_naive(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_naive(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for naive matrix-based implementation using numpy.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_naive_vec(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_naive_vec(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for naive matrix-based implementation with reduced memory.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_naive_low_mem(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for naive matrix-based implementation with reduced memory and rescaling.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_naive_low_mem_rescaling(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for implementation with reduced memory and rescaling.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_low_mem_rescaling(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for implementation with even more reduced memory and rescaling.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_lower_mem_rescaling(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, actual_ll = vh.forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)


# Tests for implementation with even more reduced memory and rescaling, without keeping pointers.
@pytest.mark.parametrize(
"H, s, expected_path", get_example_data()
)
def test_forwards_viterbi_hap_lower_mem_rescaling_no_pointer(H, s, expected_path):
m, n = H.shape
r = np.zeros(m, dtype=np.float64) + 0.20
p_mutation = np.zeros(m, dtype=np.float64) + 0.10

n_alleles = _get_num_alleles_per_site(H)
e = _get_emission_probabilities(m, p_mutation, n_alleles)

_, _, _, actual_ll = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r)
expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r)

assert np.allclose(expected_ll, actual_ll)

0 comments on commit 3f7a936

Please sign in to comment.