From 912528a4c35dea7eea9291b9b7e2836fc330dc59 Mon Sep 17 00:00:00 2001 From: jurjen93 Date: Mon, 5 Aug 2024 17:04:53 +0200 Subject: [PATCH] interpolate freq/time weights --- ms_helpers/default_StokesV.lua | 160 ++++++++++++++ ms_helpers/doppler_correction.py | 12 +- .../interpolate_flags_diff_timefreqres.py | 198 ++++++++++++++++++ 3 files changed, 365 insertions(+), 5 deletions(-) create mode 100644 ms_helpers/default_StokesV.lua create mode 100644 ms_helpers/interpolate_flags_diff_timefreqres.py diff --git a/ms_helpers/default_StokesV.lua b/ms_helpers/default_StokesV.lua new file mode 100644 index 00000000..22802667 --- /dev/null +++ b/ms_helpers/default_StokesV.lua @@ -0,0 +1,160 @@ +--[[ + This is the default AOFlagger strategy, version 2021-03-30 + Author: André Offringa + + This strategy is made as generic / easy to tweak as possible, with the most important +'tweaking' parameters available as variables at the beginning of function 'execute'. +]] + +aoflagger.require_min_version("3.0") + +function execute(input) + -- + -- Generic settings + -- + + -- What polarizations to flag? Default: input:get_polarizations() (=all that are in the input data) + -- Other options are e.g.: + -- { 'XY', 'YX' } to flag only XY and YX, or + -- { 'I', 'Q' } to flag only on Stokes I and Q + local flag_polarizations = {'V'} + + local base_threshold = 1.0 -- lower means more sensitive detection + -- How to flag complex values, options are: phase, amplitude, real, imaginary, complex + -- May have multiple values to perform detection multiple times + local flag_representations = { "amplitude" } + local iteration_count = 3 -- how many iterations to perform? + local threshold_factor_step = 2.0 -- How much to increase the sensitivity each iteration? + -- If the following variable is true, the strategy will consider existing flags + -- as bad data. It will exclude flagged data from detection, and make sure that any existing + -- flags on input will be flagged on output. If set to false, existing flags are ignored. + local exclude_original_flags = true + local frequency_resize_factor = 1.0 -- Amount of "extra" smoothing in frequency direction + local transient_threshold_factor = 1.0 -- decreasing this value makes detection of transient RFI more aggressive + + -- + -- End of generic settings + -- + + local inpPolarizations = input:get_polarizations() + + if not exclude_original_flags then + input:clear_mask() + end + -- For collecting statistics. Note that this is done after clear_mask(), + -- so that the statistics ignore any flags in the input data. + local copy_of_input = input:copy() + + for ipol, polarization in ipairs(flag_polarizations) do + local pol_data = input:convert_to_polarization(polarization) + local converted_data + local converted_copy + + for _, representation in ipairs(flag_representations) do + converted_data = pol_data:convert_to_complex(representation) + converted_copy = converted_data:copy() + + for i = 1, iteration_count - 1 do + local threshold_factor = threshold_factor_step ^ (iteration_count - i) + + local sumthr_level = threshold_factor * base_threshold + if exclude_original_flags then + aoflagger.sumthreshold_masked( + converted_data, + converted_copy, + sumthr_level, + sumthr_level * transient_threshold_factor, + true, + true + ) + else + aoflagger.sumthreshold(converted_data, sumthr_level, sumthr_level * transient_threshold_factor, true, true) + end + + -- Do timestep & channel flagging + local chdata = converted_data:copy() + aoflagger.threshold_timestep_rms(converted_data, 3.5) + aoflagger.threshold_channel_rms(chdata, 3.0 * threshold_factor, true) + converted_data:join_mask(chdata) + + -- High pass filtering steps + converted_data:set_visibilities(converted_copy) + if exclude_original_flags then + converted_data:join_mask(converted_copy) + end + + local resized_data = aoflagger.downsample(converted_data, 3, frequency_resize_factor, true) + aoflagger.low_pass_filter(resized_data, 21, 31, 2.5, 5.0) + aoflagger.upsample(resized_data, converted_data, 3, frequency_resize_factor) + + -- In case this script is run from inside rfigui, calling + -- the following visualize function will add the current result + -- to the list of displayable visualizations. + -- If the script is not running inside rfigui, the call is ignored. + aoflagger.visualize(converted_data, "Fit #" .. i, i - 1) + + local tmp = converted_copy - converted_data + tmp:set_mask(converted_data) + converted_data = tmp + + aoflagger.visualize(converted_data, "Residual #" .. i, i + iteration_count) + aoflagger.set_progress((ipol - 1) * iteration_count + i, #flag_polarizations * iteration_count) + end -- end of iterations + + if exclude_original_flags then + aoflagger.sumthreshold_masked( + converted_data, + converted_copy, + base_threshold, + base_threshold * transient_threshold_factor, + true, + true + ) + else + aoflagger.sumthreshold(converted_data, base_threshold, base_threshold * transient_threshold_factor, true, true) + end + end -- end of complex representation iteration + + if exclude_original_flags then + converted_data:join_mask(converted_copy) + end + + -- Helper function used below + function contains(arr, val) + for _, v in ipairs(arr) do + if v == val then + return true + end + end + return false + end + + if contains(inpPolarizations, polarization) then + if input:is_complex() then + converted_data = converted_data:convert_to_complex("complex") + end + input:set_polarization_data(polarization, converted_data) + else + input:join_mask(converted_data) + end + + aoflagger.visualize(converted_data, "Residual #" .. iteration_count, 2 * iteration_count) + aoflagger.set_progress(ipol, #flag_polarizations) + end -- end of polarization iterations + + if exclude_original_flags then + aoflagger.scale_invariant_rank_operator_masked(input, copy_of_input, 0.2, 0.2) + else + aoflagger.scale_invariant_rank_operator(input, 0.2, 0.2) + end + + aoflagger.threshold_timestep_rms(input, 4.0) + + if input:is_complex() and input:has_metadata() then + -- This command will calculate a few statistics like flag% and stddev over + -- time, frequency and baseline and write those to the MS. These can be + -- visualized with aoqplot. + aoflagger.collect_statistics(input, copy_of_input) + end + input:flag_nans() +end diff --git a/ms_helpers/doppler_correction.py b/ms_helpers/doppler_correction.py index b0287d4d..33afe936 100644 --- a/ms_helpers/doppler_correction.py +++ b/ms_helpers/doppler_correction.py @@ -92,8 +92,10 @@ def correct_doppler_shift_casacore(msname, restfreq, outms, frame='LSRK'): print(f"Doppler correction applied. Output MS saved as {outms}") -# Example usage: -msname = 'your_measurement_set.ms' -restfreq = '1420.40575177MHz' -outms = 'corrected_measurement_set.ms' -correct_doppler_shift_casacore(msname, restfreq, outms) \ No newline at end of file + +if __name__ == '__main__': + # Example usage: + msname = 'your_measurement_set.ms' + restfreq = '1420.40575177MHz' + outms = 'corrected_measurement_set.ms' + correct_doppler_shift_casacore(msname, restfreq, outms) \ No newline at end of file diff --git a/ms_helpers/interpolate_flags_diff_timefreqres.py b/ms_helpers/interpolate_flags_diff_timefreqres.py new file mode 100644 index 00000000..3a51231a --- /dev/null +++ b/ms_helpers/interpolate_flags_diff_timefreqres.py @@ -0,0 +1,198 @@ +""" +With this script you can flag data from a lower freq/time resolution to a higher one. +Make sure that both datasets have the same antennas with same antenna indices and originate from the same observation. + +Strategy: + 1) aoflagger with default_StokesV.lua strategy (Stokes V) + 2) interpolate the new flags to the output measurement set (higher freq/time resolution dataset) + +Usage: + python interpolate_flags_diff_timefreqres.py --msin dataset_lowres.ms --msout dataset_original.ms + + dataset_lowres.ms --> a dataset with a lower time/freq resolution which originates from dataset_original.ms + dataset_original.ms --> the original dataset +""" + +from casacore.tables import table +import numpy as np +from argparse import ArgumentParser +from subprocess import call +from scipy.interpolate import griddata +from sys import stdout + + +def run(command): + """ + Execute a shell command through subprocess + + Args: + command (str): the command to execute. + Returns: + None + """ + + retval = call(command, shell=True) + if retval != 0: + print('FAILED to run ' + command + ': return value is ' + str(retval)) + raise Exception(command) + return retval + + +def print_progress_bar(index, total, bar_length=50): + """ + Prints a progress bar to the console. + + :param::param: + - index: the current index (0-based) in the iteration. + - total: the total number of indices. + - bar_length: the character length of the progress bar (default 50). + """ + + percent_complete = (index + 1) / total + filled_length = int(bar_length * percent_complete) + bar = "█" * filled_length + '-' * (bar_length - filled_length) + stdout.write(f'\rProgress: |{bar}| {percent_complete * 100:.1f}% Complete') + stdout.flush() # Important to ensure the progress bar is updated in place + + # Print a new line on completion + if index == total - 1: + print() + + +def runaoflagger(ms, strategy='default_StokesV.lua'): + """ + Run aoglagger on a Measurement Set. + + Args: + mslist (list): list of Measurement Sets to iterate over. + Returns: + None + """ + + if strategy is not None: + cmd = 'aoflagger -strategy ' + strategy + ' ' + ms + else: + cmd = 'aoflagger ' + ms + print(cmd) + run(cmd) + return + + +def make_ant_pairs(n_ant, n_time): + """ + Generate ANTENNA1 and ANTENNA2 arrays for an array with M antennas over N time slots. + + :param: + - n_ant: Number of antennas in the array. + - n_int: Number of time slots. + + :return: + - ANTENNA1 + - ANTENNA2 + """ + + # Generate all unique pairs of antennas for one time slot + antenna_pairs = [(i, j) for i in range(n_ant) for j in range(i + 1, n_ant)] + + # Expand the pairs across n_time time slots + antenna1 = np.array([pair[0] for pair in antenna_pairs] * n_time) + antenna2 = np.array([pair[1] for pair in antenna_pairs] * n_time) + + return antenna1, antenna2 + + +def interpolate_weights(flagged_ms, ms): + """ + Args: + flagged_ms: measurement set from where to interpolate + ms: the pre-averaged measurement set + Returns: + interpolated weights + """ + + ants = table(flagged_ms + "::ANTENNA", ack=False) + baselines = np.c_[make_ant_pairs(ants.nrows(), 1)] + ants.close() + + t1 = table(flagged_ms, ack=False) + t2 = table(ms, ack=False, readonly=False) + + # Get freq axis first table + t = table(flagged_ms+'::SPECTRAL_WINDOW', ack=False) + freq_flagged_axis = t.getcol("CHAN_FREQ")[0] + t.close() + + # Get freq axis second table + t = table(ms+'::SPECTRAL_WINDOW', ack=False) + freq_axis = t.getcol("CHAN_FREQ")[0] + t.close() + + # Loop over baselines + for n, baseline in enumerate(baselines): + print_progress_bar(n, len(baselines)) + sub1 = t1.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}") + sub2 = t2.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}") + + time_flagged_axis = sub1.getcol("TIME") + data_flagged = np.take(sub1.getcol('FLAG'), indices=0, axis=-1).astype(int) + + time_axis = sub2.getcol("TIME") + data = np.take(sub2.getcol('FLAG'), indices=0, axis=-1).astype(int) + + # Create the grid for interpolation + grid_x, grid_y = np.meshgrid(freq_flagged_axis, time_flagged_axis) + + # Flatten the grid and data for griddata function + points = np.column_stack((grid_x.ravel(), grid_y.ravel())) + values = data_flagged.ravel() + + # Create the interpolation points + interp_grid_x, interp_grid_y = np.meshgrid(freq_axis, time_axis) + + # Perform the nearest-neighbor interpolation + new_flags = griddata(points, values, (interp_grid_x, interp_grid_y), method='nearest') + + # Reshape the data to match the original data shape + new_flags = new_flags.reshape(time_axis.size, freq_axis.size) + + # Apply the new flags to the data + data += new_flags + + # Store the updated data for the current baseline + sub2.putcol('FLAG', np.tile(np.expand_dims(np.clip(data, a_min=0, a_max=1), axis=-1), 4).astype(bool)) + + sub1.close() + sub2.close() + + t1.close() + t2.close() + + +def parse_args(): + """ + Parse input arguments + """ + + parser = ArgumentParser(description='Flag data from a lower freq/time resolution to a higher one') + parser.add_argument('--msin', help='MS input from where to interpolate') + parser.add_argument('--msout', help='MS output from where to apply new interpolated flags') + + return parser.parse_args() + + +def main(): + """ + Main script + """ + + args = parse_args() + + # run aoflagger on the input MS + runaoflagger(args.msin) + + # interpolate weights + interpolate_weights(args.msin, args.msout) + + +if __name__ == '__main__': + main()