Skip to content

Commit

Permalink
concat upd flag
Browse files Browse the repository at this point in the history
  • Loading branch information
jurjen93 committed Oct 24, 2024
1 parent 93dc922 commit cd2045f
Showing 1 changed file with 77 additions and 9 deletions.
86 changes: 77 additions & 9 deletions ms_helpers/concat_with_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from glob import glob
from pprint import pprint

import casacore.tables as ct
from casacore.tables import table
import numpy as np


def case_insensitive_replace(text, old, new):
"""
Case insensivitve_replace function
"""
return re.sub(re.escape(old), new, text, flags=re.IGNORECASE)


Expand Down Expand Up @@ -57,7 +60,7 @@ def get_channels(input_ms):

# collect all frequencies in 1 numpy array
for n, ms in enumerate(input_ms):
with ct.table(ms + '/::SPECTRAL_WINDOW', ack=False) as t:
with table(ms + '/::SPECTRAL_WINDOW', ack=False) as t:

if n == 0:
chans = t.getcol("CHAN_FREQ")
Expand All @@ -68,7 +71,7 @@ def get_channels(input_ms):
return np.sort(chans), input_ms


def fill_freq_gaps(input, make_dummies, output_name, only_basename):
def fill_freq_gaps(input: list = None, make_dummies: bool = None, output_name: str= None, only_basename: bool = None):
"""
Fill the frequency gaps between sub-blocks with dummies (if requested)
and return txt file with MS in order
Expand Down Expand Up @@ -118,7 +121,7 @@ def fill_freq_gaps(input, make_dummies, output_name, only_basename):
return True


def split_ms_phasedir(mslist):
def split_ms_phasedir(mslist: list = None):
"""
Separate MS with different phase centers
Expand All @@ -129,7 +132,7 @@ def split_ms_phasedir(mslist):

d = {}
for ms in mslist:
t = ct.table(f"{ms}::FIELD", ack=False)
t = table(f"{ms}::FIELD", ack=False)
pd = ''.join([str(round(i, 6)) for i in t.getcol("PHASE_DIR").squeeze()])
if pd not in d.keys():
d.update({pd: [ms]})
Expand All @@ -138,6 +141,53 @@ def split_ms_phasedir(mslist):
return d


def remove_flagged_antennas(msin: str = None):
"""
Remove antennas that are full flagged (to save storage)
input:
- msfile: measurement set name
"""

# Read antenna names from Measurement Set
with table(f"{msin}::ANTENNA", ack=False) as ants:
ants_names = ants.getcol("NAME")

# Read main tables Measurement Set
with table(msin, readonly=True, ack=False) as ms:
# Read the antenna ID columns
antenna1 = ms.getcol('ANTENNA1')
antenna2 = ms.getcol('ANTENNA2')

# Read the FLAG column
flags = ms.getcol('FLAG')

# Get the unique antenna indices
unique_antennas = np.unique(np.concatenate((antenna1, antenna2)))

# Identify fully flagged antennas
fully_flagged_antennas = []
for ant in unique_antennas:
# Find rows with this antenna
ant_rows = np.where((antenna1 == ant) | (antenna2 == ant))
# Check if all data for this antenna is flagged
if np.all(flags[ant_rows]):
fully_flagged_antennas.append(ant)

if len(fully_flagged_antennas) == 0:
print(f'No flagged antennas for {msin}, move on.')
return None

else:
# Get names of ants to filter
ants_to_filter = ','.join([ants_names[idx] for idx in fully_flagged_antennas])
print(f"Filtering fully flagged antennas: {ants_to_filter}")

# Run DP3
return f'\nfilter.type=filter\nfilter.remove=true\nfilter.baseline=!{ants_to_filter}'



def parse_args():
"""
Parse input arguments
Expand All @@ -153,13 +203,16 @@ def parse_args():
parser.add_argument('--freq_avg', help='Frequency averaging factor', type=int)
parser.add_argument('--time_res', help='Time resolution (in seconds)', type=int)
parser.add_argument('--freq_res', help='Frequency resolution', type=str)
parser.add_argument('--remove_flagged_station', action='store_true', help='Remove flagged station (save output)')
parser.add_argument('--make_only_parset', action='store_true', help='Make only parset')
parser.add_argument('--only_basename', action='store_true', help='Return only basename of msin')

return parser.parse_args()


def make_parset(mss, concat_name, data_column, time_avg, freq_avg, time_res, freq_res, phase_center, only_basename):
def make_parset(mss: list = None, concat_name: str = None, data_column: str = None,
time_avg: int= None, freq_avg: int = None, time_res=None, freq_res=None, phase_center: str = None,
only_basename: bool = None, remove_flagged_station: bool = None):
"""
Make parset for DP3
Expand All @@ -172,6 +225,7 @@ def make_parset(mss, concat_name, data_column, time_avg, freq_avg, time_res, fre
:param freq_res: frequency resolution
:param phase_center: phase center
:param only_basename: return only basename
:param remove_flagged_station: remove station when fully flagged
:return: parset
"""
Expand All @@ -185,9 +239,16 @@ def make_parset(mss, concat_name, data_column, time_avg, freq_avg, time_res, fre

if concat_name is None:

# Special case
# Parse facet + L-number (if available). Specifically for VLBI pipeline
matchf = re.search(r'facet_\d{2}-', ms[0].split('/')[-1])
matchL = re.search(r'L\d{6}', ms[0].split('/')[-1])
if matchL is None:
with table(ms[0]+"::OBSERVATION", ack=False) as t:
try:
matchL = t.getcol("LOFAR_FILENAME")[0].split("_")[0]
except RuntimeError:
matchL = None

if matchf is not None and matchL is not None:
concatname = matchf.group()+matchL.group()+'.concat.ms'

Expand Down Expand Up @@ -246,7 +307,7 @@ def make_parset(mss, concat_name, data_column, time_avg, freq_avg, time_res, fre
if time_res is not None:
parset += f'\navg.timeresolution={time_res}'
if freq_avg is not None:
with ct.table(ms[0] + "::SPECTRAL_WINDOW", ack=False) as t:
with table(ms[0] + "::SPECTRAL_WINDOW", ack=False) as t:
channum = len(t.getcol("CHAN_FREQ")[0])
freqavg = get_largest_divider(channum, freq_avg + 1)
if freqavg!=freq_avg:
Expand All @@ -255,6 +316,13 @@ def make_parset(mss, concat_name, data_column, time_avg, freq_avg, time_res, fre
if freq_res is not None:
parset += f'\navg.freqresolution={freq_res}'

# Remove station when fully flagged
if remove_flagged_station:
rmv = remove_flagged_antennas(ms[0])
if rmv is not None:
parset += rmv
steps.append('filter')

parset += '\nsteps='+str(steps).replace(" ", "").replace("'", "")
with open(parsetname, 'w') as f:
f.write(parset)
Expand All @@ -271,7 +339,7 @@ def main():
args = parse_args()
parsets = make_parset(args.msin, args.msout, args.data_column,
args.time_avg, args.freq_avg, args.time_res,
args.freq_res, args.phase_center, args.only_basename)
args.freq_res, args.phase_center, args.only_basename, args.remove_flagged_station)
if not args.make_only_parset:
for parset in parsets:
os.system('DP3 ' + parset)
Expand Down

0 comments on commit cd2045f

Please sign in to comment.