Skip to content

Commit

Permalink
got select on read working on the baseline axis
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Mar 21, 2024
1 parent bc69f98 commit 02fb517
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 26 deletions.
163 changes: 144 additions & 19 deletions pyuvdata/uvdata/mwa_corr_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .. import telescopes as uvtel
from .. import utils as uvutils
from ..docstrings import copy_replace_short_description
from .uvdata import UVData, _future_array_shapes_warning
from .uvdata import UVData, _future_array_shapes_warning, _select_blt_preprocess

__all__ = ["input_output_mapping", "MWACorrFITS"]

Expand Down Expand Up @@ -66,6 +66,7 @@ def read_metafits(
meta_tbl = meta[1].data

# because of polarization, each antenna # is listed twice
# antenna_inds are the correlator input numbers.
antenna_inds = meta_tbl["Antenna"][1::2]
antenna_numbers = meta_tbl["Tile"][1::2]
antenna_names = meta_tbl["TileName"][1::2]
Expand Down Expand Up @@ -629,6 +630,7 @@ def flag_init(

def _read_fits_file(
self,
*,
filename,
time_array,
file_nums,
Expand All @@ -638,6 +640,7 @@ def _read_fits_file(
map_inds,
conj,
pol_index_array,
bl_inds=None,
):
"""
Read the fits file and populate into memory.
Expand Down Expand Up @@ -666,6 +669,8 @@ def _read_fits_file(
Indices for conjugating data_array from weird correlator packing.
pol_index_array : array
Indices for reordering polarizations to the 'AIPS' convention
bl_inds : array, optional
Baseline indices (after any re-mapping) to select on read.
"""
# get the file number from the file name
Expand All @@ -688,6 +693,34 @@ def _read_fits_file(
(self.Ntimes, num_fine_chans, self.Nbls * self.Npols),
dtype=np.complex64,
)

if bl_inds is not None:
bl_inds = np.array(bl_inds)
if not mwax:
# map_inds gives the baseline-pol ordering
n_orig_bls = int(
len(self.antenna_numbers) * (len(self.antenna_numbers) + 1) / 2.0
)
# reshape, do selection along bl axis, then flatten
bl_inds_map = np.take(
map_inds.reshape(n_orig_bls, self.Npols), bl_inds, axis=0
).flatten()
conj = np.take(
conj.reshape(n_orig_bls, self.Npols), bl_inds, axis=0
).flatten()

# The data array is written with real, imaginary parts interleaved.
# This corresponds to a 2d array flattened where the last axis is
# real, imaginary
# So the indices need to be updated for that structure.
bl_inds_map_ri = np.concatenate(
(
bl_inds_map[:, np.newaxis] * 2,
bl_inds_map[:, np.newaxis] * 2 + 1,
),
axis=1,
).flatten()

with fits.open(filename, mode="denywrite") as hdu_list:
# if mwax, data is in every other hdu
if mwax:
Expand All @@ -704,31 +737,46 @@ def _read_fits_file(
time_ind = np.where(time_array == time)[0][0]
# dump data into matrix
# and take data from real to complex numbers
coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data
if bl_inds is not None:
if not mwax:
coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data[
:, bl_inds_map_ri
]
else:
temp_data = hdu.data[bl_inds]
coarse_chan_data.view(np.float32)[time_ind, :, :] = temp_data
else:
coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data
# fill nsample and flag arrays
# think about using the mwax weights array in the future
self.nsample_array[
time_ind, :, freq_ind : freq_ind + num_fine_chans, :
] = 1.0
self.flag_array[time_ind, :, coarse_ind, :] = False
if not mwax:
# do mapping and reshaping here to avoid copying whole data_array
np.take(coarse_chan_data, map_inds, axis=2, out=coarse_chan_data)
if bl_inds is None:
# do mapping and reshaping here to avoid copying whole data_array
# map_inds gives the baseline-pol ordering
np.take(coarse_chan_data, map_inds, axis=2, out=coarse_chan_data)
# conjugate data
coarse_chan_data[:, :, conj] = np.conj(coarse_chan_data[:, :, conj])
# reshape
# each time gets its own HDU. MWAX has 2 HDUs per time (data/weights alternate)
if mwax:
# freq and pol axes are combined, baseline axis is separate
coarse_chan_data = coarse_chan_data.reshape(
(self.Ntimes, self.Nbls, num_fine_chans, self.Npols)
)
else:
# freq axis, then baseline-pol axis
coarse_chan_data = coarse_chan_data.reshape(
(self.Ntimes, num_fine_chans, self.Nbls, self.Npols)
)
coarse_chan_data = np.swapaxes(coarse_chan_data, 1, 2)
coarse_chan_data = coarse_chan_data.reshape(
self.Nblts, num_fine_chans, self.Npols
)

# reorder pols here to avoid memory spike from self.reorder_pols
np.take(coarse_chan_data, pol_index_array, axis=-1, out=coarse_chan_data)
# make a mask where data actually is so coarse channels that
Expand Down Expand Up @@ -759,7 +807,7 @@ def _read_flag_file(self, filename, file_nums, num_fine_chans):
freq_ind = np.where(file_nums == flag_num)[0][0] * num_fine_chans
with fits.open(filename, mode="denywrite") as aoflags:
flags = aoflags[1].data.field("FLAGS")
# some flag files are longer than data; crop the ends
# some flag files are longer than data; crop the end
flags = flags[: self.Nblts, :]
# some flag files are shorter than data; assume same end time
blt_ind = self.Nblts - len(flags)
Expand Down Expand Up @@ -1269,6 +1317,17 @@ def _apply_corrections(
def read_mwa_corr_fits(
self,
filelist,
antenna_nums=None,
antenna_names=None,
bls=None,
# frequencies=None,
# freq_chans=None,
# times=None,
# time_range=None,
# lsts=None,
# lst_range=None,
# polarizations=None,
keep_all_metadata=True,
use_aoflagger_flags=None,
remove_dig_gains=True,
remove_coarse_band=True,
Expand Down Expand Up @@ -1322,6 +1381,12 @@ def read_mwa_corr_fits(
if start_flag != "goodtime":
raise ValueError("start_flag must be int or float or 'goodtime'")

# check that bls are a list of 2-tuples as required by _select_blt_preprocess
if bls is not None and not all(len(item) == 2 for item in bls):
raise ValueError(
"bls must be a list of 2-tuples giving antenna number pairs"
)

# set future array shapes
self._set_future_array_shapes()

Expand Down Expand Up @@ -1373,6 +1438,8 @@ def read_mwa_corr_fits(
# check headers for first and last times containing data
headstart = hdu_list[1].header
headfin = hdu_list[-1].header
# start & end times are for the full file set
# first & last are for this file
first_time = headstart["TIME"] + headstart["MILLITIM"] / 1000.0
last_time = headfin["TIME"] + headfin["MILLITIM"] / 1000.0
if start_time == 0.0:
Expand Down Expand Up @@ -1561,8 +1628,6 @@ def read_mwa_corr_fits(
ant_1_inds, ant_2_inds = np.transpose(
list(itertools.combinations_with_replacement(np.arange(self.Nants_data), 2))
)
ant_1_inds = np.tile(np.array(ant_1_inds), self.Ntimes).astype(np.int_)
ant_2_inds = np.tile(np.array(ant_2_inds), self.Ntimes).astype(np.int_)

if not mwax:
# coarse channel mapping for the legacy correlator:
Expand Down Expand Up @@ -1647,7 +1712,6 @@ def read_mwa_corr_fits(
corr_ants_to_pfb_inputs[(meta_dict["antenna_inds"][i], p)] = (
2 * i + p
)

# for mapping, start with a pair of antennas/polarizations
# this is the pair we want to find the data for
# map the pair to the corresponding coarse pfb input indices
Expand All @@ -1669,6 +1733,55 @@ def read_mwa_corr_fits(
else:
map_inds = None
conj = None

# check if we want to do any select on the baseline axis
# Note: only passing the ant_1/2_arrays and baseline_array for one time.
bl_inds, bl_selections = _select_blt_preprocess(
select_antenna_nums=antenna_nums,
select_antenna_names=antenna_names,
bls=bls,
times=None,
time_range=None,
lsts=None,
lst_range=None,
blt_inds=None,
phase_center_ids=None,
antenna_names=self.antenna_names,
antenna_numbers=self.antenna_numbers,
ant_1_array=ant_1_array,
ant_2_array=ant_2_array,
baseline_array=self.antnums_to_baseline(ant_1_array, ant_2_array),
time_array=self.time_array,
time_tols=self._time_array.tols,
lst_array=self.lst_array,
lst_tols=self._lst_array.tols,
phase_center_id_array=self.phase_center_id_array,
)

if bl_inds is not None:
ant_1_inds = ant_1_inds[bl_inds]
ant_2_inds = ant_2_inds[bl_inds]

history_update_string = (
" Downselected to specific "
+ ", ".join(bl_selections)
+ " using pyuvdata."
)
# do select operations on everything except data_array, flag_array
# and nsample_array
blt_inds = np.take(
np.arange(self.Nblts).reshape(self.Ntimes, self.Nbls),
bl_inds,
axis=1,
).flatten()
self._select_by_index(
blt_inds, None, None, history_update_string, keep_all_metadata
)

ant_1_inds = np.tile(np.array(ant_1_inds), self.Ntimes).astype(np.int_)
ant_2_inds = np.tile(np.array(ant_2_inds), self.Ntimes).astype(np.int_)

if read_data:
# create arrays for data, nsamples, and flags
self.data_array = np.zeros(
(self.Nblts, self.Nfreqs, self.Npols), dtype=data_array_dtype
Expand All @@ -1684,15 +1797,16 @@ def read_mwa_corr_fits(
# read data files
for filename in file_dict["data"]:
self._read_fits_file(
filename,
time_array,
file_nums,
num_fine_chans,
meta_dict["int_time"],
mwax,
map_inds,
conj,
pol_index_array,
filename=filename,
time_array=time_array,
file_nums=file_nums,
num_fine_chans=num_fine_chans,
int_time=meta_dict["int_time"],
mwax=mwax,
map_inds=map_inds,
conj=conj,
pol_index_array=pol_index_array,
bl_inds=bl_inds,
)

# propagate coarse flags
Expand Down Expand Up @@ -1804,9 +1918,20 @@ def read_mwa_corr_fits(

# remove bad antennas
# select must be called after lst thread is re-joined
if remove_flagged_ants:
if (
remove_flagged_ants
and meta_dict["flagged_ant_inds"].size > 0
and np.sum(
np.isin(
meta_dict["flagged_ant_inds"],
np.union1d(self.ant_1_array, self.ant_2_array),
)
)
> 0
):
good_ants = np.delete(
np.array(self.antenna_numbers), meta_dict["flagged_ant_inds"]
np.union1d(self.ant_1_array, self.ant_2_array),
meta_dict["flagged_ant_inds"],
)
self.select(antenna_nums=good_ants, run_check=False)

Expand Down
65 changes: 64 additions & 1 deletion pyuvdata/uvdata/tests/test_mwa_corr_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def test_select_on_read():
with uvtest.check_warnings(
UserWarning,
[
'Warning: select on read keyword set, but file_type is "mwa_corr_fits"',
"Warning: a select on read keyword is set that is not supported by "
"read_mwa_corr_fits. This select will be done after reading the file.",
"some coarse channel files were not submitted",
],
):
Expand Down Expand Up @@ -1234,3 +1235,65 @@ def test_bscale(tmp_path):
# check mwax data
uv4.read(filelist[11:13], use_future_array_shapes=True)
assert "SCALEFAC" not in uv4.extra_keywords.keys()


@pytest.mark.filterwarnings("ignore:some coarse channel files were not submitted")
@pytest.mark.filterwarnings("ignore:Fixing auto-correlations to be be real-only")
@pytest.mark.parametrize(
["select_kwargs", "warn_msg"],
[
[{"antenna_nums": [18, 31, 66, 95]}, ""],
[{"antenna_names": [f"Tile{ant:03d}" for ant in [18, 31, 66, 95]]}, ""],
[{"bls": [(48, 34), (96, 11), (22, 87)]}, ""],
[
{"ant_str": "48_34,96_11,22_87"},
"a select on read keyword is set that is not supported by "
"read_mwa_corr_fits. This select will be done after reading the file.",
],
],
)
@pytest.mark.parametrize("mwax", [False, True])
def test_partial_read_bl_axis(tmp_path, flag_file_init, mwax, select_kwargs, warn_msg):

if mwax:
# generate a spoof file with 16 channels
cb_spoof = str(tmp_path / "mwax_cb_spoof80_ch137_000.fits")
meta_spoof = str(tmp_path / "mwax_cb_spoof80.metafits")

with fits.open(filelist[12]) as mini1:
mini1[1].data = np.repeat(mini1[1].data, 16, axis=1)
mini1.writeto(cb_spoof)

with fits.open(filelist[11]) as meta:
meta[0].header["FINECHAN"] = 80
meta.writeto(meta_spoof)

files_use = [meta_spoof, cb_spoof]

else:
cb_spoof = str(tmp_path / "cb_spoof_01_00.fits")
with fits.open(filelist[1]) as mini1:
mini1[1].data = np.repeat(mini1[1].data, 32, axis=0)
mini1.writeto(cb_spoof)
files_use = [filelist[0], cb_spoof]

uv_full = UVData.from_file(files_use, use_future_array_shapes=True)

warn_msg_list = ["some coarse channel files were not submitted"]
if warn_msg != "":
warn_msg_list.append(warn_msg)

if mwax and "bls" not in select_kwargs.keys():
# The bls selection has no autos
warn_msg_list.append("Fixing auto-correlations to be be real-only")

with uvtest.check_warnings(UserWarning, match=warn_msg_list):
uv_partial = UVData.from_file(
files_use, use_future_array_shapes=True, **select_kwargs
)
exp_uv = uv_full.select(**select_kwargs, inplace=False)

# history doesn't match because of different order of operations.
exp_uv.history = uv_partial.history

assert uv_partial == exp_uv
Loading

0 comments on commit 02fb517

Please sign in to comment.