Skip to content

Commit

Permalink
Added RNG seed support
Browse files Browse the repository at this point in the history
  • Loading branch information
niemasd committed Mar 9, 2024
1 parent 2e2cc79 commit 7052691
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# test files
config.json
test.json
tmp
test*.json
tmp*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
25 changes: 25 additions & 0 deletions favites_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,27 @@
Niema Moshiri 2022
'''

# useful constants
RNG_SEED_MIN = 0
RNG_SEED_MAX = 2147483647

# general imports and load global.json
from os import makedirs, remove
from os.path import abspath, expanduser, isdir, isfile
from shutil import rmtree
from sys import argv, stderr
from time import time
import json
import random
GLOBAL_JSON_PATH = "%s/global.json" % '/'.join(abspath(expanduser(argv[0])).split('/')[:-1])
GLOBAL = json.loads(open(GLOBAL_JSON_PATH).read())

# external imports
try:
import numpy
except:
error("Unable to import numpy. Install with: pip install numpy")

# FAVITES-Lite-specific imports
from plugins import PLUGIN_FUNCTIONS
from plugins.common import *
Expand All @@ -25,16 +36,29 @@ def parse_args():
parser.add_argument('-c', '--config', required=True, type=str, help="FAVITES-Lite Config File")
parser.add_argument('-o', '--output', required=True, type=str, help="Output Directory")
parser.add_argument('--overwrite', action="store_true", help="Overwrite output directory if it exists")
parser.add_argument('--rng_seed', required=False, type=int, default=None, help="Random Number Generator Seed")
parser.add_argument('--quiet', action="store_true", help="Suppress Log Messages")
parser.add_argument('--version', action="store_true", help="Show FAVITES-Lite version")
return parser.parse_args()

# validate user args
def validate_args(args, verbose=True):
# check RNG seed
if args.rng_seed is None:
args.rng_seed = random.randint(RNG_SEED_MIN, RNG_SEED_MAX)
elif args.rng_seed < RNG_SEED_MIN or args.rng_seed > RNG_SEED_MAX:
error("Invalid RNG seed (%s). Must be in the range [%d, %d]" % (args.rng_seed, RNG_SEED_MIN, RNG_SEED_MAX))
GLOBAL['RNG_SEED'] = args.rng_seed; random.seed(GLOBAL['RNG_SEED']); numpy.random.seed(GLOBAL['RNG_SEED'])
if verbose:
print_log("RNG Seed: %d" % GLOBAL['RNG_SEED'])

# check config file
if not isfile(args.config):
error("Config file not found: %s" % args.config)
if verbose:
print_log("Config File: %s" % args.config)

# check output directory
if isdir(args.output) or isfile(args.output):
if args.overwrite or input('Output directory exists: "%s". Overwrite? (Y/N) ' % args.output).upper().startswith('Y'):
if verbose:
Expand Down Expand Up @@ -74,6 +98,7 @@ def validate_config(config):
args = parse_args(); verbose = not args.quiet
if verbose:
print_log("=== FAVITES-Lite v%s ===" % GLOBAL['VERSION'])
print_log("Command: %s" % ' '.join(argv))
validate_args(args, verbose=verbose)
config = json.loads(open(args.config).read()); validate_config(config)
makedirs(args.output); f = open("%s/config.json" % args.output, 'w'); json.dump(config, f); f.close()
Expand Down
3 changes: 1 addition & 2 deletions global.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"VERSION": "0.0.5",
"VERSION": "1.0.0",
"CONFIG_KEYS": ["Contact Network", "Transmission Network", "Sample Times", "Viral Phylogeny (Transmissions)", "Viral Phylogeny (Seeds)", "Mutation Rates", "Ancestral Sequence", "Sequence Evolution"],
"TYPES": ["positive integer"],
"DESC": {
"Contact Network": "The <b style='color:red;'>Contact Network</b> graph model describes all social interactions:<ul><li>Nodes represent individuals in the population</li><li>Edges represent all interactions across which the pathogen can transmit</li><li>Currently, FAVITES-Lite only supports static (i.e., unchanging) contact networks</li></ul>",
"Transmission Network": "The <b style='color:red;'>Transmission Network</b> compartmental model describes how the pathogen propagates along the contact network. State transition rates are in unit of Poisson process arrivals per time (reciprocal of expected time to next arrival). The user will specify the number of individuals in each compartment, and individuals will be placed into each compartment uniformly.",
Expand Down
4 changes: 3 additions & 1 deletion plugins/contact_network/ngg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#! /usr/bin/env python3
from .. import *
from os import environ
from subprocess import call

# simulate contact network using NiemaGraphGen
def ngg(exe, params, out_fn, config, GLOBAL, verbose=True):
env = dict(environ); env['NGG_RNG_SEED'] = str(GLOBAL['RNG_SEED'])
if exe == 'ngg_barabasi_albert':
command = [exe, str(params['n']), str(params['m'])]
elif exe == 'ngg_barbell':
Expand All @@ -22,7 +24,7 @@ def ngg(exe, params, out_fn, config, GLOBAL, verbose=True):
print_log("Command: %s" % ' '.join(command))
f = open(out_fn['contact_network'], 'w')
try:
call(command, stdout=f)
call(command, stdout=f, env=env)
except FileNotFoundError as e:
error("Unable to run NiemaGraphGen. Make sure all ngg_* executables are in your PATH (e.g. /usr/local/bin)")
f.close()
Expand Down
7 changes: 6 additions & 1 deletion plugins/sequence_evolution/seqgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def seqgen(mode, params, out_fn, config, GLOBAL, verbose=True):
seqgen_tree_fn = "%s/seqgen.phy" % out_fn['intermediate']
seqgen_log_fn = "%s/seqgen.log" % out_fn['intermediate']
f = open(seqgen_tree_fn, 'w'); f.write("1 %d\nROOT %s\n1\n%s" % (len(root_seq),root_seq,treestr)); f.close()
command = ['seq-gen', '-of', '-k1']
command = [
'seq-gen',
'-of',
'-k1',
'-z', str(GLOBAL['RNG_SEED']),
]
if mode in {'GTR', 'GTR+G', 'GTR+Codon'}: # GTR model
command += ['-m', 'GTR']
if mode in {'GTR', 'GTR+G', 'GTR+Codon'}: # add base frequencies
Expand Down
1 change: 1 addition & 0 deletions plugins/transmission_network/gemf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def gemf_favites(model, params, out_fn, config, GLOBAL, verbose=True):
'-t', str(params['duration']),
'-o', gemf_out,
'--output_all_transitions',
'--rng_seed', str(GLOBAL['RNG_SEED']),
'--quiet',
]
if verbose:
Expand Down
5 changes: 3 additions & 2 deletions plugins/viral_phylogeny_trans/coatran.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#! /usr/bin/env python3
from .. import *
from os import stat
from os import environ, stat
from subprocess import call
try:
from treeswift import read_tree_newick
Expand All @@ -9,6 +9,7 @@

# simulate a coalescent viral phylogeny using CoaTran
def coatran(exe, params, out_fn, config, GLOBAL, verbose=True):
env = dict(environ); env['COATRAN_RNG_SEED'] = str(GLOBAL['RNG_SEED'])
if exe in {'coatran_inftime', 'coatran_transtree'}:
command = [exe, out_fn['transmission_network'], out_fn['sample_times']]
elif exe == 'coatran_constant':
Expand All @@ -19,7 +20,7 @@ def coatran(exe, params, out_fn, config, GLOBAL, verbose=True):
print_log("Command: %s" % ' '.join(command))
f = open(out_fn['viral_phylogeny_all_chains_time'], 'w')
try:
call(command, stdout=f)
call(command, stdout=f, env=env)
except FileNotFoundError as e:
error("Unable to run CoaTran. Make sure all coatran_* executables are in your PATH (e.g. /usr/local/bin)")
f.close()
Expand Down

0 comments on commit 7052691

Please sign in to comment.