-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
71 lines (59 loc) · 3.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#import sys
#sys.path.append("/usr/local/lib/python3.9/site-packages")
from mctsRNA.LoadRNA import LoadData as DataModel
import argparse
import yaml
import logging
from mctsRNA.RNAState import State as State
from mctsRNA.MCTS import MCTS as MCTS
from mctsRNA.DesignedRNA import Sequence as DesignedRNA
from copy import deepcopy
from joblib import Parallel, delayed
args = argparse.ArgumentParser()
args.add_argument('-r', '--root', type=str, help="root dir", default="./")
args.add_argument('-t', '--test_mode', type=int, help="test MCTS", default=1)
args.add_argument('-m', '--rolls', type=int, help="rollouts", default=1)
args.add_argument('-s', '--sims', type=int,help=" MCTS simulations", default=500)
args.add_argument('-l', '--search', type=int,help="local search", default=0)
args.add_argument('-c', '--config', type=str, help="configs", default="./mctsRNA/config.yml")
args.add_argument('-v', '--verbose', type=bool, help="verbose", default=False)
args.add_argument('-x', '--mx_seq', type=int, help="max. seq.", default=200)
args.add_argument('-f', '--freq', type=int, help="print freq.", default=20)
args.add_argument('-d', '--dataset', type=str, help="dataset", default="modena")
args.add_argument('-w', '--workers', type=int, help="num processors", default=13)
args.add_argument('-e', '--interval_iter', type=int, help="sampling", default = 3)
args = args.parse_args()
# get the dataset for the train + the logs
config = yaml.load(open(args.config), Loader=yaml.FullLoader)
data_model = DataModel(**config['data'])
logging.log(logging.INFO, "Dataset loaded")
dataset = data_model.get_dataset(args.dataset)
#if int(args.test_mode) ==1: dataset = dataset[:5]
def seq_processor(seq, sample_iter):
seq_id = dataset.index(seq)
if len(list(seq)) > args.mx_seq:return None
state = State(seq)
mcts = MCTS(state, args.sims, args.search,configs=config)
designed_rna = DesignedRNA(state.target, config)
action_ix = 0
while not state.is_terminal():
if action_ix%args.freq==0:
print(f"current_seq {seq_id+1} action {action_ix+1} of {state.max_seq_len}")
best_action = mcts.tree_search(action_ix=action_ix, verbose=args.verbose)
paired, location = state.paired(action_ix, True)
designed_rna.update(best_action, paired, location)
state.designed = deepcopy(designed_rna.rna_seq)
action_ix += 1
assert(designed_rna.is_terminal() and state.is_terminal() and mcts.state.is_terminal())
designed_rna.write_results(seq_id+1, args.dataset, sample_iter)
def runner(_iter):
# the main loop
Parallel(n_jobs=args.workers, verbose=args.freq)(delayed(seq_processor)(seq, _iter) for seq in dataset)
filename = f"{config['result_path']}{args.dataset}/summary-{_iter}-.csv"
DesignedRNA.write_summary(filename, args.dataset, _iter)
if __name__ == "__main__":
# run the parallelised model to get samples of the size args.interval_iter
for sample_iter in range(args.interval_iter):runner(sample_iter)
# get the final results and the confidence intervals..
DesignedRNA.generate_intervals()
print("Done!!")