Skip to content

Commit

Permalink
subtract upd
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Oct 18, 2024
1 parent d17396f commit 7390f39
Showing 1 changed file with 45 additions and 69 deletions.
114 changes: 45 additions & 69 deletions subtract/subtract_with_wsclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tables
from astropy.io import fits
from astropy.wcs import WCS
import casacore.tables as ct
from casacore.tables import table, taql


def add_trailing_zeros(s, digitsize=4):
Expand Down Expand Up @@ -216,14 +216,13 @@ def isfulljones(h5: str = None):


class SubtractWSClean:
def __init__(self, mslist: list = None, region: str = None, localnorth: bool = True, onlyprint: bool = False, inverse: bool = False):
def __init__(self, mslist: list = None, region: str = None, localnorth: bool = True, inverse: bool = False):
"""
Subtract image with WSClean
:param mslist: measurement set list
:param region: region file to mask
:param model_image: model image
:param onlyprint: print only the commands (for testing purposes)
"""

# list with MS
Expand All @@ -234,8 +233,6 @@ def __init__(self, mslist: list = None, region: str = None, localnorth: bool = T
hdu = fits.open(model_images[0])
self.imshape = (hdu[0].header['NAXIS1'], hdu[0].header['NAXIS2'])

self.onlyprint = onlyprint

if len(glob('*-????-model-pb.fits')) >= 1:
self.model_images = glob('*-????-model-pb.fits')
elif len(glob('*-????-model.fits')) >= 1:
Expand Down Expand Up @@ -268,9 +265,8 @@ def clean_model_images(self):

freqs = []
for ms in self.mslist:
t = ct.table(ms + "::SPECTRAL_WINDOW", ack=False)
freqs += list(t.getcol("CHAN_FREQ")[0])
t.close()
with table(ms + "::SPECTRAL_WINDOW", ack=False) as t:
freqs += list(t.getcol("CHAN_FREQ")[0])
self.fmax_ms = max(freqs)
self.fmin_ms = min(freqs)
model_images = glob('*-model*.fits')
Expand Down Expand Up @@ -401,9 +397,8 @@ def mask_region(self, region_cube: bool = False):

print('Mask ' + fits_model)

if not self.onlyprint:

hdu = fits.open(fits_model)
with fits.open(fits_model) as hdu:

b = not self.inverse

Expand All @@ -419,8 +414,6 @@ def mask_region(self, region_cube: bool = False):
hdu[0].data[0][0][np.where(manualmask == b)] = 0.0
hdu.writeto(fits_model, overwrite=True)

# hdu.close()

return self

def subtract_col(self, out_column: str = None):
Expand All @@ -432,56 +425,40 @@ def subtract_col(self, out_column: str = None):

for ms in self.mslist:
print('Subtract ' + ms)
with ct.table(ms, readonly=False, ack=False) as ts:
with table(ms, readonly=False, ack=False) as ts:
colnames = ts.colnames()

if "MODEL_DATA" not in colnames:
sys.exit(
f"ERROR: MODEL_DATA does not exist in {ms}.\nThis is most likely due to a failed predict step.")

if not self.onlyprint:
if out_column not in colnames:
# get column description from DATA
desc = ts.getcoldesc('DATA')
# create output column
print('Create ' + out_column)
desc['name'] = out_column
# create template for output column
ts.addcols(desc)

else:
print(out_column, ' already exists')

# get number of rows
nrows = ts.nrows()
if out_column not in colnames:
# get column description from DATA
desc = ts.getcoldesc('DATA')
# create output column
print('Create ' + out_column)
desc['name'] = out_column
# create template for output column
ts.addcols(desc)

# make sure every slice has the same size
best_slice = get_largest_divider(nrows, 1000)

if self.inverse:
sign = '+'
else:
sign = '-'
print(out_column, ' already exists')

if 'SUBTRACT_DATA' in colnames:
colmn = 'SUBTRACT_DATA'
elif 'CORRECTED_DATA' in colnames:
colmn = 'CORRECTED_DATA'
else:
colmn = 'DATA'

for c in range(0, nrows, best_slice):
if self.inverse:
sign = '+'
else:
sign = '-'

if c == 0:
print(f'Output --> {colmn} {sign} MODEL_DATA')
if 'SUBTRACT_DATA' in colnames:
colmn = 'SUBTRACT_DATA'
else:
colmn = 'DATA'

if not self.onlyprint:
data = ts.getcol(colmn, startrow=c, nrow=best_slice)
print(f'Output --> {colmn} {sign} MODEL_DATA')

if not self.onlyprint:
model = ts.getcol('MODEL_DATA', startrow=c, nrow=best_slice)
ts.putcol(out_column, data - model if not self.inverse else data + model, startrow=c, nrow=best_slice)
ts.removecols(['MODEL_DATA'])
taql(f"UPDATE {ms} SET {out_column}={colmn}{sign}MODEL_DATA")
taql(f"ALTER TABLE {ms} DROP COLUMN MODEL_DATA")

return self

Expand Down Expand Up @@ -545,8 +522,7 @@ def predict(self, h5parm: str = None, facet_regions: str = None):
predict_cmd.write('\n'.join(command))
predict_cmd.close()

if not self.onlyprint:
os.system(' '.join(command) + ' > log_predict.log')
os.system(' '.join(command) + ' > log_predict.log')

return self

Expand Down Expand Up @@ -724,8 +700,7 @@ def run_DP3(self, phaseshift: str = None, freqavg: str = None,

print(f"Make subtract_concat.ms")

if not self.onlyprint:
os.system(' '.join(command) + " > dp3.subtract.log")
os.system(' '.join(command) + " > dp3.subtract.log")
else:
msout = []
for n, ms in enumerate(self.mslist):
Expand All @@ -738,8 +713,7 @@ def run_DP3(self, phaseshift: str = None, freqavg: str = None,

print(f"Make sub{self.scale}_{ms}")

if not self.onlyprint:
os.system(' '.join(command + [f'msin={ms}', f'msout=sub{self.scale}_{ms}']) + f" > dp3.sub{n}.log")
os.system(' '.join(command + [f'msin={ms}', f'msout=sub{self.scale}_{ms}']) + f" > dp3.sub{n}.log")
msout.append(f'sub{self.scale}_{ms}')

return msout
Expand All @@ -761,12 +735,11 @@ def parse_args():
parser.add_argument('--facets_predict', type=str, help='facet region file for prediction')
parser.add_argument('--phasecenter', type=str, help='phaseshift to given point (example: --phaseshift 16h06m07.61855,55d21m35.4166)')
parser.add_argument('--freqavg', type=str, help='frequency averaging')
parser.add_argument('--timeres', type=str, help='time resolution averaging in secondsZ')
parser.add_argument('--timeres', type=str, help='time resolution averaging in seconds')
parser.add_argument('--concat', action='store_true', help='concat MS')
parser.add_argument('--applybeam', action='store_true', help='apply beam in phaseshift center or center of field')
parser.add_argument('--applycal', action='store_true', help='applycal after subtraction and phaseshifting')
parser.add_argument('--applycal_h5', type=str, help='applycal solution file')
parser.add_argument('--print_only_commands', action='store_true', help='only print commands for testing purposes')
parser.add_argument('--forwidefield', action='store_true', help='will search for the polygon_info.csv file to extract information from')
parser.add_argument('--skip_predict', action='store_true', help='skip predict and do only subtract')
parser.add_argument('--even_time_avg', action='store_true', help='(only if --forwidefield) only allow even time averaging (in case of stacking nights with different averaging)')
Expand Down Expand Up @@ -826,14 +799,12 @@ def main():
else:
sys.exit('ERROR: using --forwidefield option needs polygon_info.csv file to read polygon information from')

t = ct.table(args.mslist[0] + "::SPECTRAL_WINDOW", ack=False)
channum = len(t.getcol("CHAN_FREQ")[0])
t.close()
with table(args.mslist[0] + "::SPECTRAL_WINDOW", ack=False) as t:
channum = len(t.getcol("CHAN_FREQ")[0])

t = ct.table(args.mslist[0], ack=False)
time = np.unique(t.getcol("TIME"))
dtime = abs(time[1] - time[0])
t.close()
with table(args.mslist[0], ack=False) as t:
time = np.unique(t.getcol("TIME"))
dtime = abs(time[1] - time[0])

polygon = polygon_info.loc[polygon_info.polygon_file == args.region.split('/')[-1]]
try:
Expand Down Expand Up @@ -879,11 +850,12 @@ def main():

# mkdir and copy files
command = [f'mkdir -p {runpath}',
f'cp *.fits {runpath}',
f'cp {args.region} {runpath}']
f'cp *-model*.fits {runpath}']
if args.region is not None:
command += [f'cp {args.region} {runpath}']
command += [f'rsync -a --no-perms {dataset} {runpath}' for dataset in args.mslist]
# when running with scratch + toil, the next commands are to clean up the tmp* files
command += ['rm *.fits', f'rm -rf {args.model_image_folder}']
command += ['rm *-model*.fits', f'rm -rf {args.model_image_folder}']
# command += [f'rm -rf {dataset}' for dataset in args.mslist]
os.system('&&'.join(command))
outpath = os.getcwd()
Expand All @@ -898,7 +870,6 @@ def main():
subpred = SubtractWSClean(mslist=args.mslist if not args.scratch_toil else [ms.split('/')[-1] for ms in args.mslist],
region=args.region if not args.scratch_toil else args.region.split('/')[-1],
localnorth=not args.no_local_north,
onlyprint=args.print_only_commands,
inverse=args.inverse)

if not args.skip_predict:
Expand Down Expand Up @@ -961,12 +932,17 @@ def main():

if args.scratch_toil:
# copy averaged MS back to output folder
for ms in msout: os.system(f'cp -r {ms} {outpath}/{dirname.replace("Dir","facet_")}-{ms.split("/")[-1]}')
for ms in msout: os.system(f'rsync -a --no-perms {outpath}/{dirname.replace("Dir","facet_")}-{ms.split("/")[-1]}')
# clean up scratch directory (for big MS)
os.system(f'cp *.log {outpath} && rm -rf *.ms')
os.chdir(outpath)

print(f'DONE: See output --> {dirname.replace("Dir","facet_")}-*.ms')
elif args.scratch_toil:
# copy back the subtracted MS to the output path
for ms in subpred.mslist: os.system(f'rsync -a --no-perms {ms} {outpath}')
os.system(f'cp *.log {outpath}')
os.chdir(outpath)
else:
print(f"DONE: Output is SUBTRACT_DATA column in input MS")

Expand Down

0 comments on commit 7390f39

Please sign in to comment.