diff --git a/lshmm/core.py b/lshmm/core.py index e912959..137a8bc 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -322,14 +322,14 @@ def get_emission_probability_haploid_hkylike(ref_allele, query_allele, site, emi raise ValueError("Reference allele cannot be MISSING.") if query_allele == NONCOPY: raise ValueError("Query allele cannot be NONCOPY.") - if emission_matrix.shape[1] != 2: + if emission_matrix.shape[1] != 4 or emission_matrix.shape[2] != 4: raise ValueError("Emission probability matrix has incorrect shape.") if ref_allele == NONCOPY: return 0.0 elif query_allele == MISSING: return 1.0 else: - return emission_matrix[ref_allele, query_allele] + return emission_matrix[site, ref_allele, query_allele] @jit.numba_njit @@ -343,22 +343,25 @@ def get_emission_matrix_hkylike(mu, kappa=None): :param float mu: Probability of mutation to any allele. :param float kappa: Transition-to-transversion rate ratio. """ - if kappa <= 0: - raise ValueError("Ts/tv ratio must be positive.") - # Assume that ACGT are encoded as 0 to 3. - num_alleles = 4 - emission_matrix = np.zeros((num_alleles, num_alleles), dtype=np.float64) - 1 - for i in range(num_alleles): + if kappa is not None: + if kappa <= 0: + raise ValueError("Ts/tv ratio must be positive.") + num_sites = len(mu) + num_alleles = 4 # Assume that ACGT are encoded as 0 to 3. + emission_matrix = np.zeros((num_sites, num_alleles, num_alleles), dtype=np.float64) - 1 + for i in range(num_sites): for j in range(num_alleles): - if i == j: - emission_matrix[i, j] = 1 - mu - else: - emission_matrix[i, j] = mu / 3 - if kappa is not None: - # Transitions: A <-> G, C <-> T. - is_transition = (i in [0, 2] and j in [0, 2]) or (i in [1, 3] and j in [1, 3]) - if is_transition: - emission_matrix[i, j] *= kappa + for k in range(num_alleles): + if j == k: + emission_matrix[i, j, k] = 1 - mu[i] + else: + emission_matrix[i, j, k] = mu[i] / 3 + if kappa is not None: + # Transitions: A <-> G, C <-> T. + is_transition_AG = i in [0, 2] and j in [0, 2] + is_transition_CT = i in [1, 3] and j in [1, 3] + if is_transition_AG or is_transition_CT: + emission_matrix[i, j, k] *= kappa return emission_matrix