Skip to content

Commit

Permalink
Reformat using black
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 19, 2024
1 parent 30bf68f commit cbad776
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 39 deletions.
10 changes: 6 additions & 4 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def checks(
)
warnings.warn(warn_msg)
else:
err_msg = f"Mutation probability is not None, a scalar, or vector of length {m}."
err_msg = (
f"Mutation probability is not None, a scalar, or vector of length {m}."
)
raise ValueError(err_msg)

# Ensure that the recombination probability is either a scalar or a vector of length m
Expand All @@ -149,7 +151,7 @@ def set_emission_probabilities(
scale_mutation_based_on_n_alleles,
):
# Check alleles should go in here, and modify e before passing to the algorithm
# If alleles is not passed, we don't perform a test of alleles,
# If alleles is not passed, we don't perform a test of alleles,
# but set n_alleles based on the reference_panel.
if alleles is None:
exclusion_set = np.array([core.MISSING])
Expand Down Expand Up @@ -200,8 +202,8 @@ def set_emission_probabilities(
# DEV: there's a wrinkle here.
e = np.zeros((m, 8))
e[:, core.EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2
e[:, core.UNEQUAL_BOTH_HOM] = p_mutation ** 2
e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation ** 2
e[:, core.UNEQUAL_BOTH_HOM] = p_mutation**2
e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2
e[:, core.REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation)
e[:, core.REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation)
e[:, core.MISSING_INDEX] = 1
Expand Down
9 changes: 7 additions & 2 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@


""" Helper functions. """


# https://github.com/numba/numba/issues/1269
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
""" Create numpy-like functions for max, sum etc. """
"""Create numpy-like functions for max, sum etc."""
assert arr.ndim == 2
assert axis in [0, 1]
if axis == 0:
Expand All @@ -30,6 +32,7 @@ def np_apply_along_axis(func1d, axis, arr):
result[i] = func1d(arr[i, :])
return result


@jit.numba_njit
def np_amax(array, axis):
"""Numba implementation of numpy vectorised maximum."""
Expand All @@ -44,11 +47,13 @@ def np_sum(array, axis):

@jit.numba_njit
def np_argmax(array, axis):
""" Numba implementation of numpy vectorised argmax. """
"""Numba implementation of numpy vectorised argmax."""
return np_apply_along_axis(np.argmax, axis, array)


""" Functions used across different implementations of the LS HMM. """


@jit.numba_njit
def get_index_in_emission_prob_matrix(a1, a2):
return np.int64(np.equal(a1, a2) or a2 == MISSING)
14 changes: 8 additions & 6 deletions lshmm/forward_backward/fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True):


def backwards_ls_dip(n, m, G, s, e, c, r):
""" Matrix based diploid LS backward algorithm using numpy vectorisation. """
"""Matrix based diploid LS backward algorithm using numpy vectorisation."""
# Initialise the backward tensor
B = np.zeros((m, n, n))

Expand Down Expand Up @@ -124,7 +124,9 @@ def backwards_ls_dip(n, m, G, s, e, c, r):
)

# One changes
sum_j = core.np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n))
sum_j = (
core.np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n))
)
B[l, :, :] += ((1 - r[l + 1]) * r_n[l + 1]) * (sum_j + sum_j.T)
B[l, :, :] *= 1 / c[l + 1]

Expand All @@ -133,7 +135,7 @@ def backwards_ls_dip(n, m, G, s, e, c, r):

@jit.numba_njit
def forward_ls_dip_starting_point(n, m, G, s, e, r):
""" Naive implementation of LS diploid forwards algorithm. """
"""Naive implementation of LS diploid forwards algorithm."""
# Initialise the forward tensor
F = np.zeros((m, n, n))
r_n = r / n
Expand Down Expand Up @@ -214,7 +216,7 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r):

@jit.numba_njit
def backward_ls_dip_starting_point(n, m, G, s, e, r):
""" Naive implementation of LS diploid backwards algorithm. """
"""Naive implementation of LS diploid backwards algorithm."""
# Backwards
B = np.zeros((m, n, n))

Expand Down Expand Up @@ -292,7 +294,7 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r):

@jit.numba_njit
def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):
""" LS diploid forwards algoritm without vectorisation. """
"""LS diploid forwards algoritm without vectorisation."""
# Initialise the forward tensor
F = np.zeros((m, n, n))
for j1 in range(n):
Expand Down Expand Up @@ -411,7 +413,7 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True):

@jit.numba_njit
def backward_ls_dip_loop(n, m, G, s, e, c, r):
""" LS diploid backwards algoritm without vectorisation. """
"""LS diploid backwards algoritm without vectorisation."""
# Initialise the backward tensor
B = np.zeros((m, n, n))
B[m - 1, :, :] = 1
Expand Down
8 changes: 5 additions & 3 deletions lshmm/forward_backward/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@jit.numba_njit
def forwards_ls_hap(n, m, H, s, e, r, norm=True):
""" Matrix based haploid LS forward algorithm using numpy vectorisation. """
"""Matrix based haploid LS forward algorithm using numpy vectorisation."""
F = np.zeros((m, n))
r_n = r / n

Expand Down Expand Up @@ -56,7 +56,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):

@jit.numba_njit
def backwards_ls_hap(n, m, H, s, e, c, r):
""" Matrix based haploid LS backward algorithm using numpy vectorisation. """
"""Matrix based haploid LS backward algorithm using numpy vectorisation."""
B = np.zeros((m, n))
for i in range(n):
B[m - 1, i] = 1
Expand All @@ -67,7 +67,9 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(H[l + 1, i], s[0, l + 1])
emission_idx = core.get_index_in_emission_prob_matrix(
H[l + 1, i], s[0, l + 1]
)
tmp_B[i] = e[l + 1, emission_idx] * B[l + 1, i]
tmp_B_sum += tmp_B[i]
for i in range(n):
Expand Down
14 changes: 7 additions & 7 deletions lshmm/vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@jit.numba_njit
def forwards_viterbi_dip_naive(n, m, G, s, e, r):
""" Naive implementation of LS diploid Viterbi algorithm. """
"""Naive implementation of LS diploid Viterbi algorithm."""
# Initialise
V = np.zeros((m, n, n))
P = np.zeros((m, n, n)).astype(np.int64)
Expand Down Expand Up @@ -64,7 +64,7 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r):

@jit.numba_njit
def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):
""" Naive implementation of LS diploid Viterbi algorithm, with reduced memory. """
"""Naive implementation of LS diploid Viterbi algorithm, with reduced memory."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
Expand Down Expand Up @@ -123,7 +123,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r):

@jit.numba_njit
def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):
""" LS diploid Viterbi algorithm, with reduced memory. """
"""LS diploid Viterbi algorithm, with reduced memory."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
Expand Down Expand Up @@ -206,7 +206,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r):

@jit.numba_njit
def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):
""" LS diploid Viterbi algorithm, with reduced memory. """
"""LS diploid Viterbi algorithm, with reduced memory."""
# Initialise
V = np.zeros((n, n))
V_previous = np.zeros((n, n))
Expand Down Expand Up @@ -303,7 +303,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r):

@jit.numba_njit
def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):
""" Vectorised LS diploid Viterbi algorithm using numpy. """
"""Vectorised LS diploid Viterbi algorithm using numpy."""
# Initialise
V = np.zeros((m, n, n))
P = np.zeros((m, n, n)).astype(np.int64)
Expand Down Expand Up @@ -352,7 +352,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r):


def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r):
""" Fully vectorised naive LS diploid Viterbi algorithm using numpy. """
"""Fully vectorised naive LS diploid Viterbi algorithm using numpy."""
char_both = np.eye(n * n).ravel().reshape((n, n, n, n))
char_col = np.tile(np.sum(np.eye(n * n).reshape((n, n, n, n)), 3), (n, 1, 1, 1))
char_row = np.copy(char_col).T
Expand Down Expand Up @@ -432,7 +432,7 @@ def backwards_viterbi_dip_no_pointer(
recombs_double,
V_last,
):
""" Run a backwards pass to determine the most likely path. """
"""Run a backwards pass to determine the most likely path."""
assert V_last.ndim == 2
assert V_last.shape[0] == V_last.shape[1]
# Initialisation
Expand Down
24 changes: 12 additions & 12 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@jit.numba_njit
def viterbi_naive_init(n, m, H, s, e, r):
""" Initialise a naive implementation of the LS Viterbi algorithm. """
"""Initialise a naive implementation of the LS Viterbi algorithm."""
V = np.zeros((m, n))
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
Expand All @@ -22,7 +22,7 @@ def viterbi_naive_init(n, m, H, s, e, r):

@jit.numba_njit
def viterbi_init(n, m, H, s, e, r):
""" Initialise a naive, but more space memory efficient, implementation of the LS Viterbi algorithm. """
"""Initialise a naive, but more space memory efficient, implementation of the LS Viterbi algorithm."""
V_previous = np.zeros(n)
V = np.zeros(n)
P = np.zeros((m, n)).astype(np.int64)
Expand All @@ -37,7 +37,7 @@ def viterbi_init(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive(n, m, H, s, e, r):
""" Naive implementation of the haploid LS Viterbi algorithm. """
"""Naive implementation of the haploid LS Viterbi algorithm."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)

for j in range(1, m):
Expand All @@ -61,7 +61,7 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
""" Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy. """
"""Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)

for j in range(1, m):
Expand All @@ -81,7 +81,7 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
""" Naive implementation of LS haploid Viterbi algorithm, with reduced memory. """
"""Naive implementation of LS haploid Viterbi algorithm, with reduced memory."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)

for j in range(1, m):
Expand All @@ -106,7 +106,7 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
""" Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling. """
"""Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)
c = np.ones(m)

Expand Down Expand Up @@ -135,7 +135,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
""" LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure. """
"""LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure."""
V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r)
c = np.ones(m)

Expand All @@ -161,7 +161,7 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
""" LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure. """
"""LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure."""
V = np.zeros(n)
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0])
Expand Down Expand Up @@ -190,7 +190,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):

@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
""" LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure. """
"""LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure."""
V = np.zeros(n)
for i in range(n):
emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0])
Expand Down Expand Up @@ -227,7 +227,7 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
# Speedier version, variants x samples
@jit.numba_njit
def backwards_viterbi_hap(m, V_last, P):
""" Run a backwards pass to determine the most likely path. """
"""Run a backwards pass to determine the most likely path."""
assert len(V_last.shape) == 1
path = np.zeros(m).astype(np.int64)
path[m - 1] = np.argmax(V_last)
Expand All @@ -240,7 +240,7 @@ def backwards_viterbi_hap(m, V_last, P):

@jit.numba_njit
def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs):
""" Run a backwards pass to determine the most likely path. """
"""Run a backwards pass to determine the most likely path."""
path = np.zeros(m).astype(np.int64)
path[m - 1] = V_argmaxes[m - 1]

Expand All @@ -255,7 +255,7 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs):

@jit.numba_njit
def path_ll_hap(n, m, H, path, s, e, r):
""" Evaluate log-likelihood path through a reference panel which results in sequence s. """
"""Evaluate log-likelihood path through a reference panel which results in sequence s."""
emission_idx = core.get_index_in_emission_prob_matrix(H[0, path[0]], s[0, 0])
log_prob_path = np.log10((1 / n) * e[0, emission_idx])
old = path[0]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class FBAlgorithmBase(LSBase):


class TestMethodsHap(FBAlgorithmBase):
""" Test that the computed likelihood is the same across all implementations. """
"""Test that the computed likelihood is the same across all implementations."""

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):
Expand All @@ -245,7 +245,7 @@ def verify(self, ts):


class TestMethodsDip(FBAlgorithmBase):
""" Test that the computed likelihood is the same across all implementations. """
"""Test that the computed likelihood is the same across all implementations."""

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts):
Expand All @@ -261,11 +261,11 @@ def verify(self, ts):


class VitAlgorithmBase(LSBase):
""" Base for Viterbi algoritm tests. """
"""Base for Viterbi algoritm tests."""


class TestViterbiHap(VitAlgorithmBase):
""" Test that the computed log-likelihood is the same across all implementations. """
"""Test that the computed log-likelihood is the same across all implementations."""

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts):
Expand All @@ -280,7 +280,7 @@ def verify(self, ts):


class TestViterbiDip(VitAlgorithmBase):
""" Test that the computed log-likelihood is the same across all implementations. """
"""Test that the computed log-likelihood is the same across all implementations."""

def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts):
Expand Down

0 comments on commit cbad776

Please sign in to comment.