forked from astheeggeggs/lshmm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,477 additions
and
1,924 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
from lshmm import jit | ||
|
||
|
||
EQUAL_BOTH_HOM = 4 | ||
UNEQUAL_BOTH_HOM = 0 | ||
BOTH_HET = 7 | ||
REF_HOM_OBS_HET = 1 | ||
REF_HET_OBS_HOM = 2 | ||
MISSING_INDEX = 3 | ||
|
||
MISSING = -1 | ||
|
||
|
||
""" Helper functions. """ | ||
|
||
|
||
# https://github.com/numba/numba/issues/1269 | ||
@jit.numba_njit | ||
def np_apply_along_axis(func1d, axis, arr): | ||
"""Create Numpy-like functions for max, sum, etc.""" | ||
assert arr.ndim == 2 | ||
assert axis in [0, 1] | ||
if axis == 0: | ||
result = np.empty(arr.shape[1]) | ||
for i in range(len(result)): | ||
result[i] = func1d(arr[:, i]) | ||
else: | ||
result = np.empty(arr.shape[0]) | ||
for i in range(len(result)): | ||
result[i] = func1d(arr[i, :]) | ||
return result | ||
|
||
|
||
@jit.numba_njit | ||
def np_amax(array, axis): | ||
"""Numba implementation of Numpy-vectorised max.""" | ||
return np_apply_along_axis(np.amax, axis, array) | ||
|
||
|
||
@jit.numba_njit | ||
def np_sum(array, axis): | ||
"""Numba implementation of Numpy-vectorised sum.""" | ||
return np_apply_along_axis(np.sum, axis, array) | ||
|
||
|
||
@jit.numba_njit | ||
def np_argmax(array, axis): | ||
"""Numba implementation of Numpy-vectorised argmax.""" | ||
return np_apply_along_axis(np.argmax, axis, array) | ||
|
||
|
||
""" Functions used across different implementations of the LS HMM. """ | ||
|
||
|
||
@jit.numba_njit | ||
def get_index_in_emission_matrix_haploid(ref_allele, query_allele): | ||
is_allele_match = ref_allele == query_allele | ||
is_query_missing = query_allele == MISSING | ||
if is_allele_match or is_query_missing: | ||
return 1 | ||
return 0 | ||
|
||
|
||
@jit.numba_njit | ||
def get_index_in_emission_matrix_diploid(ref_allele, query_allele): | ||
if query_allele == MISSING: | ||
return MISSING_INDEX | ||
else: | ||
is_match = ref_allele == query_allele | ||
is_ref_one = ref_allele == 1 | ||
is_query_one = query_allele == 1 | ||
return 4 * is_match + 2 * is_ref_one + is_query_one | ||
|
||
|
||
@jit.numba_njit | ||
def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n): | ||
if query_allele == MISSING: | ||
return MISSING_INDEX * np.ones((n, n), dtype=np.int64) | ||
else: | ||
is_match = ref_G == query_allele | ||
is_ref_one = ref_G == 1 | ||
is_query_one = query_allele == 1 | ||
return 4 * is_match + 2 * is_ref_one + is_query_one |
Oops, something went wrong.