-
Notifications
You must be signed in to change notification settings - Fork 1
/
localScript.py
199 lines (156 loc) · 7.02 KB
/
localScript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from secrets import choice
import subprocess
import sys
import os
import argparse
import glob
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from astropy.time import Time
from fit_utils import get_bestfit_lightcurve, parse_csv
from astropy.time import Time
from nmma.em.model import SVDLightCurveModel, GRBLightCurveModel, KilonovaGRBLightCurveModel, SupernovaGRBLightCurveModel
from nmma.em.utils import loadEvent, getFilteredMag
import seaborn as sns
## Goal of this script is to manually execute model fits of specific csv files and save the results to a folder.
parser = argparse.ArgumentParser()
parser.add_argument("-d","--dataDir", type=str, default=None)
## Currently not in use: allows for manual fitting of a specific candidate or candidates
parser.add_argument("-c","--candidate", nargs="+", type=str, default=None)
## Would have to pass args.models as model_list when submitting jobs
## Would have to pass when executing fit bot as " ".join(f'"{m}"' for m in args.models)
parser.add_argument("-m","--models", nargs="+", type=str, default = ["TrPi2018","nugent-hyper", "Piro2021","Bu2019lm"], choices = ["TrPi2018","nugent-hyper", "Piro2021","Bu2019lm"])
parser.add_argument("--svdmodels", type=str, default="/home/cough052/shared/NMMA/svdmodels", help="Path to the SVD models. Note: Not present in the repo, need to be aquired separately (Files are very large)")
parser.add_argument("--nlive", type=int, default=256, help="Number of live points to use")
parser.add_argument("-p","--prior", type=str, default=None, help="path to manual prior file (Note: expect issues if trying to use one prior for multiple models)") ## should add ability to pass multiple priors for multiple models (Dict?)
parser.add_argument("-v","--verbose", action="store_true", help="whether to use verbose output")
## where to output plots
parser.add_argument('-o',"--outdir",type=str,default='./outdir/')
parser.add_argument('-s',"--sampler",type=str,default='pymultinest',choices=['pymultinest','dynesty'])
args = parser.parse_args()
og_directory = os.getcwd()
outdir = args.outdir
svd_path = args.svdmodels
if not args.dataDir and not args.candidate:
print("Please pass --dataDir (-d) and/or --candidate (-c)")
sys.exit()
if not args.models:
print("No --models argument: Fitting to all models")
model_list = args.models
verbose = args.verbose
# job_name = {"Bu2019lm": "KNjob.txt",
# "TrPi2018": "GRBjob.txt",
# "nugent-hyper": "SNjob.txt",
# "Piro2021": "SCjob.txt",}
if not os.path.exists(outdir):
os.makedirs(outdir)
## No check on the number of detections, making assumption that all submissions have a requisite number of detections since this is a manually executed script
if args.dataDir:
lc_data = glob.glob(args.dataDir+'*.dat', recursive=False)
elif args.candidate:
lc_data = list(args.candidate)
## assumes working with .dat files already (see nmma_fit.py for how those are found)
## could probably add that here so this is fully independent of nmma_fit being done
## Random Stuff from nmma_fit.py
svd_mag_ncoeff = 10
svd_lbol_ncoeff = 10
Ebv_max = 0.5724
grb_resolution = 7
jet_type = 0
joint_light_curve = False
sampler = args.sampler
seed = 42
nlive = args.nlive
error_budget = 1.0
t0 = 1
trigger_time_heuristic = False
fit_trigger_time = True
for cand in lc_data: ## hacky way of doing things
print("candidate path: {}".format(cand))
candName = cand.split("/")[-1].split(".")[0]
candDir = os.path.join(outdir,candName,"")
if not os.path.exists(candDir):
os.makedirs(candDir)
candTable = pd.read_table(cand,delimiter=r'\s+', header=None)
for model in model_list:
tmin = 0.01
tmax = 7.01
dt = 0.1
modelDir = os.path.join(candDir,model,"")
if not os.path.exists(modelDir):
os.makedirs(modelDir)
# GRB model requires special values so lightcurves can be generated without NMMA running into timeout errors.
if model == "TrPi2018":
tmin = 0.01
tmax = 7.01
dt = 0.35
if args.prior:
prior = args.prior
print('Using manual prior file: {}'.format(prior))
elif model == 'nugent-hyper':
# SN
if fit_trigger_time:
prior = './priors/ZTF_sn_t0.prior'
else:
prior = './priors/ZTF_sn.prior'
elif model == 'TrPi2018':
# GRB
if fit_trigger_time:
prior = './priors/ZTF_grb_t0.prior'
else:
prior = './priors/ZTF_grb.prior'
elif model == 'Piro2021':
# Shock cooling
if fit_trigger_time:
prior = './priors/ZTF_sc_t0.prior'
else:
prior = './priors/ZTF_sc.prior'
elif model == 'Bu2019lm':
# KN
if fit_trigger_time:
prior = './priors/ZTF_kn_t0.prior'
else:
prior = './priors/ZTF_kn.prior'
else:
print("nmma_fitter does not know of the prior file for model "+ model)
exit(1)
if fit_trigger_time:
# Set to earliest detection in preparation for fit
# Need to search the whole file since they are not always ordered.
trigger_time = np.inf
for index, row in candTable.iterrows():
if np.isinf(float(row[3])):
continue
elif Time(row[0], format='isot').mjd < trigger_time:
trigger_time = Time(row[0], format='isot').mjd
elif trigger_time_heuristic:
# One day before the first non-zero point
trigger_time = np.inf
for index, row in candTable.iterrows():
if np.isinf(float(row[3])):
continue
elif (Time(row[0], format='isot').mjd - 1) < trigger_time:
trigger_time = Time(row[0], format='isot').mjd - 1
else:
# Set the trigger time
trigger_time = t0
print("trigger time: {}".format(trigger_time))
print("model directory: {}".format(modelDir))
command_string = " light_curve_analysis"\
+ " --model " + model + " --svd-path " + svd_path + " --outdir " + modelDir\
+ " --label " + str(candName+"_"+model+"_nlive"+str(nlive))\
+ " --trigger-time " + str(trigger_time)\
+ " --data " + cand + " --prior " + prior + " --tmin " + str(tmin)\
+ " --tmax " + str(tmax) + " --dt " + str(dt) + " --error-budget " + str(error_budget)\
+ " --nlive " + str(nlive) + " --Ebv-max " + str(Ebv_max)\
+ " --detection-limit" +" \"{\'r\':23.5, \'g\':23.5, \'i\':23.5}\""\
+ " --plot"\
+ " --sampler " + str(sampler)#+ " --verbose"
if verbose:
command_string = command_string + " --verbose"
command = subprocess.run(command_string, shell=True, capture_output=True)
sys.stdout.buffer.write(command.stdout)
sys.stderr.buffer.write(command.stderr)