Skip to content

Commit

Permalink
Fixed bug in mutation rate Truncated-Normal model
Browse files Browse the repository at this point in the history
  • Loading branch information
niemasd committed May 3, 2024
1 parent 567b93d commit bdb3990
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion global.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"VERSION": "1.0.1",
"VERSION": "1.0.2",
"CONFIG_KEYS": ["Contact Network", "Transmission Network", "Sample Times", "Viral Phylogeny (Transmissions)", "Viral Phylogeny (Seeds)", "Mutation Rates", "Ancestral Sequence", "Sequence Evolution"],
"DESC": {
"Contact Network": "The <b style='color:red;'>Contact Network</b> graph model describes all social interactions:<ul><li>Nodes represent individuals in the population</li><li>Edges represent all interactions across which the pathogen can transmit</li><li>Currently, FAVITES-Lite only supports static (i.e., unchanging) contact networks</li></ul>",
Expand Down
17 changes: 17 additions & 0 deletions plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
#! /usr/bin/env python3
# standard imports
from datetime import datetime
from sys import stderr
import math

# constants
ZERO_THRESH = 0.00000000001

# non-standard imports
try:
from scipy.stats import truncnorm
except:
error("Unable to import scipy. Install with: pip install scipy")

# dummy plugin function
def DUMMY_PLUGIN_FUNC(params, out_fn, config, GLOBAL, verbose=True):
pass
Expand All @@ -28,3 +37,11 @@ def check_props(props):
return False
tot += p
return abs(tot - 1) <= ZERO_THRESH

# sample from a truncated normal distribution with (non-truncated) mean `loc` and (non-truncated) stdev `scale` in range [`a`,`b`]
# I'm using the Wikipedia notation: https://en.wikipedia.org/wiki/Truncated_normal_distribution
# SciPy's `truncnorm` defines `a` and `b` as "standard deviations above/below `loc`", so I need to convert
def truncnorm_rvs(loc, scale, a_min, b_max, size):
a = (a_min - loc) / scale
b = (b_max - loc) / scale
return truncnorm.rvs(a=a, b=b, loc=loc, scale=scale, size=size)
10 changes: 4 additions & 6 deletions plugins/mutation_rates/common_treeswift.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
from numpy.random import f as f_dist
except:
error("Unable to import numpy. Install with: pip install numpy")
try:
from scipy.stats import truncnorm
except:
error("Unable to import scipy. Install with: pip install scipy")
try:
from treeswift import read_tree_newick
except:
Expand Down Expand Up @@ -216,9 +212,11 @@ def treeswift_triangular(params, out_fn, config, GLOBAL, verbose=True):

# Truncated Normal
def treeswift_truncnorm(params, out_fn, config, GLOBAL, verbose=True):
tree = read_tree_newick(out_fn['viral_phylogeny_time']); mu = params['mu']; sigma = params['sigma']; a = params['a']; b = params['b']
mu = params['mu']; sigma = params['sigma']; a_min = params['a']; b_max = params['b']
tree = read_tree_newick(out_fn['viral_phylogeny_time'])
nodes = [node for node in tree.traverse_preorder() if node.edge_length is not None]
rates = truncnorm.rvs(a=a, b=b, loc=mu, scale=sigma, size=len(nodes))
rates = truncnorm_rvs(loc=mu, scale=sigma, a_min=a_min, b_max=b_max, size=len(nodes))
print('\n'.join(str(r) for r in rates)); exit() # TODO
for i in range(len(nodes)):
nodes[i].edge_length *= rates[i]
tree.write_tree_newick(out_fn['viral_phylogeny_mut'])
Expand Down
5 changes: 2 additions & 3 deletions plugins/sample_times/time_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
except:
error("Unable to import numpy. Install with: pip install numpy")
try:
from scipy.stats import truncexpon, truncnorm
from scipy.stats import truncexpon
except:
error("Unable to import scipy. Install with: pip install scipy")

Expand All @@ -28,8 +28,7 @@ def time_windows(model, params, out_fn, config, GLOBAL, verbose=True):
if model == "Truncated Exponential":
variates = list(truncexpon.rvs(1, size=tot_num_samples))
elif model == "Truncated Normal":
corrected_min = (0-params['mu'])/params['sigma']; corrected_max = (1-params['mu'])/params['sigma']
variates = list(truncnorm.rvs(corrected_min, corrected_max, loc=params['mu'], scale=params['sigma'], size=tot_num_samples))
variates = list(truncnorm_rvs(loc=params['mu'], scale=params['sigma'], a_min=0, b_max=1, size=tot_num_samples))
for node in windows:
for _ in range(params['num_samples']):
state, start, end = choice(windows[node]); length = end - start; delta = None
Expand Down

0 comments on commit bdb3990

Please sign in to comment.