Skip to content

Commit

Permalink
splith5
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Aug 27, 2024
1 parent 212b197 commit f366d21
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 86 deletions.
2 changes: 1 addition & 1 deletion ms_helpers/concat_with_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Concat measurement sets with dummies for missing frequency subbands.
Example:
python concat_with_dummies.py --concat_name concat.ms --freq_avg 4 --time_avg 8 *.ms
python concat_with_dummies.py --msout concat.ms --freq_avg 4 --time_avg 8 *.ms
"""

__author__ = "Jurjen de Jong"
Expand Down
175 changes: 90 additions & 85 deletions subtract/subtract_with_wsclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_time_preavg_factor(ms: str = None):
factor = int(float(avg_num))
if factor != 1:
print("WARNING: " + ms + " time has been pre-averaged with factor " + str(
factor) + ". This might cause time smearing effects.")
factor) + ". This might cause stronger time smearing effects in your final image.")
return factor
elif isfloat(avg_num):
factor = float(avg_num)
Expand All @@ -109,6 +109,62 @@ def get_time_preavg_factor(ms: str = None):
return None


def split_facet_h5(h5parm: str = None, dirname: str = None):
"""
Split multi-facet h5parm
:param h5parm: multi-facet h5parm
:param dirname: direction name
"""

outputh5 = f'{h5parm}.{dirname}.h5'
os.system(f'cp {h5parm} {outputh5}')

with tables.open_file(outputh5, 'r+') as outh5:

axes = make_utf8(outh5.root.sol000.phase000.val.attrs["AXES"])

dir_axis = axes.split(',').index('dir')

sources = outh5.root.sol000.source[:]
dirs = [make_utf8(dir) for dir in sources['name']]
dir_idx = dirs.index(dirname)

def get_data(soltab, axis):
return np.take(outh5.root.sol000._f_get_child(soltab)._f_get_child(axis)[:], indices=[dir_idx], axis=dir_axis)

phase_w = get_data('phase000', 'weight')
amplitude_w = get_data('amplitude000', 'weight')
phase_v = get_data('phase000', 'val')
amplitude_v = get_data('amplitude000', 'val')
new_dirs = np.array([outh5.root.sol000.source[:][dir_idx]])
new_dirs['name'][0] = bytes('Dir' + str(0).zfill(2), 'utf-8')
dirs = np.array([outh5.root.sol000.phase000.dir[:][dir_idx]])

outh5.remove_node("/sol000/phase000", "val", recursive=True)
outh5.remove_node("/sol000/phase000", "weight", recursive=True)
outh5.remove_node("/sol000/phase000", "dir", recursive=True)
outh5.remove_node("/sol000/amplitude000", "val", recursive=True)
outh5.remove_node("/sol000/amplitude000", "weight", recursive=True)
outh5.remove_node("/sol000/amplitude000", "dir", recursive=True)
outh5.remove_node("/sol000", "source", recursive=True)

outh5.create_array('/sol000/phase000', 'val', phase_v)
outh5.create_array('/sol000/phase000', 'weight', phase_w)
outh5.create_array('/sol000/phase000', 'dir', dirs)
outh5.create_array('/sol000/amplitude000', 'val', amplitude_v)
outh5.create_array('/sol000/amplitude000', 'weight', amplitude_w)
outh5.create_array('/sol000/amplitude000', 'dir', dirs)
outh5.create_table('/sol000', 'source', new_dirs, title='Source names and directions')

outh5.root.sol000.phase000.val.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.phase000.weight.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.amplitude000.val.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.amplitude000.weight.attrs['AXES'] = bytes(axes, 'utf-8')

return outputh5


class SubtractWSClean:
def __init__(self, mslist: list = None, region: str = None, localnorth: bool = True, onlyprint: bool = False, inverse: bool = False):
"""
Expand Down Expand Up @@ -162,7 +218,7 @@ def clean_model_images(self):

freqs = []
for ms in self.mslist:
t = ct.table(ms + "::SPECTRAL_WINDOW")
t = ct.table(ms + "::SPECTRAL_WINDOW", ack=False)
freqs += list(t.getcol("CHAN_FREQ")[0])
t.close()
self.fmax_ms = max(freqs)
Expand Down Expand Up @@ -326,7 +382,7 @@ def subtract_col(self, out_column: str = None):

for ms in self.mslist:
print('Subtract ' + ms)
ts = ct.table(ms, readonly=False)
ts = ct.table(ms, readonly=False, ack=False)
colnames = ts.colnames()

if "MODEL_DATA" not in colnames:
Expand Down Expand Up @@ -461,61 +517,6 @@ def isfulljones(h5: str = None):
T.close()
return False

def split_facet_h5(self, h5parm: str = None, dirname: str = None):
"""
Split multi-facet h5parm
:param h5parm: multi-facet h5parm
:param dirname: direction name
"""

outputh5 = f'{h5parm}.{dirname}.h5'
os.system(f'cp {h5parm} {outputh5}')

with tables.open_file(outputh5, 'r+') as outh5:

axes = outh5.root.sol000.phase000.val.attrs["AXES"]

dir_axis = make_utf8(axes).split(',').index('dir')

sources = outh5.root.sol000.source[:]
dirs = [make_utf8(dir) for dir in sources['name']]
dir_idx = dirs.index(dirname)

def get_data(soltab, axis):
return np.take(outh5.root.sol000._f_get_child(soltab)._f_get_child(axis)[:], indices=[dir_idx], axis=dir_axis)

phase_w = get_data('phase000', 'weight')
amplitude_w = get_data('amplitude000', 'weight')
phase_v = get_data('phase000', 'val')
amplitude_v = get_data('amplitude000', 'val')
new_dirs = np.array([outh5.root.sol000.source[:][dir_idx]])
dirs = np.array([outh5.root.sol000.phase000.dir[:][dir_idx]])

outh5.remove_node("/sol000/phase000", "val", recursive=True)
outh5.remove_node("/sol000/phase000", "weight", recursive=True)
outh5.remove_node("/sol000/phase000", "dir", recursive=True)
outh5.remove_node("/sol000/amplitude000", "val", recursive=True)
outh5.remove_node("/sol000/amplitude000", "weight", recursive=True)
outh5.remove_node("/sol000/amplitude000", "dir", recursive=True)
outh5.remove_node("/sol000", "source", recursive=True)

outh5.create_array('/sol000/phase000', 'val', phase_v)
outh5.create_array('/sol000/phase000', 'weight', phase_w)
outh5.create_array('/sol000/phase000', 'dir', dirs)
outh5.create_array('/sol000/amplitude000', 'val', amplitude_v)
outh5.create_array('/sol000/amplitude000', 'weight', amplitude_w)
outh5.create_array('/sol000/amplitude000', 'dir', dirs)
outh5.create_table('/sol000', 'source', new_dirs, title='Source names and directions')

outh5.root.sol000.phase000.val.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.phase000.weight.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.amplitude000.val.attrs['AXES'] = bytes(axes, 'utf-8')
outh5.root.sol000.amplitude000.weight.attrs['AXES'] = bytes(axes, 'utf-8')

return outputh5


def run_DP3(self, phaseshift: str = None, freqavg: str = None,
timeres: str = None, concat: bool = None,
applybeam: bool = None, applycal_h5: str = None, dirname: str = None):
Expand Down Expand Up @@ -579,13 +580,14 @@ def run_DP3(self, phaseshift: str = None, freqavg: str = None,
ac_count = 0
T = tables.open_file(applycal_h5)
for corr in T.root.sol000._v_groups.keys():
command += [f'ac{ac_count}.type=applycal',
f'ac{ac_count}.parmdb={applycal_h5}',
f'ac{ac_count}.correction={corr}']
if phaseshift is not None and dirname is not None:
command += [f'ac{ac_count}.direction=' + dirname]
steps.append(f'ac{ac_count}')
ac_count += 1
if 'phase' in corr or 'amp' in corr:
command += [f'ac{ac_count}.type=applycal',
f'ac{ac_count}.parmdb={applycal_h5}',
f'ac{ac_count}.correction={corr}']
if phaseshift is not None and dirname is not None:
command += [f'ac{ac_count}.direction=' + dirname]
steps.append(f'ac{ac_count}')
ac_count += 1
T.close()

# 4) APPLY BEAM (OPTIONAL IF APPLY BEAM USED FOR APPLYCAL)
Expand Down Expand Up @@ -725,11 +727,11 @@ def main():
else:
sys.exit('ERROR: using --forwidefield option needs polygon_info.csv file to read polygon information from')

t = ct.table(args.mslist[0] + "::SPECTRAL_WINDOW")
t = ct.table(args.mslist[0] + "::SPECTRAL_WINDOW", ack=False)
channum = len(t.getcol("CHAN_FREQ")[0])
t.close()

t = ct.table(args.mslist[0])
t = ct.table(args.mslist[0], ack=False)
time = np.unique(t.getcol("TIME"))
dtime = abs(time[1] - time[0])
t.close()
Expand Down Expand Up @@ -778,49 +780,51 @@ def main():
command = [f'mkdir -p {runpath}',
f'cp *.fits {runpath}',
f'cp {args.region} {runpath}']
command += [f'mv {dataset} {runpath}' for dataset in args.mslist]
command += [f'rsync -a --no-perms {dataset} {runpath}' for dataset in args.mslist]
command += [f'rm -rf {dataset}' for dataset in args.mslist]
os.system('&&'.join(command))
outpath = os.getcwd()
os.chdir(runpath)


object = SubtractWSClean(mslist=args.mslist if not args.scratch else [ms.split('/')[-1] for ms in args.mslist],
subpred = SubtractWSClean(mslist=args.mslist if not args.scratch else [ms.split('/')[-1] for ms in args.mslist],
region=args.region if not args.scratch else args.region.split('/')[-1],
localnorth=not args.no_local_north,
onlyprint=args.print_only_commands,
inverse=args.inverse)

if not args.skip_predict:

# clean model images
object.clean_model_images()
subpred.clean_model_images()

# mask
print('############## MASK REGION ##############')
if args.region is not None:
object.mask_region(region_cube=args.use_region_cube)
subpred.mask_region(region_cube=args.use_region_cube)

if args.scratch:
os.system(f'cp {outpath}/{args.h5parm_predict.split("/")[-1]} {runpath}')

if args.inverse:
if args.scratch:
os.system(f'cp {args.h5parm_predict} {runpath}')
faceth5 = object.split_facet_h5(h5parm=args.h5parm_predict if not args.scratch else args.h5parm_predict.split('/')-[-1],
faceth5 = split_facet_h5(h5parm=args.h5parm_predict if not args.scratch else args.h5parm_predict.split('/')[-1],
dirname=dirname)
# predict
print('############## PREDICT ##############')
object.predict(h5parm=faceth5,
subpred.predict(h5parm=faceth5,
facet_regions=args.region if not args.scratch else args.region.split('/')[-1])

else:
# predict
print('############## PREDICT ##############')
if args.scratch:
os.system(f'cp {args.h5parm_predict} {runpath}')
os.system(f'cp {args.facets_predict} {runpath}')
object.predict(h5parm=args.h5parm_predict if not args.scratch else args.h5parm_predict.split('/')-[-1],
facet_regions=args.facets_predict if not args.scratch else args.facets_predict.split('/')-[-1])
os.system(f'cp {outpath}/{args.facets_predict.split("/")[-1]} {runpath}')
subpred.predict(h5parm=args.h5parm_predict if not args.scratch else args.h5parm_predict.split('/')[-1],
facet_regions=args.facets_predict if not args.scratch else args.facets_predict.split('/')[-1])

# subtract
print('############## SUBTRACT ##############')
object.subtract_col(out_column='SUBTRACT_DATA' if not args.inverse else "DATA")
subpred.subtract_col(out_column='SUBTRACT_DATA' if not args.inverse else "DATA")

# extra DP3 step
if args.phasecenter is not None or \
Expand All @@ -840,17 +844,18 @@ def main():
applycalh5 = None

if args.scratch:
os.system(f'cp {applycalh5} {runpath}')
os.system(f'cp {outpath}/{applycalh5.split("/")[-1]} {runpath}')

msout = object.run_DP3(phaseshift=phasecenter, freqavg=freqavg, timeres=timeres,
msout = subpred.run_DP3(phaseshift=phasecenter, freqavg=freqavg, timeres=timeres,
concat=args.concat, applybeam=args.applybeam,
applycal_h5=applycalh5 if not args.scratch else applycalh5.split('/')[-1], dirname=dirname)

if args.scratch:
for ms in msout: os.system(f'mv {ms} {outpath}')
os.system('rm -r *.ms')
for ms in msout: os.system(f'cp -r {ms} {outpath}')
os.system('rm -rf *.ms')
os.chdir(outpath)

print(f"DONE: See output --> sub{object.scale}*.ms")
print(f"DONE: See output --> sub{subpred.scale}*.ms")
else:
print(f"DONE: Output is SUBTRACT_DATA column in input MS")

Expand Down

0 comments on commit f366d21

Please sign in to comment.