Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 10, 2024
1 parent 4f84dd3 commit 9e4a278
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,18 @@ def get_index_in_emission_matrix_haploid(ref_allele, query_allele):

@jit.numba_njit
def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype):
"""
Compare the implied unphased genotypes (allele dosages) of
the reference and query to get the index of the entry
in the emission probability matrix, and return the index.
"""
if query_genotype == MISSING:
return MISSING_INDEX
else:
is_match = ref_genotype == query_genotype
is_ref_one = ref_genotype == 1
is_query_one = query_genotype == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
is_ref_het = ref_genotype == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


@jit.numba_njit
Expand All @@ -191,9 +196,9 @@ def get_index_in_emission_matrix_diploid_genotypes(
return MISSING_INDEX * np.ones((num_ref_haps, num_ref_haps), dtype=np.int64)
else:
is_match = ref_genotypes == query_genotype
is_ref_one = ref_genotypes == 1
is_query_one = query_genotype == 1
return 4 * is_match + 2 * is_ref_one + is_query_one
is_ref_het = ref_genotypes == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate):
Expand Down Expand Up @@ -282,9 +287,11 @@ def get_emission_probability_diploid(
def get_emission_probability_diploid_genotypes(
ref_genotypes, query_genotype, site, emission_matrix
):
emission_probs = np.zeros(ref_genotypes.shape, dtype=np.float64)
for i in range(len(ref_genotypes)):
for j in range(len(ref_genotypes)):
assert ref_genotypes.shape[0] == ref_genotypes.shape[1]
num_ref_haps = len(ref_genotypes)
emission_probs = np.zeros((num_ref_haps, num_ref_haps), dtype=np.float64)
for i in range(num_ref_haps):
for j in range(num_ref_haps):
if ref_genotypes[i, j] == NONCOPY:
emission_probs[i, j] = 0.0
else:
Expand Down

0 comments on commit 9e4a278

Please sign in to comment.