-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
448 additions
and
495 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
|
||
from lshmm import jit | ||
|
||
|
||
EQUAL_BOTH_HOM = 4 | ||
UNEQUAL_BOTH_HOM = 0 | ||
BOTH_HET = 7 | ||
REF_HOM_OBS_HET = 1 | ||
REF_HET_OBS_HOM = 2 | ||
MISSING_INDEX = 3 | ||
|
||
MISSING = -1 | ||
|
||
|
||
""" Helper functions. """ | ||
|
||
|
||
# https://github.com/numba/numba/issues/1269 | ||
@jit.numba_njit | ||
def np_apply_along_axis(func1d, axis, arr): | ||
"""Create numpy-like functions for max, sum etc.""" | ||
assert arr.ndim == 2 | ||
assert axis in [0, 1] | ||
if axis == 0: | ||
result = np.empty(arr.shape[1]) | ||
for i in range(len(result)): | ||
result[i] = func1d(arr[:, i]) | ||
else: | ||
result = np.empty(arr.shape[0]) | ||
for i in range(len(result)): | ||
result[i] = func1d(arr[i, :]) | ||
return result | ||
|
||
|
||
@jit.numba_njit | ||
def np_amax(array, axis): | ||
"""Numba implementation of numpy vectorised maximum.""" | ||
return np_apply_along_axis(np.amax, axis, array) | ||
|
||
|
||
@jit.numba_njit | ||
def np_sum(array, axis): | ||
"""Numba implementation of numpy vectorised sum.""" | ||
return np_apply_along_axis(np.sum, axis, array) | ||
|
||
|
||
@jit.numba_njit | ||
def np_argmax(array, axis): | ||
"""Numba implementation of numpy vectorised argmax.""" | ||
return np_apply_along_axis(np.argmax, axis, array) | ||
|
||
|
||
""" Functions used across different implementations of the LS HMM. """ | ||
|
||
|
||
@jit.numba_njit | ||
def get_index_in_emission_prob_matrix(ref_allele, query_allele): | ||
is_allele_match = np.equal(ref_allele, query_allele) | ||
is_query_missing = query_allele == MISSING | ||
if is_allele_match or is_query_missing: | ||
return 1 | ||
return 0 | ||
|
||
|
||
@jit.numba_njit | ||
def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele): | ||
if query_allele == MISSING: | ||
return MISSING_INDEX | ||
else: | ||
is_allele_match = ref_allele == query_allele | ||
is_ref_one = ref_allele == 1 | ||
is_query_one = query_allele == 1 | ||
return 4 * is_allele_match + 2 * is_ref_one + is_query_one |
Oops, something went wrong.