-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
365 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
--[[ | ||
This is the default AOFlagger strategy, version 2021-03-30 | ||
Author: André Offringa | ||
This strategy is made as generic / easy to tweak as possible, with the most important | ||
'tweaking' parameters available as variables at the beginning of function 'execute'. | ||
]] | ||
|
||
aoflagger.require_min_version("3.0") | ||
|
||
function execute(input) | ||
-- | ||
-- Generic settings | ||
-- | ||
|
||
-- What polarizations to flag? Default: input:get_polarizations() (=all that are in the input data) | ||
-- Other options are e.g.: | ||
-- { 'XY', 'YX' } to flag only XY and YX, or | ||
-- { 'I', 'Q' } to flag only on Stokes I and Q | ||
local flag_polarizations = {'V'} | ||
|
||
local base_threshold = 1.0 -- lower means more sensitive detection | ||
-- How to flag complex values, options are: phase, amplitude, real, imaginary, complex | ||
-- May have multiple values to perform detection multiple times | ||
local flag_representations = { "amplitude" } | ||
local iteration_count = 3 -- how many iterations to perform? | ||
local threshold_factor_step = 2.0 -- How much to increase the sensitivity each iteration? | ||
-- If the following variable is true, the strategy will consider existing flags | ||
-- as bad data. It will exclude flagged data from detection, and make sure that any existing | ||
-- flags on input will be flagged on output. If set to false, existing flags are ignored. | ||
local exclude_original_flags = true | ||
local frequency_resize_factor = 1.0 -- Amount of "extra" smoothing in frequency direction | ||
local transient_threshold_factor = 1.0 -- decreasing this value makes detection of transient RFI more aggressive | ||
|
||
-- | ||
-- End of generic settings | ||
-- | ||
|
||
local inpPolarizations = input:get_polarizations() | ||
|
||
if not exclude_original_flags then | ||
input:clear_mask() | ||
end | ||
-- For collecting statistics. Note that this is done after clear_mask(), | ||
-- so that the statistics ignore any flags in the input data. | ||
local copy_of_input = input:copy() | ||
|
||
for ipol, polarization in ipairs(flag_polarizations) do | ||
local pol_data = input:convert_to_polarization(polarization) | ||
local converted_data | ||
local converted_copy | ||
|
||
for _, representation in ipairs(flag_representations) do | ||
converted_data = pol_data:convert_to_complex(representation) | ||
converted_copy = converted_data:copy() | ||
|
||
for i = 1, iteration_count - 1 do | ||
local threshold_factor = threshold_factor_step ^ (iteration_count - i) | ||
|
||
local sumthr_level = threshold_factor * base_threshold | ||
if exclude_original_flags then | ||
aoflagger.sumthreshold_masked( | ||
converted_data, | ||
converted_copy, | ||
sumthr_level, | ||
sumthr_level * transient_threshold_factor, | ||
true, | ||
true | ||
) | ||
else | ||
aoflagger.sumthreshold(converted_data, sumthr_level, sumthr_level * transient_threshold_factor, true, true) | ||
end | ||
|
||
-- Do timestep & channel flagging | ||
local chdata = converted_data:copy() | ||
aoflagger.threshold_timestep_rms(converted_data, 3.5) | ||
aoflagger.threshold_channel_rms(chdata, 3.0 * threshold_factor, true) | ||
converted_data:join_mask(chdata) | ||
|
||
-- High pass filtering steps | ||
converted_data:set_visibilities(converted_copy) | ||
if exclude_original_flags then | ||
converted_data:join_mask(converted_copy) | ||
end | ||
|
||
local resized_data = aoflagger.downsample(converted_data, 3, frequency_resize_factor, true) | ||
aoflagger.low_pass_filter(resized_data, 21, 31, 2.5, 5.0) | ||
aoflagger.upsample(resized_data, converted_data, 3, frequency_resize_factor) | ||
|
||
-- In case this script is run from inside rfigui, calling | ||
-- the following visualize function will add the current result | ||
-- to the list of displayable visualizations. | ||
-- If the script is not running inside rfigui, the call is ignored. | ||
aoflagger.visualize(converted_data, "Fit #" .. i, i - 1) | ||
|
||
local tmp = converted_copy - converted_data | ||
tmp:set_mask(converted_data) | ||
converted_data = tmp | ||
|
||
aoflagger.visualize(converted_data, "Residual #" .. i, i + iteration_count) | ||
aoflagger.set_progress((ipol - 1) * iteration_count + i, #flag_polarizations * iteration_count) | ||
end -- end of iterations | ||
|
||
if exclude_original_flags then | ||
aoflagger.sumthreshold_masked( | ||
converted_data, | ||
converted_copy, | ||
base_threshold, | ||
base_threshold * transient_threshold_factor, | ||
true, | ||
true | ||
) | ||
else | ||
aoflagger.sumthreshold(converted_data, base_threshold, base_threshold * transient_threshold_factor, true, true) | ||
end | ||
end -- end of complex representation iteration | ||
|
||
if exclude_original_flags then | ||
converted_data:join_mask(converted_copy) | ||
end | ||
|
||
-- Helper function used below | ||
function contains(arr, val) | ||
for _, v in ipairs(arr) do | ||
if v == val then | ||
return true | ||
end | ||
end | ||
return false | ||
end | ||
|
||
if contains(inpPolarizations, polarization) then | ||
if input:is_complex() then | ||
converted_data = converted_data:convert_to_complex("complex") | ||
end | ||
input:set_polarization_data(polarization, converted_data) | ||
else | ||
input:join_mask(converted_data) | ||
end | ||
|
||
aoflagger.visualize(converted_data, "Residual #" .. iteration_count, 2 * iteration_count) | ||
aoflagger.set_progress(ipol, #flag_polarizations) | ||
end -- end of polarization iterations | ||
|
||
if exclude_original_flags then | ||
aoflagger.scale_invariant_rank_operator_masked(input, copy_of_input, 0.2, 0.2) | ||
else | ||
aoflagger.scale_invariant_rank_operator(input, 0.2, 0.2) | ||
end | ||
|
||
aoflagger.threshold_timestep_rms(input, 4.0) | ||
|
||
if input:is_complex() and input:has_metadata() then | ||
-- This command will calculate a few statistics like flag% and stddev over | ||
-- time, frequency and baseline and write those to the MS. These can be | ||
-- visualized with aoqplot. | ||
aoflagger.collect_statistics(input, copy_of_input) | ||
end | ||
input:flag_nans() | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
""" | ||
With this script you can flag data from a lower freq/time resolution to a higher one. | ||
Make sure that both datasets have the same antennas with same antenna indices and originate from the same observation. | ||
Strategy: | ||
1) aoflagger with default_StokesV.lua strategy (Stokes V) | ||
2) interpolate the new flags to the output measurement set (higher freq/time resolution dataset) | ||
Usage: | ||
python interpolate_flags_diff_timefreqres.py --msin dataset_lowres.ms --msout dataset_original.ms | ||
dataset_lowres.ms --> a dataset with a lower time/freq resolution which originates from dataset_original.ms | ||
dataset_original.ms --> the original dataset | ||
""" | ||
|
||
from casacore.tables import table | ||
import numpy as np | ||
from argparse import ArgumentParser | ||
from subprocess import call | ||
from scipy.interpolate import griddata | ||
from sys import stdout | ||
|
||
|
||
def run(command): | ||
""" | ||
Execute a shell command through subprocess | ||
Args: | ||
command (str): the command to execute. | ||
Returns: | ||
None | ||
""" | ||
|
||
retval = call(command, shell=True) | ||
if retval != 0: | ||
print('FAILED to run ' + command + ': return value is ' + str(retval)) | ||
raise Exception(command) | ||
return retval | ||
|
||
|
||
def print_progress_bar(index, total, bar_length=50): | ||
""" | ||
Prints a progress bar to the console. | ||
:param::param: | ||
- index: the current index (0-based) in the iteration. | ||
- total: the total number of indices. | ||
- bar_length: the character length of the progress bar (default 50). | ||
""" | ||
|
||
percent_complete = (index + 1) / total | ||
filled_length = int(bar_length * percent_complete) | ||
bar = "█" * filled_length + '-' * (bar_length - filled_length) | ||
stdout.write(f'\rProgress: |{bar}| {percent_complete * 100:.1f}% Complete') | ||
stdout.flush() # Important to ensure the progress bar is updated in place | ||
|
||
# Print a new line on completion | ||
if index == total - 1: | ||
print() | ||
|
||
|
||
def runaoflagger(ms, strategy='default_StokesV.lua'): | ||
""" | ||
Run aoglagger on a Measurement Set. | ||
Args: | ||
mslist (list): list of Measurement Sets to iterate over. | ||
Returns: | ||
None | ||
""" | ||
|
||
if strategy is not None: | ||
cmd = 'aoflagger -strategy ' + strategy + ' ' + ms | ||
else: | ||
cmd = 'aoflagger ' + ms | ||
print(cmd) | ||
run(cmd) | ||
return | ||
|
||
|
||
def make_ant_pairs(n_ant, n_time): | ||
""" | ||
Generate ANTENNA1 and ANTENNA2 arrays for an array with M antennas over N time slots. | ||
:param: | ||
- n_ant: Number of antennas in the array. | ||
- n_int: Number of time slots. | ||
:return: | ||
- ANTENNA1 | ||
- ANTENNA2 | ||
""" | ||
|
||
# Generate all unique pairs of antennas for one time slot | ||
antenna_pairs = [(i, j) for i in range(n_ant) for j in range(i + 1, n_ant)] | ||
|
||
# Expand the pairs across n_time time slots | ||
antenna1 = np.array([pair[0] for pair in antenna_pairs] * n_time) | ||
antenna2 = np.array([pair[1] for pair in antenna_pairs] * n_time) | ||
|
||
return antenna1, antenna2 | ||
|
||
|
||
def interpolate_weights(flagged_ms, ms): | ||
""" | ||
Args: | ||
flagged_ms: measurement set from where to interpolate | ||
ms: the pre-averaged measurement set | ||
Returns: | ||
interpolated weights | ||
""" | ||
|
||
ants = table(flagged_ms + "::ANTENNA", ack=False) | ||
baselines = np.c_[make_ant_pairs(ants.nrows(), 1)] | ||
ants.close() | ||
|
||
t1 = table(flagged_ms, ack=False) | ||
t2 = table(ms, ack=False, readonly=False) | ||
|
||
# Get freq axis first table | ||
t = table(flagged_ms+'::SPECTRAL_WINDOW', ack=False) | ||
freq_flagged_axis = t.getcol("CHAN_FREQ")[0] | ||
t.close() | ||
|
||
# Get freq axis second table | ||
t = table(ms+'::SPECTRAL_WINDOW', ack=False) | ||
freq_axis = t.getcol("CHAN_FREQ")[0] | ||
t.close() | ||
|
||
# Loop over baselines | ||
for n, baseline in enumerate(baselines): | ||
print_progress_bar(n, len(baselines)) | ||
sub1 = t1.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}") | ||
sub2 = t2.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}") | ||
|
||
time_flagged_axis = sub1.getcol("TIME") | ||
data_flagged = np.take(sub1.getcol('FLAG'), indices=0, axis=-1).astype(int) | ||
|
||
time_axis = sub2.getcol("TIME") | ||
data = np.take(sub2.getcol('FLAG'), indices=0, axis=-1).astype(int) | ||
|
||
# Create the grid for interpolation | ||
grid_x, grid_y = np.meshgrid(freq_flagged_axis, time_flagged_axis) | ||
|
||
# Flatten the grid and data for griddata function | ||
points = np.column_stack((grid_x.ravel(), grid_y.ravel())) | ||
values = data_flagged.ravel() | ||
|
||
# Create the interpolation points | ||
interp_grid_x, interp_grid_y = np.meshgrid(freq_axis, time_axis) | ||
|
||
# Perform the nearest-neighbor interpolation | ||
new_flags = griddata(points, values, (interp_grid_x, interp_grid_y), method='nearest') | ||
|
||
# Reshape the data to match the original data shape | ||
new_flags = new_flags.reshape(time_axis.size, freq_axis.size) | ||
|
||
# Apply the new flags to the data | ||
data += new_flags | ||
|
||
# Store the updated data for the current baseline | ||
sub2.putcol('FLAG', np.tile(np.expand_dims(np.clip(data, a_min=0, a_max=1), axis=-1), 4).astype(bool)) | ||
|
||
sub1.close() | ||
sub2.close() | ||
|
||
t1.close() | ||
t2.close() | ||
|
||
|
||
def parse_args(): | ||
""" | ||
Parse input arguments | ||
""" | ||
|
||
parser = ArgumentParser(description='Flag data from a lower freq/time resolution to a higher one') | ||
parser.add_argument('--msin', help='MS input from where to interpolate') | ||
parser.add_argument('--msout', help='MS output from where to apply new interpolated flags') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
""" | ||
Main script | ||
""" | ||
|
||
args = parse_args() | ||
|
||
# run aoflagger on the input MS | ||
runaoflagger(args.msin) | ||
|
||
# interpolate weights | ||
interpolate_weights(args.msin, args.msout) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |