-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFES_from_State.py
executable file
·345 lines (337 loc) · 12.4 KB
/
FES_from_State.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#! /usr/bin/env python3
### usage: python plumed_scripts/FES_from_State.py -f STATE_FILE -o OUT_FILE --temp TEMPERATURE_K
### Get the FES estimate used by OPES, from a dumped state file (STATE_WFILE). 1D or 2D only ###
# usage is similar to plumed sum_hills
import sys
import argparse
import numpy as np
import pandas as pd #much faster reading from file
do_bck=False #requires the bck.meup.sh script
if do_bck:
bck_script='bck.meup.sh' #e.g. place the script in your ~/bin
import subprocess
### Parser stuff ###
parser = argparse.ArgumentParser(description='get the FES estimate used by OPES, from a dumped state file (STATE_WFILE). 1D or 2D only')
# files
parser.add_argument('--state','-f',dest='filename',type=str,default='STATE',help='the state file name, with the compressed kernels')
parser.add_argument('--outfile','-o',dest='outfile',type=str,default='fes.dat',help='name of the output file')
# compulsory
kbt_group = parser.add_mutually_exclusive_group(required=True)
kbt_group.add_argument('--kt',dest='kbt',type=float,help='the temperature in energy units')
kbt_group.add_argument('--temp',dest='temp',type=float,help='the temperature in Kelvin. Energy units is Kj/mol')
# grid related
parser.add_argument('--min',dest='grid_min',type=str,required=False,help='lower bounds for the grid')
parser.add_argument('--max',dest='grid_max',type=str,required=False,help='upper bounds for the grid')
parser.add_argument('--bin',dest='grid_bin',type=str,default="100,100",help='number of bins for the grid')
# other options
parser.add_argument('--fmt',dest='fmt',type=str,default='% 12.6f',help='specify the output format')
parser.add_argument('--deltaFat',dest='deltaFat',type=float,required=False,help='calculate the free energy difference between left and right of given c1 value')
parser.add_argument('--all_stored',dest='all_stored',action='store_true',default=False,help='print all the FES stored instead of only the last one')
parser.add_argument('--nomintozero',dest='nomintozero',action='store_true',default=False,help='do not shift the minimum to zero')
parser.add_argument('--der',dest='der',action='store_true',default=False,help='calculate also FES derivatives')
# some easy parsing
args=parser.parse_args()
filename=args.filename
outfile=args.outfile
if args.kbt is not None:
kbt=args.kbt
else:
kbt=args.temp*0.0083144621
fmt=args.fmt
calc_deltaF=False
if args.deltaFat is not None:
calc_deltaF=True
ts=args.deltaFat
mintozero=(not args.nomintozero)
calc_der=args.der
all_stored=args.all_stored
if all_stored:
if outfile.rfind('/')==-1:
prefix=''
outfile_n=outfile
else:
prefix=outfile[:outfile.rfind('/')]
if prefix+'/'==outfile:
outfile+='fes.dat'
outfile_n=outfile[outfile.rfind('/'):]
if outfile_n.rfind('.')==-1:
suffix=''
else:
suffix=outfile_n[outfile_n.rfind('.'):]
outfile_n=outfile_n[:outfile_n.rfind('.')]
outfile_n=prefix+outfile_n+'_%d'+suffix
explore='unset'
### Get data ###
# get data and check number of stored states
data=pd.read_table(filename,sep='\s+',header=None)
fields_pos=[]
tot_lines=len(data.iloc[:,1])
for i in range(tot_lines):
if data.iloc[i,1]=='FIELDS':
fields_pos.append(i)
if len(fields_pos)==0:
sys.exit(' no FIELDS found in file "'+filename+'"')
if len(fields_pos)>1:
print(' a total of %d stored states where found'%len(fields_pos))
if all_stored:
print(' -> all will be printed')
else:
print(' -> only the last one will be printed. use --all_stored to instead print them all')
fields_pos=[fields_pos[-1]]
fields_pos.append(tot_lines)
for n in range(len(fields_pos)-1):
print(' working... 0% of {:.0%}'.format(n/(len(fields_pos)-1)),end='\r')
l=fields_pos[n]
dim2=False
if len(data.iloc[l,:])==6:
name_cv_x=data.iloc[l,3]
elif len(data.iloc[l,:])==8:
dim2=True
name_cv_x=data.iloc[l,3]
name_cv_y=data.iloc[l,4]
else:
sys.exit(' wrong number of FIELDS in file "'+filename+'": only 1 or 2 dimensional bias are supported')
action=data.iloc[l+1,3]
if action=="OPES_METAD_state":
if explore!='no':
explore='no'
print(' building free energy from OPES_METAD')
elif action=="OPES_METAD_EXPLORE_state":
if explore!='yes':
explore='yes'
print(' building free energy from OPES_METAD_EXPLORE')
else:
sys.exit(' This script works only with OPES_METAD_state and OPES_METAD_EXPLORE_state')
if data.iloc[l+2,2]!='biasfactor':
sys.exit(' biasfactor not found!')
sf=1 #scaling factor for explore mode
if explore=='yes':
sf=float(data.iloc[l+2,3])
if data.iloc[l+3,2]!='epsilon':
sys.exit(' epsilon not found!')
epsilon=float(data.iloc[l+3,3])
if data.iloc[l+4,2]!='kernel_cutoff':
sys.exit(' kernel_cutoff not found!')
cutoff=float(data.iloc[l+4,3])
val_at_cutoff=np.exp(-0.5*cutoff**2)
if data.iloc[l+6,2]!='zed':
sys.exit(' zed not found!')
Zed=float(data.iloc[l+6,3])
if explore=='no':
if data.iloc[l+7,2]!='sum_weights':
sys.exit(' sum_weights not found!')
Zed*=float(data.iloc[l+7,3])
if explore=='yes':
if data.iloc[l+9,2]!='counter':
sys.exit(' counter not found!')
Zed*=float(data.iloc[l+9,3])
l+=10 #there are always at least 10 header lines
# get periodicity
period_x=0
period_y=0
while data.iloc[l,0]=='#!':
if data.iloc[l,2]=='min_'+name_cv_x:
if data.iloc[l,3]=='-pi':
grid_min_x=-np.pi
else:
grid_min_x=float(data.iloc[l,3])
l+=1
if data.iloc[l,2]!='max_'+name_cv_x:
sys.exit(' min_%s was found, but not max_%s !'%(name_cv_x,name_cv_x))
if data.iloc[l,3]=='pi':
grid_max_x=np.pi
else:
grid_max_x=float(data.iloc[l,3])
period_x=grid_max_x-grid_min_x
if calc_der:
sys.exit(' derivatives not supported with periodic CVs, remove --der option')
if dim2 and data.iloc[l,2]=='min_'+name_cv_y:
if data.iloc[l,3]=='-pi':
grid_min_y=-np.pi
else:
grid_min_y=float(data.iloc[l,3])
l+=1
if data.iloc[l,2]!='max_'+name_cv_y:
sys.exit(' min_%s was found, but not max_%s !'%(name_cv_y,name_cv_y))
if data.iloc[l,3]=='pi':
grid_max_y=np.pi
else:
grid_max_y=float(data.iloc[l,3])
period_y=grid_max_y-grid_min_y
if calc_der:
sys.exit(' derivatives not supported with periodic CVs, remove --der option')
l+=1
if l==fields_pos[-1]:
sys.exit(' missing data!')
# get kernels
center_x=np.array(data.iloc[l:fields_pos[n+1],1],dtype=float)
if dim2:
center_y=np.array(data.iloc[l:fields_pos[n+1],2],dtype=float)
sigma_x=np.array(data.iloc[l:fields_pos[n+1],3],dtype=float)
sigma_y=np.array(data.iloc[l:fields_pos[n+1],4],dtype=float)
height=np.array(data.iloc[l:fields_pos[n+1],5],dtype=float)
else:
sigma_x=np.array(data.iloc[l:fields_pos[n+1],2],dtype=float)
height=np.array(data.iloc[l:fields_pos[n+1],3],dtype=float)
### Prepare the grid ###
grid_bin_x=int(args.grid_bin.split(',')[0])
if period_x==0:
grid_bin_x+=1 #same as plumed sum_hills
if args.grid_min is None:
if period_x==0: #otherwise is already set
grid_min_x=min(center_x)
else:
if args.grid_min.split(',')[0]=='-pi':
grid_min_x=-np.pi
else:
grid_min_x=float(args.grid_min.split(',')[0])
if args.grid_max is None:
if period_x==0: #otherwise is already set
grid_max_x=max(center_x)
else:
if args.grid_max.split(',')[0]=='pi':
grid_max_x=np.pi
else:
grid_max_x=float(args.grid_max.split(',')[0])
grid_cv_x=np.linspace(grid_min_x,grid_max_x,grid_bin_x)
if dim2:
if len(args.grid_bin.split(','))!=2:
sys.exit('two comma separated integers expected after --bin')
grid_bin_y=int(args.grid_bin.split(',')[1])
if period_y==0:
grid_bin_y+=1 #same as plumed sum_hills
if args.grid_min is None:
if period_y==0: #otherwise is already set
grid_min_y=min(center_y)
else:
if len(args.grid_min.split(','))!=2:
sys.exit('two comma separated floats expected after --min')
if args.grid_min.split(',')[1]=='-pi':
grid_min_y=-np.pi
else:
grid_min_y=float(args.grid_min.split(',')[1])
if args.grid_max is None:
if period_y==0: #otherwise is already set
grid_max_y=max(center_y)
else:
if len(args.grid_max.split(','))!=2:
sys.exit('two comma separated floats expected after --max')
if args.grid_max.split(',')[1]=='pi':
grid_max_y=np.pi
else:
grid_max_y=float(args.grid_max.split(',')[1])
grid_cv_y=np.linspace(grid_min_y,grid_max_y,grid_bin_y)
x,y=np.meshgrid(grid_cv_x,grid_cv_y)
if calc_deltaF and (ts<=grid_min_x or ts>=grid_max_x):
print(' +++ WARNING: the provided --deltaFat is out of the CV grid +++')
calc_deltaF=False
### Calculate FES ###
max_prob=0
if not dim2:
prob=np.zeros(grid_bin_x)
if calc_der:
der_prob_x=np.zeros(grid_bin_x)
for i in range(grid_bin_x):
print(' working... {:.0%} of '.format(i/grid_bin_x),end='\r')
if period_x==0:
dist_x=(grid_cv_x[i]-center_x)/sigma_x
else:
dx=np.absolute(grid_cv_x[i]-center_x)
dist_x=np.minimum(dx,period_x-dx)/sigma_x
kernels_i=height*(np.maximum(np.exp(-0.5*dist_x*dist_x)-val_at_cutoff,0))
prob[i]=np.sum(kernels_i)/Zed+epsilon
if calc_der:
der_prob_x[i]=np.sum(-dist_x/sigma_x*kernels_i)/Zed
if mintozero and prob[i]>max_prob:
max_prob=prob[i]
else:
prob=np.zeros((grid_bin_y,grid_bin_x))
if calc_der:
der_prob_x=np.zeros((grid_bin_y,grid_bin_x))
der_prob_y=np.zeros((grid_bin_y,grid_bin_x))
for i in range(grid_bin_y):
print(' working... {:.0%} of '.format(i/grid_bin_y),end='\r')
for j in range(grid_bin_x):
if period_x==0:
dist_x=(x[i,j]-center_x)/sigma_x
else:
dx=np.absolute(x[i,j]-center_x)
dist_x=np.minimum(dx,period_x-dx)/sigma_x
if period_y==0:
dist_y=(y[i,j]-center_y)/sigma_y
else:
dy=np.absolute(y[i,j]-center_y)
dist_y=np.minimum(dy,period_y-dy)/sigma_y
kernels_ij=height*(np.maximum(np.exp(-0.5*(dist_x**2+dist_y**2))-val_at_cutoff,0))
prob[i,j]=np.sum(kernels_ij)/Zed+epsilon
if calc_der:
der_prob_x[i,j]=np.sum(-dist_x/sigma_x*kernels_ij)/Zed
der_prob_y[i,j]=np.sum(-dist_y/sigma_y*kernels_ij)/Zed
if mintozero and prob[i,j]>max_prob:
max_prob=prob[i,j]
if not mintozero:
max_prob=1
fes=-kbt*sf*np.log(prob/max_prob)
if calc_der:
der_fes_x=-kbt*sf/prob*der_prob_x
if dim2:
der_fes_y=-kbt*sf/prob*der_prob_y
# calculate deltaF
# NB: summing is as accurate as trapz, and logaddexp avoids overflows
if calc_deltaF:
if not dim2:
fesA=-kbt*np.logaddexp.reduce(-1/kbt*fes[grid_cv_x<ts])
fesB=-kbt*np.logaddexp.reduce(-1/kbt*fes[grid_cv_x>ts])
else:
fesA=-kbt*np.logaddexp.reduce(-1/kbt*fes[x<ts])
fesB=-kbt*np.logaddexp.reduce(-1/kbt*fes[x>ts])
deltaF=fesB-fesA
### Print to file ###
# prepare file
if all_stored:
outfile=outfile_n%(n+1)
if do_bck:
cmd=subprocess.Popen(bck_script+' -i '+outfile,shell=True)
cmd.wait()
# actual print
f=open(outfile,'w')
fields='#! FIELDS '+name_cv_x
if dim2:
fields+=' '+name_cv_y
fields+=' file.free'
if calc_der:
fields+=' der_'+name_cv_x
if dim2:
fields+=' der_'+name_cv_y
f.write(fields+'\n')
if calc_deltaF:
f.write('#! SET DeltaF %g\n'%(deltaF))
f.write('#! SET min_'+name_cv_x+' %g\n'%(grid_min_x))
f.write('#! SET max_'+name_cv_x+' %g\n'%(grid_max_x))
f.write('#! SET nbins_'+name_cv_x+' %g\n'%(grid_bin_x))
if period_x==0:
f.write('#! SET periodic_'+name_cv_x+' false\n')
else:
f.write('#! SET periodic_'+name_cv_x+' true\n')
if not dim2:
for i in range(grid_bin_x):
line=(fmt+' '+fmt)%(grid_cv_x[i],fes[i])
if calc_der:
line+=(' '+fmt)%(der_fes_x[i])
f.write(line+'\n')
else:
f.write('#! SET min_'+name_cv_y+' %g\n'%(grid_min_y))
f.write('#! SET max_'+name_cv_y+' %g\n'%(grid_max_y))
f.write('#! SET nbins_'+name_cv_y+' %g\n'%(grid_bin_y))
if period_y==0:
f.write('#! SET periodic_'+name_cv_y+' false\n')
else:
f.write('#! SET periodic_'+name_cv_y+' true\n')
for i in range(grid_bin_y):
for j in range(grid_bin_x):
line=(fmt+' '+fmt+' '+fmt)%(x[i,j],y[i,j],fes[i,j])
if calc_der:
line+=(' '+fmt+' '+fmt)%(der_fes_x[i,j],der_fes_y[i,j])
f.write(line+'\n')
f.write('\n')
f.close()