Skip to content

Commit

Permalink
interpolate freq/time weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Aug 5, 2024
1 parent 7d45b0e commit 912528a
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 5 deletions.
160 changes: 160 additions & 0 deletions ms_helpers/default_StokesV.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
--[[
This is the default AOFlagger strategy, version 2021-03-30
Author: André Offringa
This strategy is made as generic / easy to tweak as possible, with the most important
'tweaking' parameters available as variables at the beginning of function 'execute'.
]]

aoflagger.require_min_version("3.0")

function execute(input)
--
-- Generic settings
--

-- What polarizations to flag? Default: input:get_polarizations() (=all that are in the input data)
-- Other options are e.g.:
-- { 'XY', 'YX' } to flag only XY and YX, or
-- { 'I', 'Q' } to flag only on Stokes I and Q
local flag_polarizations = {'V'}

local base_threshold = 1.0 -- lower means more sensitive detection
-- How to flag complex values, options are: phase, amplitude, real, imaginary, complex
-- May have multiple values to perform detection multiple times
local flag_representations = { "amplitude" }
local iteration_count = 3 -- how many iterations to perform?
local threshold_factor_step = 2.0 -- How much to increase the sensitivity each iteration?
-- If the following variable is true, the strategy will consider existing flags
-- as bad data. It will exclude flagged data from detection, and make sure that any existing
-- flags on input will be flagged on output. If set to false, existing flags are ignored.
local exclude_original_flags = true
local frequency_resize_factor = 1.0 -- Amount of "extra" smoothing in frequency direction
local transient_threshold_factor = 1.0 -- decreasing this value makes detection of transient RFI more aggressive

--
-- End of generic settings
--

local inpPolarizations = input:get_polarizations()

if not exclude_original_flags then
input:clear_mask()
end
-- For collecting statistics. Note that this is done after clear_mask(),
-- so that the statistics ignore any flags in the input data.
local copy_of_input = input:copy()

for ipol, polarization in ipairs(flag_polarizations) do
local pol_data = input:convert_to_polarization(polarization)
local converted_data
local converted_copy

for _, representation in ipairs(flag_representations) do
converted_data = pol_data:convert_to_complex(representation)
converted_copy = converted_data:copy()

for i = 1, iteration_count - 1 do
local threshold_factor = threshold_factor_step ^ (iteration_count - i)

local sumthr_level = threshold_factor * base_threshold
if exclude_original_flags then
aoflagger.sumthreshold_masked(
converted_data,
converted_copy,
sumthr_level,
sumthr_level * transient_threshold_factor,
true,
true
)
else
aoflagger.sumthreshold(converted_data, sumthr_level, sumthr_level * transient_threshold_factor, true, true)
end

-- Do timestep & channel flagging
local chdata = converted_data:copy()
aoflagger.threshold_timestep_rms(converted_data, 3.5)
aoflagger.threshold_channel_rms(chdata, 3.0 * threshold_factor, true)
converted_data:join_mask(chdata)

-- High pass filtering steps
converted_data:set_visibilities(converted_copy)
if exclude_original_flags then
converted_data:join_mask(converted_copy)
end

local resized_data = aoflagger.downsample(converted_data, 3, frequency_resize_factor, true)
aoflagger.low_pass_filter(resized_data, 21, 31, 2.5, 5.0)
aoflagger.upsample(resized_data, converted_data, 3, frequency_resize_factor)

-- In case this script is run from inside rfigui, calling
-- the following visualize function will add the current result
-- to the list of displayable visualizations.
-- If the script is not running inside rfigui, the call is ignored.
aoflagger.visualize(converted_data, "Fit #" .. i, i - 1)

local tmp = converted_copy - converted_data
tmp:set_mask(converted_data)
converted_data = tmp

aoflagger.visualize(converted_data, "Residual #" .. i, i + iteration_count)
aoflagger.set_progress((ipol - 1) * iteration_count + i, #flag_polarizations * iteration_count)
end -- end of iterations

if exclude_original_flags then
aoflagger.sumthreshold_masked(
converted_data,
converted_copy,
base_threshold,
base_threshold * transient_threshold_factor,
true,
true
)
else
aoflagger.sumthreshold(converted_data, base_threshold, base_threshold * transient_threshold_factor, true, true)
end
end -- end of complex representation iteration

if exclude_original_flags then
converted_data:join_mask(converted_copy)
end

-- Helper function used below
function contains(arr, val)
for _, v in ipairs(arr) do
if v == val then
return true
end
end
return false
end

if contains(inpPolarizations, polarization) then
if input:is_complex() then
converted_data = converted_data:convert_to_complex("complex")
end
input:set_polarization_data(polarization, converted_data)
else
input:join_mask(converted_data)
end

aoflagger.visualize(converted_data, "Residual #" .. iteration_count, 2 * iteration_count)
aoflagger.set_progress(ipol, #flag_polarizations)
end -- end of polarization iterations

if exclude_original_flags then
aoflagger.scale_invariant_rank_operator_masked(input, copy_of_input, 0.2, 0.2)
else
aoflagger.scale_invariant_rank_operator(input, 0.2, 0.2)
end

aoflagger.threshold_timestep_rms(input, 4.0)

if input:is_complex() and input:has_metadata() then
-- This command will calculate a few statistics like flag% and stddev over
-- time, frequency and baseline and write those to the MS. These can be
-- visualized with aoqplot.
aoflagger.collect_statistics(input, copy_of_input)
end
input:flag_nans()
end
12 changes: 7 additions & 5 deletions ms_helpers/doppler_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def correct_doppler_shift_casacore(msname, restfreq, outms, frame='LSRK'):

print(f"Doppler correction applied. Output MS saved as {outms}")

# Example usage:
msname = 'your_measurement_set.ms'
restfreq = '1420.40575177MHz'
outms = 'corrected_measurement_set.ms'
correct_doppler_shift_casacore(msname, restfreq, outms)

if __name__ == '__main__':
# Example usage:
msname = 'your_measurement_set.ms'
restfreq = '1420.40575177MHz'
outms = 'corrected_measurement_set.ms'
correct_doppler_shift_casacore(msname, restfreq, outms)
198 changes: 198 additions & 0 deletions ms_helpers/interpolate_flags_diff_timefreqres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
With this script you can flag data from a lower freq/time resolution to a higher one.
Make sure that both datasets have the same antennas with same antenna indices and originate from the same observation.
Strategy:
1) aoflagger with default_StokesV.lua strategy (Stokes V)
2) interpolate the new flags to the output measurement set (higher freq/time resolution dataset)
Usage:
python interpolate_flags_diff_timefreqres.py --msin dataset_lowres.ms --msout dataset_original.ms
dataset_lowres.ms --> a dataset with a lower time/freq resolution which originates from dataset_original.ms
dataset_original.ms --> the original dataset
"""

from casacore.tables import table
import numpy as np
from argparse import ArgumentParser
from subprocess import call
from scipy.interpolate import griddata
from sys import stdout


def run(command):
"""
Execute a shell command through subprocess
Args:
command (str): the command to execute.
Returns:
None
"""

retval = call(command, shell=True)
if retval != 0:
print('FAILED to run ' + command + ': return value is ' + str(retval))
raise Exception(command)
return retval


def print_progress_bar(index, total, bar_length=50):
"""
Prints a progress bar to the console.
:param::param:
- index: the current index (0-based) in the iteration.
- total: the total number of indices.
- bar_length: the character length of the progress bar (default 50).
"""

percent_complete = (index + 1) / total
filled_length = int(bar_length * percent_complete)
bar = "█" * filled_length + '-' * (bar_length - filled_length)
stdout.write(f'\rProgress: |{bar}| {percent_complete * 100:.1f}% Complete')
stdout.flush() # Important to ensure the progress bar is updated in place

# Print a new line on completion
if index == total - 1:
print()


def runaoflagger(ms, strategy='default_StokesV.lua'):
"""
Run aoglagger on a Measurement Set.
Args:
mslist (list): list of Measurement Sets to iterate over.
Returns:
None
"""

if strategy is not None:
cmd = 'aoflagger -strategy ' + strategy + ' ' + ms
else:
cmd = 'aoflagger ' + ms
print(cmd)
run(cmd)
return


def make_ant_pairs(n_ant, n_time):
"""
Generate ANTENNA1 and ANTENNA2 arrays for an array with M antennas over N time slots.
:param:
- n_ant: Number of antennas in the array.
- n_int: Number of time slots.
:return:
- ANTENNA1
- ANTENNA2
"""

# Generate all unique pairs of antennas for one time slot
antenna_pairs = [(i, j) for i in range(n_ant) for j in range(i + 1, n_ant)]

# Expand the pairs across n_time time slots
antenna1 = np.array([pair[0] for pair in antenna_pairs] * n_time)
antenna2 = np.array([pair[1] for pair in antenna_pairs] * n_time)

return antenna1, antenna2


def interpolate_weights(flagged_ms, ms):
"""
Args:
flagged_ms: measurement set from where to interpolate
ms: the pre-averaged measurement set
Returns:
interpolated weights
"""

ants = table(flagged_ms + "::ANTENNA", ack=False)
baselines = np.c_[make_ant_pairs(ants.nrows(), 1)]
ants.close()

t1 = table(flagged_ms, ack=False)
t2 = table(ms, ack=False, readonly=False)

# Get freq axis first table
t = table(flagged_ms+'::SPECTRAL_WINDOW', ack=False)
freq_flagged_axis = t.getcol("CHAN_FREQ")[0]
t.close()

# Get freq axis second table
t = table(ms+'::SPECTRAL_WINDOW', ack=False)
freq_axis = t.getcol("CHAN_FREQ")[0]
t.close()

# Loop over baselines
for n, baseline in enumerate(baselines):
print_progress_bar(n, len(baselines))
sub1 = t1.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}")
sub2 = t2.query(f"ANTENNA1={baseline[0]} AND ANTENNA2={baseline[1]}")

time_flagged_axis = sub1.getcol("TIME")
data_flagged = np.take(sub1.getcol('FLAG'), indices=0, axis=-1).astype(int)

time_axis = sub2.getcol("TIME")
data = np.take(sub2.getcol('FLAG'), indices=0, axis=-1).astype(int)

# Create the grid for interpolation
grid_x, grid_y = np.meshgrid(freq_flagged_axis, time_flagged_axis)

# Flatten the grid and data for griddata function
points = np.column_stack((grid_x.ravel(), grid_y.ravel()))
values = data_flagged.ravel()

# Create the interpolation points
interp_grid_x, interp_grid_y = np.meshgrid(freq_axis, time_axis)

# Perform the nearest-neighbor interpolation
new_flags = griddata(points, values, (interp_grid_x, interp_grid_y), method='nearest')

# Reshape the data to match the original data shape
new_flags = new_flags.reshape(time_axis.size, freq_axis.size)

# Apply the new flags to the data
data += new_flags

# Store the updated data for the current baseline
sub2.putcol('FLAG', np.tile(np.expand_dims(np.clip(data, a_min=0, a_max=1), axis=-1), 4).astype(bool))

sub1.close()
sub2.close()

t1.close()
t2.close()


def parse_args():
"""
Parse input arguments
"""

parser = ArgumentParser(description='Flag data from a lower freq/time resolution to a higher one')
parser.add_argument('--msin', help='MS input from where to interpolate')
parser.add_argument('--msout', help='MS output from where to apply new interpolated flags')

return parser.parse_args()


def main():
"""
Main script
"""

args = parse_args()

# run aoflagger on the input MS
runaoflagger(args.msin)

# interpolate weights
interpolate_weights(args.msin, args.msout)


if __name__ == '__main__':
main()

0 comments on commit 912528a

Please sign in to comment.