From 7052691a77db3df74289a07dc0f2e4da6ad2e6f4 Mon Sep 17 00:00:00 2001 From: Niema Moshiri Date: Sat, 9 Mar 2024 15:34:19 -0800 Subject: [PATCH] Added RNG seed support --- .gitignore | 4 ++-- favites_lite.py | 25 ++++++++++++++++++++++++ global.json | 3 +-- plugins/contact_network/ngg.py | 4 +++- plugins/sequence_evolution/seqgen.py | 7 ++++++- plugins/transmission_network/gemf.py | 1 + plugins/viral_phylogeny_trans/coatran.py | 5 +++-- 7 files changed, 41 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index dd291dc..5397ba9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ # test files config.json -test.json -tmp +test*.json +tmp* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/favites_lite.py b/favites_lite.py index 6fd966a..870984b 100755 --- a/favites_lite.py +++ b/favites_lite.py @@ -4,6 +4,10 @@ Niema Moshiri 2022 ''' +# useful constants +RNG_SEED_MIN = 0 +RNG_SEED_MAX = 2147483647 + # general imports and load global.json from os import makedirs, remove from os.path import abspath, expanduser, isdir, isfile @@ -11,9 +15,16 @@ from sys import argv, stderr from time import time import json +import random GLOBAL_JSON_PATH = "%s/global.json" % '/'.join(abspath(expanduser(argv[0])).split('/')[:-1]) GLOBAL = json.loads(open(GLOBAL_JSON_PATH).read()) +# external imports +try: + import numpy +except: + error("Unable to import numpy. Install with: pip install numpy") + # FAVITES-Lite-specific imports from plugins import PLUGIN_FUNCTIONS from plugins.common import * @@ -25,16 +36,29 @@ def parse_args(): parser.add_argument('-c', '--config', required=True, type=str, help="FAVITES-Lite Config File") parser.add_argument('-o', '--output', required=True, type=str, help="Output Directory") parser.add_argument('--overwrite', action="store_true", help="Overwrite output directory if it exists") + parser.add_argument('--rng_seed', required=False, type=int, default=None, help="Random Number Generator Seed") parser.add_argument('--quiet', action="store_true", help="Suppress Log Messages") parser.add_argument('--version', action="store_true", help="Show FAVITES-Lite version") return parser.parse_args() # validate user args def validate_args(args, verbose=True): + # check RNG seed + if args.rng_seed is None: + args.rng_seed = random.randint(RNG_SEED_MIN, RNG_SEED_MAX) + elif args.rng_seed < RNG_SEED_MIN or args.rng_seed > RNG_SEED_MAX: + error("Invalid RNG seed (%s). Must be in the range [%d, %d]" % (args.rng_seed, RNG_SEED_MIN, RNG_SEED_MAX)) + GLOBAL['RNG_SEED'] = args.rng_seed; random.seed(GLOBAL['RNG_SEED']); numpy.random.seed(GLOBAL['RNG_SEED']) + if verbose: + print_log("RNG Seed: %d" % GLOBAL['RNG_SEED']) + + # check config file if not isfile(args.config): error("Config file not found: %s" % args.config) if verbose: print_log("Config File: %s" % args.config) + + # check output directory if isdir(args.output) or isfile(args.output): if args.overwrite or input('Output directory exists: "%s". Overwrite? (Y/N) ' % args.output).upper().startswith('Y'): if verbose: @@ -74,6 +98,7 @@ def validate_config(config): args = parse_args(); verbose = not args.quiet if verbose: print_log("=== FAVITES-Lite v%s ===" % GLOBAL['VERSION']) + print_log("Command: %s" % ' '.join(argv)) validate_args(args, verbose=verbose) config = json.loads(open(args.config).read()); validate_config(config) makedirs(args.output); f = open("%s/config.json" % args.output, 'w'); json.dump(config, f); f.close() diff --git a/global.json b/global.json index f3740f0..4f7d7c3 100644 --- a/global.json +++ b/global.json @@ -1,7 +1,6 @@ { - "VERSION": "0.0.5", + "VERSION": "1.0.0", "CONFIG_KEYS": ["Contact Network", "Transmission Network", "Sample Times", "Viral Phylogeny (Transmissions)", "Viral Phylogeny (Seeds)", "Mutation Rates", "Ancestral Sequence", "Sequence Evolution"], - "TYPES": ["positive integer"], "DESC": { "Contact Network": "The Contact Network graph model describes all social interactions:", "Transmission Network": "The Transmission Network compartmental model describes how the pathogen propagates along the contact network. State transition rates are in unit of Poisson process arrivals per time (reciprocal of expected time to next arrival). The user will specify the number of individuals in each compartment, and individuals will be placed into each compartment uniformly.", diff --git a/plugins/contact_network/ngg.py b/plugins/contact_network/ngg.py index f3e1a65..b581595 100644 --- a/plugins/contact_network/ngg.py +++ b/plugins/contact_network/ngg.py @@ -1,9 +1,11 @@ #! /usr/bin/env python3 from .. import * +from os import environ from subprocess import call # simulate contact network using NiemaGraphGen def ngg(exe, params, out_fn, config, GLOBAL, verbose=True): + env = dict(environ); env['NGG_RNG_SEED'] = str(GLOBAL['RNG_SEED']) if exe == 'ngg_barabasi_albert': command = [exe, str(params['n']), str(params['m'])] elif exe == 'ngg_barbell': @@ -22,7 +24,7 @@ def ngg(exe, params, out_fn, config, GLOBAL, verbose=True): print_log("Command: %s" % ' '.join(command)) f = open(out_fn['contact_network'], 'w') try: - call(command, stdout=f) + call(command, stdout=f, env=env) except FileNotFoundError as e: error("Unable to run NiemaGraphGen. Make sure all ngg_* executables are in your PATH (e.g. /usr/local/bin)") f.close() diff --git a/plugins/sequence_evolution/seqgen.py b/plugins/sequence_evolution/seqgen.py index 84211ce..85da989 100644 --- a/plugins/sequence_evolution/seqgen.py +++ b/plugins/sequence_evolution/seqgen.py @@ -22,7 +22,12 @@ def seqgen(mode, params, out_fn, config, GLOBAL, verbose=True): seqgen_tree_fn = "%s/seqgen.phy" % out_fn['intermediate'] seqgen_log_fn = "%s/seqgen.log" % out_fn['intermediate'] f = open(seqgen_tree_fn, 'w'); f.write("1 %d\nROOT %s\n1\n%s" % (len(root_seq),root_seq,treestr)); f.close() - command = ['seq-gen', '-of', '-k1'] + command = [ + 'seq-gen', + '-of', + '-k1', + '-z', str(GLOBAL['RNG_SEED']), + ] if mode in {'GTR', 'GTR+G', 'GTR+Codon'}: # GTR model command += ['-m', 'GTR'] if mode in {'GTR', 'GTR+G', 'GTR+Codon'}: # add base frequencies diff --git a/plugins/transmission_network/gemf.py b/plugins/transmission_network/gemf.py index e61d1ec..925f5d4 100644 --- a/plugins/transmission_network/gemf.py +++ b/plugins/transmission_network/gemf.py @@ -65,6 +65,7 @@ def gemf_favites(model, params, out_fn, config, GLOBAL, verbose=True): '-t', str(params['duration']), '-o', gemf_out, '--output_all_transitions', + '--rng_seed', str(GLOBAL['RNG_SEED']), '--quiet', ] if verbose: diff --git a/plugins/viral_phylogeny_trans/coatran.py b/plugins/viral_phylogeny_trans/coatran.py index 7c0f15d..6333c61 100644 --- a/plugins/viral_phylogeny_trans/coatran.py +++ b/plugins/viral_phylogeny_trans/coatran.py @@ -1,6 +1,6 @@ #! /usr/bin/env python3 from .. import * -from os import stat +from os import environ, stat from subprocess import call try: from treeswift import read_tree_newick @@ -9,6 +9,7 @@ # simulate a coalescent viral phylogeny using CoaTran def coatran(exe, params, out_fn, config, GLOBAL, verbose=True): + env = dict(environ); env['COATRAN_RNG_SEED'] = str(GLOBAL['RNG_SEED']) if exe in {'coatran_inftime', 'coatran_transtree'}: command = [exe, out_fn['transmission_network'], out_fn['sample_times']] elif exe == 'coatran_constant': @@ -19,7 +20,7 @@ def coatran(exe, params, out_fn, config, GLOBAL, verbose=True): print_log("Command: %s" % ' '.join(command)) f = open(out_fn['viral_phylogeny_all_chains_time'], 'w') try: - call(command, stdout=f) + call(command, stdout=f, env=env) except FileNotFoundError as e: error("Unable to run CoaTran. Make sure all coatran_* executables are in your PATH (e.g. /usr/local/bin)") f.close()