From 784d86e24517409a1c144943fab6ac629885294a Mon Sep 17 00:00:00 2001 From: jurjen93 Date: Mon, 24 Jun 2024 13:16:21 +0200 Subject: [PATCH] chunking bug fix --- fits_helpers/crop_nan_boundaries.py | 56 ++++++ ms_stacker.py | 298 ++++++++++++++++------------ 2 files changed, 228 insertions(+), 126 deletions(-) create mode 100644 fits_helpers/crop_nan_boundaries.py diff --git a/fits_helpers/crop_nan_boundaries.py b/fits_helpers/crop_nan_boundaries.py new file mode 100644 index 00000000..83b50f65 --- /dev/null +++ b/fits_helpers/crop_nan_boundaries.py @@ -0,0 +1,56 @@ +import numpy as np +from astropy.io import fits +from argparse import ArgumentParser + +def crop_nan_boundaries(fits_in, fits_out): + """ + Crop nan boundaries + + input: + - fits_in: input fits file + - fits_out: output fits file + """ + + with fits.open(fits_in) as hdul: + image_data = hdul[0].data + header = hdul[0].header + + mask = ~np.isnan(image_data) + non_nan_indices = np.where(mask) + + ymin, ymax = non_nan_indices[0].min(), non_nan_indices[0].max() + xmin, xmax = non_nan_indices[1].min(), non_nan_indices[1].max() + + print(f"Original shape {image_data.shape}") + print(ymin, ymax) + print(xmin, xmax) + + cropped_image = image_data[ymin:ymax + 1, xmin:xmax + 1] + + header['NAXIS1'] = cropped_image.shape[1] + header['NAXIS2'] = cropped_image.shape[0] + header['CRPIX1'] -= xmin + header['CRPIX2'] -= ymin + + print(f"New shape {cropped_image.shape}") + + hdu = fits.PrimaryHDU(cropped_image, header=header) + hdu.writeto(fits_out, overwrite=True) + +def parse_args(): + """ + Command line argument parser + :return: parsed arguments + """ + parser = ArgumentParser(description='Crop fits file with nan boundaries') + parser.add_argument('--fits_input', help='fits input file', required=True, type=str) + parser.add_argument('--fits_output', help='fits output file', required=True, type=str) + return parser.parse_args() + +def main(): + """ Main function""" + args = parse_args() + crop_nan_boundaries(args.fits_input, args.fits_output) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ms_stacker.py b/ms_stacker.py index 29bb79aa..22c141f3 100644 --- a/ms_stacker.py +++ b/ms_stacker.py @@ -13,7 +13,11 @@ 2) Map baselines from input MS to template MS. This step makes *baseline_mapping folders with the baseline mappings in json files. - 3) Stack measurement sets on the template (Stack class). + 3) Interpolate new UV data with nearest neighbours. + + 4) Make exact mapping between input MS and template MS, using only UV data points. + + 5) Stack measurement sets on the template (Stack class). The stacking is done with a weighted average, using the FLAG and WEIGHT_SPECTRUM columns. """ @@ -28,7 +32,7 @@ from pprint import pprint from argparse import ArgumentParser import json -from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from itertools import compress import time import dask.array as da @@ -38,6 +42,7 @@ import matplotlib.pyplot as plt from scipy.spatial import cKDTree from scipy.interpolate import interp1d +from scipy.ndimage import gaussian_filter1d try: @@ -525,60 +530,7 @@ def add_axis(arr, ax_size): return np.repeat(arr, ax_size).reshape(new_shape) -def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, mappingfiles: list = None, UV=True): - """ - Plot baseline track - - :input: - - t_final_name: table with final name - - t_input_names: tables to compare with - - mappingfiles: baseline mapping files - """ - - if len(t_input_names) > 4: - sys.exit("ERROR: Can just plot 4 inputs") - - colors = ['red', 'orange', 'pink', 'brown'] - - if not UV: - print("MAKE UW PLOT") - - for n, t_input_name in enumerate(t_input_names): - - m = open(mappingfiles[n]) - mapping = {int(i): int(j) for i, j in json.load(m).items()} - - with table(t_final_name) as f: - uvw1 = f.getcol("UVW")[list(mapping.values())] - - with table(t_input_name) as f: - uvw2 = f.getcol("UVW")[list(mapping.keys())] - - # Scatter plot for uvw1 - if n == 0: - lbl = 'Final MS' - else: - lbl = None - - plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.4, s=100, marker='o') - - # Scatter plot for uvw2 - plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'MS {n} to stack', color=colors[n], edgecolor='black', alpha=0.7, s = 40, marker='*') - - # Adding labels and title - plt.xlabel("U (m)", fontsize=14) - plt.ylabel("V (m)" if UV else "W (m)", fontsize=14) - - # Adding grid - plt.grid(True, linestyle='--', alpha=0.6) - - # Adding legend - plt.legend(fontsize=12) - - plt.show() - - -def resample_uwv(uvw_arrays, row_idxs, time, time_ref): +def resample_uwv(uvw_arrays, time, time_ref): """ Resample a uvw array to have N rows. """ @@ -587,20 +539,19 @@ def resample_uwv(uvw_arrays, row_idxs, time, time_ref): num_points, num_coords = uvw_arrays.shape if num_coords != 3: - raise ValueError("Input array must have shape (num_points, 3)") + raise ValueError(f"Input array must have shape ({num_points}, 3)") sorted_indices = np.argsort(time) - # Resample each coordinate separately - resampled_array = np.zeros((len(row_idxs), num_coords)) + resampled_array = np.zeros((len(time_ref), num_coords)) - # Use the interpolation functions to compute resampled values + # Create a single interpolation function for the entire UVW array + interp_funcs = [ + interp1d(time[sorted_indices], uvw_arrays[:, i][sorted_indices], kind='nearest', fill_value='extrapolate') + for i in range(num_coords) + ] for i in range(num_coords): - interp_funcs = [ - interp1d(time[sorted_indices], uvw_arrays[:, i][sorted_indices], kind='nearest', fill_value='extrapolate') - for i in range(num_coords) - ] - resampled_array[:, i] = interp_funcs[i](time_ref[list(row_idxs)]) + resampled_array[:, i] = interp_funcs[i](time_ref) return resampled_array @@ -633,6 +584,57 @@ def resample_array(data, factor): return resampled_data +def get_data_arrays(column: str = 'DATA', nrows: int = None, freq_len: int = None): + """ + Get data arrays (new data and weights) + + :param: + - column: column name (DATA, WEIGHT_SPECTRUM, WEIGHT, OR UVW) + - nrows: number of rows + - freq_len: frequency axis length + + :return: + - new_data: new data array (empty array with correct shape) + - weights: weights corresponding to new data array (empty array with correct shape) + """ + + tmpfilename = column.lower()+'.tmp.dat' + tmpfilename_weights = column.lower()+'_weights.tmp.dat' + + if column in ['UVW']: + weights = np.memmap(tmpfilename_weights, dtype=np.float16, mode='w+', shape=(nrows, 3)) + weights[:] = 0 + else: + weights = None + + if column in ['DATA', 'WEIGHT_SPECTRUM']: + if column == 'DATA': + dtp = np.complex128 + elif column == 'WEIGHT_SPECTRUM': + dtp = np.float32 + else: + dtp = np.float32 + shape = (nrows, freq_len, 4) + + elif column == 'WEIGHT': + shape, dtp = (nrows, freq_len), np.float32 + + elif column == 'UVW': + shape, dtp = (nrows, 3), np.float32 + + else: + sys.exit("ERROR: Use only DATA, WEIGHT_SPECTRUM, WEIGHT, or UVW") + + new_data = np.memmap(tmpfilename, dtype=dtp, mode='w+', shape=shape) + + return new_data, weights + + +def load_json(file_path): + with open(file_path, 'r') as file: + return json.load(file) + + class Template: """Make template measurement set based on input measurement sets""" def __init__(self, msin: list = None, outname: str = 'empty.ms'): @@ -867,12 +869,15 @@ def make_template(self, overwrite: bool = True, avg_factor: float = 1): for subtbl in ['FIELD', 'HISTORY', 'FLAG_CMD', 'DATA_DESCRIPTION', 'LOFAR_ELEMENT_FAILURE', 'OBSERVATION', 'POINTING', 'POLARIZATION', 'PROCESSOR', 'STATE']: - print("----------\nADD " + self.outname + "::" + subtbl + "\n----------") + try: + print("----------\nADD " + self.outname + "::" + subtbl + "\n----------") - tsub = table(self.tmpfile+"::"+subtbl, ack=False, readonly=False) - tsub.copy(self.outname + '/' + subtbl, deep=True) - tsub.flush(True) - tsub.close() + tsub = table(self.tmpfile+"::"+subtbl, ack=False, readonly=False) + tsub.copy(self.outname + '/' + subtbl, deep=True) + tsub.flush(True) + tsub.close() + except: + print(subtbl+" unknown") self.ref_table.close() @@ -1018,8 +1023,8 @@ def process_baselines(baseline_indices, baselines, mslist): T = table(self.outname, readonly=False, ack=False) UVW = np.memmap('UVW.tmp.dat', dtype=np.float32, mode='w+', shape=(T.nrows(), 3)) - TIME = np.memmap('TIME.tmp.dat', dtype=np.float64, mode='w+', shape=(T.nrows())) - TIME[:] = T.getcol("TIME") + # TIME = np.memmap('TIME.tmp.dat', dtype=np.float64, mode='w+', shape=(T.nrows())) + TIME = np.unique(T.getcol("TIME")) # Determine the optimal number of workers cpu_count = max(os.cpu_count()-3, 1) @@ -1050,8 +1055,8 @@ def process_baselines(baseline_indices, baselines, mslist): batch_start_idx = future_to_baseline[future] try: results = future.result() - for row_idxs, uvws, b_idx, time in results: - UVW[row_idxs] = resample_uwv(uvws, row_idxs, time, TIME) + for _, uvws, b_idx, time in results: + UVW[np.array(range(len(TIME)))*len(baselines)+b_idx] = resample_uwv(uvws, time, TIME) except Exception as exc: print(f'Batch starting at index {batch_start_idx} generated an exception: {exc}') @@ -1062,6 +1067,7 @@ def process_baselines(baseline_indices, baselines, mslist): # Make final mapping self.make_mapping_uvw() + def make_mapping_uvw(self): """ Make mapping json files essential for efficient stacking based on UVW points @@ -1136,53 +1142,33 @@ def __init__(self, msin: list = None, outname: str = 'empty.ms', chunkmem: float target_chunk_size = total_memory / chunkmem self.chunk_size = min(int(target_chunk_size * (1024 ** 3) / np.dtype(np.float128).itemsize), 1_000_000) - def get_data_arrays(self, column: str = 'DATA', nrows: int = None, freq_len: int = None): + def smooth_uvw(self): """ - Get data arrays (new data and weights) - - :param: - - column: column name (DATA, WEIGHT_SPECTRUM, WEIGHT, OR UVW) - - nrows: number of rows - - freq_len: frequency axis length - - :return: - - new_data: new data array (empty array with correct shape) - - weights: weights corresponding to new data array (empty array with correct shape) + Smooth UVW values """ - tmpfilename = column.lower()+'.tmp.dat' - tmpfilename_weights = column.lower()+'_weights.tmp.dat' - - if column in ['UVW']: - weights = np.memmap(tmpfilename_weights, dtype=np.float16, mode='w+', shape=(nrows, 3)) - weights[:] = 0 - else: - weights = None - - if column in ['DATA', 'WEIGHT_SPECTRUM']: - if column == 'DATA': - dtp = np.complex128 - elif column == 'WEIGHT_SPECTRUM': - dtp = np.float32 - else: - dtp = np.float32 - shape = (nrows, freq_len, 4) - - elif column == 'WEIGHT': - shape, dtp = (nrows, freq_len), np.float32 + uvw, _ = get_data_arrays('UVW', self.T.nrows()) + uvw[:] = self.T.getcol("UVW") + time = self.T.getcol("TIME") - elif column == 'UVW': - shape, dtp = (nrows, 3), np.float32 - - else: - sys.exit("ERROR: Use only DATA, WEIGHT_SPECTRUM, WEIGHT, or UVW") + ants = table(self.outname + "::ANTENNA", ack=False) + baselines = np.c_[make_ant_pairs(ants.nrows(), 1)] + ants.close() - new_data = np.memmap(tmpfilename, dtype=dtp, mode='w+', shape=shape) + print('Smooth UVW') + for idx_b, baseline in enumerate(baselines): + print_progress_bar(idx_b, len(baselines)) + idxs = [] + for baseline_json in glob(f"*baseline_mapping/{baseline[0]}-{baseline[1]}.json"): + idxs += list(load_json(baseline_json).values()) + sorted_indices = np.argsort(time[idxs]) + for i in range(3): + uvw[np.array(idxs)[sorted_indices], i] = gaussian_filter1d(uvw[np.array(idxs)[sorted_indices], i], sigma=2) - return new_data, weights + self.T.putcol('UVW', uvw) - def stack_all(self, column: str = 'DATA'): + def stack_all(self, column: str = 'DATA', advanced: bool = False): """ Stack all MS @@ -1190,10 +1176,6 @@ def stack_all(self, column: str = 'DATA'): - column: column name (currently only DATA) """ - def load_json(file_path): - with open(file_path, 'r') as file: - return json.load(file) - def read_mapping(mapping_folder): """ Read mapping with multi-threads @@ -1217,8 +1199,10 @@ def read_mapping(mapping_folder): return indices, ref_indices if column == 'DATA': - columns = ['UVW', column, 'WEIGHT_SPECTRUM'] - # columns = [column, 'WEIGHT_SPECTRUM'] + if advanced: + columns = ['UVW', column, 'WEIGHT_SPECTRUM'] + else: + columns = [column, 'WEIGHT_SPECTRUM'] else: sys.exit("ERROR: Only column 'DATA' allowed (for now)") @@ -1235,9 +1219,9 @@ def read_mapping(mapping_folder): for col in columns: if col == 'UVW': - new_data, uvw_weights = self.get_data_arrays(col, self.T.nrows(), freq_len) + new_data, uvw_weights = get_data_arrays(col, self.T.nrows(), freq_len) else: - new_data, _ = self.get_data_arrays(col, self.T.nrows(), freq_len) + new_data, _ = get_data_arrays(col, self.T.nrows(), freq_len) # Loop over measurement sets for ms in self.mslist: @@ -1266,7 +1250,7 @@ def read_mapping(mapping_folder): indices, ref_indices = read_mapping(mapping_folder) # Chunked stacking! - chunks = t.nrows()//self.chunk_size + 1 + chunks = len(indices)//self.chunk_size + 1 print(f'Stacking in {chunks} chunks') for chunk_idx in range(chunks): print_progress_bar(chunk_idx, chunks+1) @@ -1288,19 +1272,22 @@ def read_mapping(mapping_folder): new_data[np.ix_(row_idxs_new, freq_idxs)] += data[row_idxs, :, :] new_data.flush() + print_progress_bar(chunk_idx, chunks) t.close() print(f'Put column {col}') if col == 'UVW': - uvw_weights[uvw_weights==0] = 1 + uvw_weights[uvw_weights == 0] = 1 new_data /= uvw_weights new_data[new_data != new_data] = 0. - for chunk_idx in range(chunks): + for chunk_idx in range(self.T.nrows()//self.chunk_size+1): self.T.putcol(col, new_data[chunk_idx * self.chunk_size:self.chunk_size * (chunk_idx+1)], startrow=chunk_idx * self.chunk_size, nrow=self.chunk_size) + # self.smooth_uvw() + self.T.close() # NORM DATA @@ -1339,6 +1326,65 @@ def check_folder_exists(folder_path): return os.path.isdir(folder_path) +def plot_baseline_track(t_final_name: str = None, t_input_names: list = None, mappingfiles: list = None, UV=True): + """ + Plot baseline track + + :input: + - t_final_name: table with final name + - t_input_names: tables to compare with + - mappingfiles: baseline mapping files + """ + + if len(t_input_names) > 4: + sys.exit("ERROR: Can just plot 4 inputs") + + colors = ['red', 'orange', 'pink', 'brown'] + + if not UV: + print("MAKE UW PLOT") + + for n, t_input_name in enumerate(t_input_names): + + try: + + m = open(mappingfiles[n]) + mapping = {int(i): int(j) for i, j in json.load(m).items()} + + with table(t_final_name) as f: + uvw1 = f.getcol("UVW")[list(mapping.values())] + + with table(t_input_name) as f: + uvw2 = f.getcol("UVW")[list(mapping.keys())] + + # Scatter plot for uvw1 + if n == 0: + lbl = 'Final MS' + else: + lbl = None + + plt.scatter(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], label=lbl, color='blue', edgecolor='black', alpha=0.4, s=100, marker='o') + plt.plot(uvw1[:, 0], uvw1[:, 2] if UV else uvw1[:, 3], color='blue', alpha=0.4) + + # Scatter plot for uvw2 + plt.scatter(uvw2[:, 0], uvw2[:, 2] if UV else uvw2[:, 3], label=f'MS {n} to stack', color=colors[n], edgecolor='black', alpha=0.7, s = 40, marker='*') + + except: + pass + + # Adding labels and title + plt.xlabel("U (m)", fontsize=14) + plt.ylabel("V (m)" if UV else "W (m)", fontsize=14) + + # Adding grid + plt.grid(True, linestyle='--', alpha=0.6) + + # Adding legend + plt.legend(fontsize=12) + + plt.show() + + def parse_args(): """ Parse input arguments @@ -1377,7 +1423,7 @@ def ms_merger(): sys.exit("ERROR: --less_avg only used in combination with --advanced_stacking") avg = 1 else: - avg = get_avg_factor(args.msin, args.less_avg)*2 + avg = get_avg_factor(args.msin, args.less_avg) print(f"Intermediate averaging factor {avg}\n" f"Final averaging factor {max(int(avg * args.avg), 1) if args.advanced_stacking else int(args.avg)}") @@ -1391,7 +1437,7 @@ def ms_merger(): if args.record_time: start_time = time.time() s = Stack(args.msin, args.msout, chunkmem=args.chunk_mem) - s.stack_all() + s.stack_all(advanced=args.advanced_stacking) if args.record_time: end_time = time.time() elapsed_time = end_time - start_time