From 7390f39b387e411fcbb6e86577283e53ab31b84b Mon Sep 17 00:00:00 2001 From: jurjen93 Date: Fri, 18 Oct 2024 09:24:38 +0200 Subject: [PATCH] subtract upd --- subtract/subtract_with_wsclean.py | 114 ++++++++++++------------------ 1 file changed, 45 insertions(+), 69 deletions(-) diff --git a/subtract/subtract_with_wsclean.py b/subtract/subtract_with_wsclean.py index 8a5ba5d..eed7533 100644 --- a/subtract/subtract_with_wsclean.py +++ b/subtract/subtract_with_wsclean.py @@ -13,7 +13,7 @@ import tables from astropy.io import fits from astropy.wcs import WCS -import casacore.tables as ct +from casacore.tables import table, taql def add_trailing_zeros(s, digitsize=4): @@ -216,14 +216,13 @@ def isfulljones(h5: str = None): class SubtractWSClean: - def __init__(self, mslist: list = None, region: str = None, localnorth: bool = True, onlyprint: bool = False, inverse: bool = False): + def __init__(self, mslist: list = None, region: str = None, localnorth: bool = True, inverse: bool = False): """ Subtract image with WSClean :param mslist: measurement set list :param region: region file to mask :param model_image: model image - :param onlyprint: print only the commands (for testing purposes) """ # list with MS @@ -234,8 +233,6 @@ def __init__(self, mslist: list = None, region: str = None, localnorth: bool = T hdu = fits.open(model_images[0]) self.imshape = (hdu[0].header['NAXIS1'], hdu[0].header['NAXIS2']) - self.onlyprint = onlyprint - if len(glob('*-????-model-pb.fits')) >= 1: self.model_images = glob('*-????-model-pb.fits') elif len(glob('*-????-model.fits')) >= 1: @@ -268,9 +265,8 @@ def clean_model_images(self): freqs = [] for ms in self.mslist: - t = ct.table(ms + "::SPECTRAL_WINDOW", ack=False) - freqs += list(t.getcol("CHAN_FREQ")[0]) - t.close() + with table(ms + "::SPECTRAL_WINDOW", ack=False) as t: + freqs += list(t.getcol("CHAN_FREQ")[0]) self.fmax_ms = max(freqs) self.fmin_ms = min(freqs) model_images = glob('*-model*.fits') @@ -401,9 +397,8 @@ def mask_region(self, region_cube: bool = False): print('Mask ' + fits_model) - if not self.onlyprint: - hdu = fits.open(fits_model) + with fits.open(fits_model) as hdu: b = not self.inverse @@ -419,8 +414,6 @@ def mask_region(self, region_cube: bool = False): hdu[0].data[0][0][np.where(manualmask == b)] = 0.0 hdu.writeto(fits_model, overwrite=True) - # hdu.close() - return self def subtract_col(self, out_column: str = None): @@ -432,56 +425,40 @@ def subtract_col(self, out_column: str = None): for ms in self.mslist: print('Subtract ' + ms) - with ct.table(ms, readonly=False, ack=False) as ts: + with table(ms, readonly=False, ack=False) as ts: colnames = ts.colnames() if "MODEL_DATA" not in colnames: sys.exit( f"ERROR: MODEL_DATA does not exist in {ms}.\nThis is most likely due to a failed predict step.") - if not self.onlyprint: - if out_column not in colnames: - # get column description from DATA - desc = ts.getcoldesc('DATA') - # create output column - print('Create ' + out_column) - desc['name'] = out_column - # create template for output column - ts.addcols(desc) - - else: - print(out_column, ' already exists') - - # get number of rows - nrows = ts.nrows() + if out_column not in colnames: + # get column description from DATA + desc = ts.getcoldesc('DATA') + # create output column + print('Create ' + out_column) + desc['name'] = out_column + # create template for output column + ts.addcols(desc) - # make sure every slice has the same size - best_slice = get_largest_divider(nrows, 1000) - - if self.inverse: - sign = '+' else: - sign = '-' + print(out_column, ' already exists') - if 'SUBTRACT_DATA' in colnames: - colmn = 'SUBTRACT_DATA' - elif 'CORRECTED_DATA' in colnames: - colmn = 'CORRECTED_DATA' - else: - colmn = 'DATA' - for c in range(0, nrows, best_slice): + if self.inverse: + sign = '+' + else: + sign = '-' - if c == 0: - print(f'Output --> {colmn} {sign} MODEL_DATA') + if 'SUBTRACT_DATA' in colnames: + colmn = 'SUBTRACT_DATA' + else: + colmn = 'DATA' - if not self.onlyprint: - data = ts.getcol(colmn, startrow=c, nrow=best_slice) + print(f'Output --> {colmn} {sign} MODEL_DATA') - if not self.onlyprint: - model = ts.getcol('MODEL_DATA', startrow=c, nrow=best_slice) - ts.putcol(out_column, data - model if not self.inverse else data + model, startrow=c, nrow=best_slice) - ts.removecols(['MODEL_DATA']) + taql(f"UPDATE {ms} SET {out_column}={colmn}{sign}MODEL_DATA") + taql(f"ALTER TABLE {ms} DROP COLUMN MODEL_DATA") return self @@ -545,8 +522,7 @@ def predict(self, h5parm: str = None, facet_regions: str = None): predict_cmd.write('\n'.join(command)) predict_cmd.close() - if not self.onlyprint: - os.system(' '.join(command) + ' > log_predict.log') + os.system(' '.join(command) + ' > log_predict.log') return self @@ -724,8 +700,7 @@ def run_DP3(self, phaseshift: str = None, freqavg: str = None, print(f"Make subtract_concat.ms") - if not self.onlyprint: - os.system(' '.join(command) + " > dp3.subtract.log") + os.system(' '.join(command) + " > dp3.subtract.log") else: msout = [] for n, ms in enumerate(self.mslist): @@ -738,8 +713,7 @@ def run_DP3(self, phaseshift: str = None, freqavg: str = None, print(f"Make sub{self.scale}_{ms}") - if not self.onlyprint: - os.system(' '.join(command + [f'msin={ms}', f'msout=sub{self.scale}_{ms}']) + f" > dp3.sub{n}.log") + os.system(' '.join(command + [f'msin={ms}', f'msout=sub{self.scale}_{ms}']) + f" > dp3.sub{n}.log") msout.append(f'sub{self.scale}_{ms}') return msout @@ -761,12 +735,11 @@ def parse_args(): parser.add_argument('--facets_predict', type=str, help='facet region file for prediction') parser.add_argument('--phasecenter', type=str, help='phaseshift to given point (example: --phaseshift 16h06m07.61855,55d21m35.4166)') parser.add_argument('--freqavg', type=str, help='frequency averaging') - parser.add_argument('--timeres', type=str, help='time resolution averaging in secondsZ') + parser.add_argument('--timeres', type=str, help='time resolution averaging in seconds') parser.add_argument('--concat', action='store_true', help='concat MS') parser.add_argument('--applybeam', action='store_true', help='apply beam in phaseshift center or center of field') parser.add_argument('--applycal', action='store_true', help='applycal after subtraction and phaseshifting') parser.add_argument('--applycal_h5', type=str, help='applycal solution file') - parser.add_argument('--print_only_commands', action='store_true', help='only print commands for testing purposes') parser.add_argument('--forwidefield', action='store_true', help='will search for the polygon_info.csv file to extract information from') parser.add_argument('--skip_predict', action='store_true', help='skip predict and do only subtract') parser.add_argument('--even_time_avg', action='store_true', help='(only if --forwidefield) only allow even time averaging (in case of stacking nights with different averaging)') @@ -826,14 +799,12 @@ def main(): else: sys.exit('ERROR: using --forwidefield option needs polygon_info.csv file to read polygon information from') - t = ct.table(args.mslist[0] + "::SPECTRAL_WINDOW", ack=False) - channum = len(t.getcol("CHAN_FREQ")[0]) - t.close() + with table(args.mslist[0] + "::SPECTRAL_WINDOW", ack=False) as t: + channum = len(t.getcol("CHAN_FREQ")[0]) - t = ct.table(args.mslist[0], ack=False) - time = np.unique(t.getcol("TIME")) - dtime = abs(time[1] - time[0]) - t.close() + with table(args.mslist[0], ack=False) as t: + time = np.unique(t.getcol("TIME")) + dtime = abs(time[1] - time[0]) polygon = polygon_info.loc[polygon_info.polygon_file == args.region.split('/')[-1]] try: @@ -879,11 +850,12 @@ def main(): # mkdir and copy files command = [f'mkdir -p {runpath}', - f'cp *.fits {runpath}', - f'cp {args.region} {runpath}'] + f'cp *-model*.fits {runpath}'] + if args.region is not None: + command += [f'cp {args.region} {runpath}'] command += [f'rsync -a --no-perms {dataset} {runpath}' for dataset in args.mslist] # when running with scratch + toil, the next commands are to clean up the tmp* files - command += ['rm *.fits', f'rm -rf {args.model_image_folder}'] + command += ['rm *-model*.fits', f'rm -rf {args.model_image_folder}'] # command += [f'rm -rf {dataset}' for dataset in args.mslist] os.system('&&'.join(command)) outpath = os.getcwd() @@ -898,7 +870,6 @@ def main(): subpred = SubtractWSClean(mslist=args.mslist if not args.scratch_toil else [ms.split('/')[-1] for ms in args.mslist], region=args.region if not args.scratch_toil else args.region.split('/')[-1], localnorth=not args.no_local_north, - onlyprint=args.print_only_commands, inverse=args.inverse) if not args.skip_predict: @@ -961,12 +932,17 @@ def main(): if args.scratch_toil: # copy averaged MS back to output folder - for ms in msout: os.system(f'cp -r {ms} {outpath}/{dirname.replace("Dir","facet_")}-{ms.split("/")[-1]}') + for ms in msout: os.system(f'rsync -a --no-perms {outpath}/{dirname.replace("Dir","facet_")}-{ms.split("/")[-1]}') # clean up scratch directory (for big MS) os.system(f'cp *.log {outpath} && rm -rf *.ms') os.chdir(outpath) print(f'DONE: See output --> {dirname.replace("Dir","facet_")}-*.ms') + elif args.scratch_toil: + # copy back the subtracted MS to the output path + for ms in subpred.mslist: os.system(f'rsync -a --no-perms {ms} {outpath}') + os.system(f'cp *.log {outpath}') + os.chdir(outpath) else: print(f"DONE: Output is SUBTRACT_DATA column in input MS")