From 81a250c38f108c985bd5c26576f6dd9a677f3098 Mon Sep 17 00:00:00 2001 From: jurjen93 Date: Wed, 26 Jun 2024 14:10:09 +0200 Subject: [PATCH] ms utils and stacking speed --- ms_helpers/ms_flagger.py | 1 + ms_helpers/remove_flagged_stations.py | 89 ++++++++++++++ ms_stacker.py | 160 +++++++++++++++++--------- 3 files changed, 198 insertions(+), 52 deletions(-) create mode 100644 ms_helpers/remove_flagged_stations.py diff --git a/ms_helpers/ms_flagger.py b/ms_helpers/ms_flagger.py index 8a023f5d..61d76c4c 100644 --- a/ms_helpers/ms_flagger.py +++ b/ms_helpers/ms_flagger.py @@ -46,5 +46,6 @@ def main(): print(' '.join(command)) os.system(' '.join(command)) + if __name__ == '__main__': main() diff --git a/ms_helpers/remove_flagged_stations.py b/ms_helpers/remove_flagged_stations.py new file mode 100644 index 00000000..84864d06 --- /dev/null +++ b/ms_helpers/remove_flagged_stations.py @@ -0,0 +1,89 @@ +from casacore.tables import table +import numpy as np +import os +from shutil import rmtree, move +from argparse import ArgumentParser +from sys import exit + + +def remove_flagged_antennas(msin: str = None, msout: str = None, overwrite: bool = False): + """ + Remove antennas that are full flagged (to save storage) + + input: + - msfile: measurement set name + """ + + # Cannot both overwrite and give an output name + if msout is not None and overwrite: + exit('ERROR: You specified an --msout and ask to --overwrite. Please give only one of both.') + + # Set name for output if not given + if msout is None: + msout = f"flagged_{msin}" + + # Read antenna names from Measurement Set + with table(f"{msin}::ANTENNA", ack=False) as ants: + ants_names = ants.getcol("NAME") + + # Read main tables Measurement Set + with table(msin, readonly=True, ack=False) as ms: + # Read the antenna ID columns + antenna1 = ms.getcol('ANTENNA1') + antenna2 = ms.getcol('ANTENNA2') + + # Read the FLAG column + flags = ms.getcol('FLAG') + + # Get the unique antenna indices + unique_antennas = np.unique(np.concatenate((antenna1, antenna2))) + + # Identify fully flagged antennas + fully_flagged_antennas = [] + for ant in unique_antennas: + # Find rows with this antenna + ant_rows = np.where((antenna1 == ant) | (antenna2 == ant)) + # Check if all data for this antenna is flagged + if np.all(flags[ant_rows]): + fully_flagged_antennas.append(ant) + + # Get names of ants to filter + ants_to_filter = ','.join([ants_names[idx] for idx in fully_flagged_antennas]) + print(f"Filtering fully flagged antennas: {ants_to_filter}") + + # Run DP3 + dp3_cmd = f'DP3 msin={msin} msout={msout} msout.storagemanager=dysco steps=[filter] \ + filter.type=filter filter.remove=true filter.baseline=!{ants_to_filter}' + + os.system(dp3_cmd) + + # Overwrite input + if overwrite: + rmtree(msin) + move(msout, msin) + + +def parse_args(): + """ + Parse input arguments + """ + + parser = ArgumentParser(description='MS stacking') + parser.add_argument('msin', type=str, help='Input Measurement Set', required=True) + parser.add_argument('--msout', type=str, default=None, help='Output Measurement Set') + parser.add_argument('--overwrite', action='store_true', help='Overwrite input Measurement Set') + + return parser.parse_args() + + +def main(): + """ + Main function + """ + + args = parse_args() + remove_flagged_antennas(args.msin, args.msout, args.overwrite) + + +if __name__ == '__main__': + main() diff --git a/ms_stacker.py b/ms_stacker.py index a5ebba11..a1f1fcc8 100644 --- a/ms_stacker.py +++ b/ms_stacker.py @@ -23,6 +23,7 @@ from casacore.tables import table, default_ms, taql import numpy as np +from os import system as run_command import os import shutil import sys @@ -47,12 +48,11 @@ try: from dask.distributed import Client - use_dask = True + use_dask_distributed = True + print('AWESOME: dask.distributed installed, continue with advanced parallel stacking.') except ImportError: - print('WARNING: dask.distrubted not installed, continue without parallel stacking.') - use_dask = False - -use_dask = False # TODO: dask not yet implemented + print('WARNING: dask.distributed not installed, continue without parallel stacking.') + use_dask_distributed = False one_lst_day_sec = 86164.1 @@ -86,7 +86,7 @@ def is_dysco_compressed(ms): - ms: measurement set """ - t = table(ms, readonly=True) + t = table(ms, readonly=True, ack=False) dysco = t.getdesc()["DATA"]['dataManagerGroup'] == 'DyscoData' t.close() return dysco @@ -106,7 +106,7 @@ def decompress(ms): if os.path.exists(f'{ms}.tmp'): shutil.rmtree(f'{ms}.tmp') - os.system(f"DP3 msin={ms} msout={ms}.tmp steps=[]") + run_command(f"DP3 msin={ms} msout={ms}.tmp steps=[]") print('----------') return ms + '.tmp' @@ -153,10 +153,10 @@ def compress(ms, avg=None): steps = str(steps).replace("'", "").replace(' ','') cmd += f' steps={steps}' - os.system(cmd) + run_command(cmd) try: - t = table(f"{ms}.tmp") # test if exists + t = table(f"{ms}.tmp", ack=False) # test if exists t.close() except RuntimeError: sys.exit(f"ERROR: dysco compression failed (please check {ms})") @@ -656,7 +656,7 @@ def add_spectral_window(self): Add SPECTRAL_WINDOW as sub table """ - print("\n----------\nADD " + self.outname + "::SPECTRAL_WINDOW\n----------\n") + print("\n----------\nAdd table: " + self.outname + "::SPECTRAL_WINDOW\n----------\n") tnew_spw_tmp = table(self.ref_table.getkeyword('SPECTRAL_WINDOW'), ack=False) newdesc = tnew_spw_tmp.getdesc() @@ -685,7 +685,7 @@ def add_stations(self): Add ANTENNA and FEED tables """ - print("\n----------\nADD " + self.outname + "::ANTENNA\n----------\n") + print("\n----------\nAdd table: " + self.outname + "::ANTENNA\n----------\n") stations = [sp[0] for sp in self.station_info] st_id = dict(zip(set( @@ -723,7 +723,7 @@ def add_stations(self): tnew_ant.flush(True) tnew_ant.close() - print("\n----------\nADD " + self.outname + "::FEED\n----------\n") + print("\n----------\nAdd table: " + self.outname + "::FEED\n----------\n") tnew_ant_tmp = table(self.ref_table.getkeyword('FEED'), ack=False) newdesc = tnew_ant_tmp.getdesc() @@ -746,7 +746,7 @@ def add_stations(self): tnew_feed.flush(True) tnew_feed.close() - print("\n----------\nADD " + self.outname + "::LOFAR_ANTENNA_FIELD\n----------\n") + print("\n----------\nAdd table: " + self.outname + "::LOFAR_ANTENNA_FIELD\n----------\n") tnew_ant_tmp = table(self.ref_table.getkeyword('LOFAR_ANTENNA_FIELD'), ack=False) newdesc = tnew_ant_tmp.getdesc() @@ -766,7 +766,7 @@ def add_stations(self): tnew_field.flush(True) tnew_field.close() - print("\n----------\nADD " + self.outname + "::LOFAR_STATION\n----------\n") + print("\n----------\nAdd table: " + self.outname + "::LOFAR_STATION\n----------\n") tnew_ant_tmp = table(self.ref_table.getkeyword('LOFAR_STATION'), ack=False) newdesc = tnew_ant_tmp.getdesc() @@ -793,26 +793,32 @@ def make_template(self, overwrite: bool = True, avg_factor: float = 1): same_phasedir(self.mslist) # Get data columns - unique_stations = [] - unique_lofar_stations = [] - unique_channels = [] - for k, ms in enumerate(self.mslist): + # Initialize variables outside the loop + unique_stations, unique_channels, unique_lofar_stations = [], [], [] + min_t_lst, min_dt, dfreq_min, max_t_lst = None, None, None, None + + def process_ms(ms): + """Parallel""" mscontent = get_ms_content(ms) stations, lofar_stations, channels, dfreq, total_time_seconds, dt, min_t, max_t = mscontent.values() - if k == 0: - min_t_lst = min_t - min_dt = dt - dfreq_min = dfreq - max_t_lst = max_t - else: - min_t_lst = min(min_t_lst, min_t) - min_dt = min(min_dt, dt) - dfreq_min = min(dfreq_min, dfreq) - max_t_lst = max(max_t_lst, max_t) + return stations, lofar_stations, channels, dfreq, dt, min_t, max_t - unique_stations += list(stations) - unique_channels += list(channels) - unique_lofar_stations += list(lofar_stations) + with ThreadPoolExecutor() as executor: + future_to_ms = {executor.submit(process_ms, ms): ms for ms in self.mslist} + for future in as_completed(future_to_ms): + stations, lofar_stations, channels, dfreq, dt, min_t, max_t = future.result() + + if min_t_lst is None: + min_t_lst, min_dt, dfreq_min, max_t_lst = min_t, dt, dfreq, max_t + else: + min_t_lst = min(min_t_lst, min_t) + min_dt = min(min_dt, dt) + dfreq_min = min(dfreq_min, dfreq) + max_t_lst = max(max_t_lst, max_t) + + unique_stations.extend(stations) + unique_channels.extend(channels) + unique_lofar_stations.extend(lofar_stations) self.station_info = unique_station_list(unique_stations) self.lofar_stations_info = unique_station_list(unique_lofar_stations) @@ -875,12 +881,12 @@ def make_template(self, overwrite: bool = True, avg_factor: float = 1): # Set ANTENNA/STATION info self.add_stations() - # Set other tables + # Set other tables (annoying table locks prevent parallel processing) for subtbl in ['FIELD', 'HISTORY', 'FLAG_CMD', 'DATA_DESCRIPTION', 'LOFAR_ELEMENT_FAILURE', 'OBSERVATION', 'POINTING', 'POLARIZATION', 'PROCESSOR', 'STATE']: try: - print("----------\nADD " + self.outname + "::" + subtbl + "\n----------") + print("\n----------\nAdd table: " + self.outname + "::" + subtbl + "\n----------\n") tsub = table(self.tmpfile+"::"+subtbl, ack=False, readonly=False) tsub.copy(self.outname + '/' + subtbl, deep=True) @@ -1140,7 +1146,7 @@ def __init__(self, msin: list = None, outname: str = 'empty.ms', chunkmem: float num_cpus = psutil.cpu_count(logical=True) total_memory = psutil.virtual_memory().total / (1024 ** 3) # in GB - if use_dask: + if use_dask_distributed: print(f"CPU number: {num_cpus}\nMemory limit: {int(total_memory)}GB") self.client = Client(n_workers=num_cpus, threads_per_worker=1, @@ -1229,6 +1235,12 @@ def read_mapping(mapping_folder): else: new_data, _ = get_data_arrays(col, self.T.nrows(), freq_len) + # Scatter data with dask + if use_dask_distributed: + new_data_future = da.from_delayed(self.client.scatter(new_data, broadcast=True), shape=new_data.shape, dtype=new_data.dtype) + if col == 'UVW': + uvw_weights_future = da.from_delayed(self.client.scatter(uvw_weights, broadcast=True), shape=uvw_weights.shape, dtype=uvw_weights.dtype) + # Loop over measurement sets for ms in self.mslist: @@ -1267,15 +1279,33 @@ def read_mapping(mapping_folder): row_idxs = [int(i - chunk_idx * self.chunk_size) for i in indices[chunk_idx * self.chunk_size:self.chunk_size * (chunk_idx+1)]] - # Stack columns + # Stack columns (using dask for multi-processing) if col == 'UVW': - new_data[row_idxs_new, :] += data[row_idxs, :] + if use_dask_distributed: + new_data_dask = da.from_delayed(new_data_future[row_idxs_new, :], + shape=new_data[row_idxs_new, :].shape, + dtype=new_data.dtype) + else: + new_data_dask = da.from_array(new_data[row_idxs_new, :], chunks='auto') + new_data[row_idxs_new, :] = (new_data_dask + data[row_idxs, :]).compute() new_data.flush() - uvw_weights[row_idxs_new, :] += 1 + if use_dask_distributed: + uvw_weights_dask = da.from_delayed(uvw_weights_future[row_idxs_new, :], + shape=uvw_weights[row_idxs_new, :].shape, + dtype=uvw_weights.dtype) + else: + uvw_weights_dask = da.from_array(uvw_weights[row_idxs_new, :], chunks='auto') + uvw_weights[row_idxs_new, :] = (uvw_weights_dask + 1).compute() uvw_weights.flush() else: - new_data[np.ix_(row_idxs_new, freq_idxs)] += data[row_idxs, :, :] + if use_dask_distributed: + new_data_dask = da.from_delayed(new_data_future[np.ix_(row_idxs_new, freq_idxs)], + shape=new_data[np.ix_(row_idxs_new, freq_idxs)].shape, + dtype=new_data.dtype) + else: + new_data_dask = da.from_array(new_data[np.ix_(row_idxs_new, freq_idxs)], chunks='auto') + new_data[np.ix_(row_idxs_new, freq_idxs)] = (new_data_dask + data[row_idxs, :]).compute() new_data.flush() @@ -1320,7 +1350,7 @@ def clean_binary_files(): """ for b in glob('*.tmp.dat'): - os.system(f'rm {b}') + run_command(f'rm {b}') def check_folder_exists(folder_path): @@ -1330,7 +1360,7 @@ def check_folder_exists(folder_path): return os.path.isdir(folder_path) -def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, mappingfiles: list = None, UV=True): +def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, mappingfiles: list = None, UV=True, saveas=None): """ Plot baseline track @@ -1351,14 +1381,13 @@ def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, ma for n, t_input_name in enumerate(t_input_names): try: - m = open(mappingfiles[n]) mapping = {int(i): int(j) for i, j in json.load(m).items()} - with table(t_final_name) as f: + with table(t_final_name, ack=False) as f: uvw1 = f.getcol("UVW")[list(mapping.values())] - with table(t_input_name) as f: + with table(t_input_name, ack=False) as f: uvw2 = f.getcol("UVW")[list(mapping.keys())] # Scatter plot for uvw1 @@ -1368,7 +1397,7 @@ def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, ma lbl = None plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.4, s=100, marker='o') - plt.plot(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], color='blue', alpha=0.4) + # plt.plot(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], color='blue', alpha=0.4) # Scatter plot for uvw2 plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'MS {n} to stack', color=colors[n], edgecolor='black', alpha=0.7, s = 40, marker='*') @@ -1376,17 +1405,40 @@ def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, ma except: pass - # Adding labels and title - plt.xlabel("U (m)", fontsize=14) - plt.ylabel("V (m)" if UV else "W (m)", fontsize=14) + # Adding labels and title + plt.xlabel("U (m)", fontsize=14) + plt.ylabel("V (m)" if UV else "W (m)", fontsize=14) + + # Adding grid + plt.grid(True, linestyle='--', alpha=0.6) - # Adding grid - plt.grid(True, linestyle='--', alpha=0.6) + # Adding legend + plt.legend(fontsize=12) - # Adding legend - plt.legend(fontsize=12) + if saveas is None: + plt.show() + else: + plt.savefig(saveas, dpi=150) + plt.close() + + +def make_baseline_uvw_plots(tabl, mslist): + """ + Make baseline plots + """ - plt.show() + run_command('mkdir -p baseline_plots') + + ants = table(tabl + "::ANTENNA", ack=False) + baselines = np.c_[make_ant_pairs(ants.nrows(), 1)] + ants.close() + + for baseline in baselines: + bl = '-'.join([str(a) for a in baseline]) + plot_baseline_track(tabl, sorted(mslist), sorted([ms + '_baseline_mapping/'+ bl+'.json' for ms in mslist]), + saveas=f'baseline_plots/{bl}.png') + + return def parse_args(): @@ -1405,6 +1457,7 @@ def parse_args(): parser.add_argument('--record_time', action='store_true', help='Record wall-time of stacking') parser.add_argument('--no_compression', action='store_true', help='No compression of data') parser.add_argument('--make_only_template', action='store_true', help='Stop after making empty template') + parser.add_argument('--plot_uv_baseline_coverage', action='store_true', help='make plots with baseline versus UV') return parser.parse_args() @@ -1451,6 +1504,9 @@ def ms_merger(): if not args.no_compression: compress(args.msout) #, max(int(avg * args.avg), 1) if args.advanced_stacking else int(args.avg)) + if args.plot_uv_baseline_coverage: + make_baseline_uvw_plots(args.msout, args.msin) + # Clean up mapping files if not args.keep_mapping: clean_mapping_files(args.msin)