diff --git a/pyuvdata/uvdata/mwa_corr_fits.py b/pyuvdata/uvdata/mwa_corr_fits.py index 542b8139d0..5b0c0c929e 100644 --- a/pyuvdata/uvdata/mwa_corr_fits.py +++ b/pyuvdata/uvdata/mwa_corr_fits.py @@ -29,7 +29,7 @@ from .. import telescopes as uvtel from .. import utils as uvutils from ..docstrings import copy_replace_short_description -from .uvdata import UVData, _future_array_shapes_warning +from .uvdata import UVData, _future_array_shapes_warning, _select_blt_preprocess __all__ = ["input_output_mapping", "MWACorrFITS"] @@ -66,6 +66,7 @@ def read_metafits( meta_tbl = meta[1].data # because of polarization, each antenna # is listed twice + # antenna_inds are the correlator input numbers. antenna_inds = meta_tbl["Antenna"][1::2] antenna_numbers = meta_tbl["Tile"][1::2] antenna_names = meta_tbl["TileName"][1::2] @@ -629,6 +630,7 @@ def flag_init( def _read_fits_file( self, + *, filename, time_array, file_nums, @@ -638,6 +640,7 @@ def _read_fits_file( map_inds, conj, pol_index_array, + bl_inds=None, ): """ Read the fits file and populate into memory. @@ -666,6 +669,8 @@ def _read_fits_file( Indices for conjugating data_array from weird correlator packing. pol_index_array : array Indices for reordering polarizations to the 'AIPS' convention + bl_inds : array, optional + Baseline indices (after any re-mapping) to select on read. """ # get the file number from the file name @@ -688,6 +693,34 @@ def _read_fits_file( (self.Ntimes, num_fine_chans, self.Nbls * self.Npols), dtype=np.complex64, ) + + if bl_inds is not None: + bl_inds = np.array(bl_inds) + if not mwax: + # map_inds gives the baseline-pol ordering + n_orig_bls = int( + len(self.antenna_numbers) * (len(self.antenna_numbers) + 1) / 2.0 + ) + # reshape, do selection along bl axis, then flatten + bl_inds_map = np.take( + map_inds.reshape(n_orig_bls, self.Npols), bl_inds, axis=0 + ).flatten() + conj = np.take( + conj.reshape(n_orig_bls, self.Npols), bl_inds, axis=0 + ).flatten() + + # The data array is written with real, imaginary parts interleaved. + # This corresponds to a 2d array flattened where the last axis is + # real, imaginary + # So the indices need to be updated for that structure. + bl_inds_map_ri = np.concatenate( + ( + bl_inds_map[:, np.newaxis] * 2, + bl_inds_map[:, np.newaxis] * 2 + 1, + ), + axis=1, + ).flatten() + with fits.open(filename, mode="denywrite") as hdu_list: # if mwax, data is in every other hdu if mwax: @@ -704,7 +737,16 @@ def _read_fits_file( time_ind = np.where(time_array == time)[0][0] # dump data into matrix # and take data from real to complex numbers - coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data + if bl_inds is not None: + if not mwax: + coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data[ + :, bl_inds_map_ri + ] + else: + temp_data = hdu.data[bl_inds] + coarse_chan_data.view(np.float32)[time_ind, :, :] = temp_data + else: + coarse_chan_data.view(np.float32)[time_ind, :, :] = hdu.data # fill nsample and flag arrays # think about using the mwax weights array in the future self.nsample_array[ @@ -712,16 +754,21 @@ def _read_fits_file( ] = 1.0 self.flag_array[time_ind, :, coarse_ind, :] = False if not mwax: - # do mapping and reshaping here to avoid copying whole data_array - np.take(coarse_chan_data, map_inds, axis=2, out=coarse_chan_data) + if bl_inds is None: + # do mapping and reshaping here to avoid copying whole data_array + # map_inds gives the baseline-pol ordering + np.take(coarse_chan_data, map_inds, axis=2, out=coarse_chan_data) # conjugate data coarse_chan_data[:, :, conj] = np.conj(coarse_chan_data[:, :, conj]) # reshape + # each time gets its own HDU. MWAX has 2 HDUs per time (data/weights alternate) if mwax: + # freq and pol axes are combined, baseline axis is separate coarse_chan_data = coarse_chan_data.reshape( (self.Ntimes, self.Nbls, num_fine_chans, self.Npols) ) else: + # freq axis, then baseline-pol axis coarse_chan_data = coarse_chan_data.reshape( (self.Ntimes, num_fine_chans, self.Nbls, self.Npols) ) @@ -729,6 +776,7 @@ def _read_fits_file( coarse_chan_data = coarse_chan_data.reshape( self.Nblts, num_fine_chans, self.Npols ) + # reorder pols here to avoid memory spike from self.reorder_pols np.take(coarse_chan_data, pol_index_array, axis=-1, out=coarse_chan_data) # make a mask where data actually is so coarse channels that @@ -759,7 +807,7 @@ def _read_flag_file(self, filename, file_nums, num_fine_chans): freq_ind = np.where(file_nums == flag_num)[0][0] * num_fine_chans with fits.open(filename, mode="denywrite") as aoflags: flags = aoflags[1].data.field("FLAGS") - # some flag files are longer than data; crop the ends + # some flag files are longer than data; crop the end flags = flags[: self.Nblts, :] # some flag files are shorter than data; assume same end time blt_ind = self.Nblts - len(flags) @@ -1269,6 +1317,17 @@ def _apply_corrections( def read_mwa_corr_fits( self, filelist, + antenna_nums=None, + antenna_names=None, + bls=None, + # frequencies=None, + # freq_chans=None, + # times=None, + # time_range=None, + # lsts=None, + # lst_range=None, + # polarizations=None, + keep_all_metadata=True, use_aoflagger_flags=None, remove_dig_gains=True, remove_coarse_band=True, @@ -1322,6 +1381,12 @@ def read_mwa_corr_fits( if start_flag != "goodtime": raise ValueError("start_flag must be int or float or 'goodtime'") + # check that bls are a list of 2-tuples as required by _select_blt_preprocess + if bls is not None and not all(len(item) == 2 for item in bls): + raise ValueError( + "bls must be a list of 2-tuples giving antenna number pairs" + ) + # set future array shapes self._set_future_array_shapes() @@ -1373,6 +1438,8 @@ def read_mwa_corr_fits( # check headers for first and last times containing data headstart = hdu_list[1].header headfin = hdu_list[-1].header + # start & end times are for the full file set + # first & last are for this file first_time = headstart["TIME"] + headstart["MILLITIM"] / 1000.0 last_time = headfin["TIME"] + headfin["MILLITIM"] / 1000.0 if start_time == 0.0: @@ -1561,8 +1628,6 @@ def read_mwa_corr_fits( ant_1_inds, ant_2_inds = np.transpose( list(itertools.combinations_with_replacement(np.arange(self.Nants_data), 2)) ) - ant_1_inds = np.tile(np.array(ant_1_inds), self.Ntimes).astype(np.int_) - ant_2_inds = np.tile(np.array(ant_2_inds), self.Ntimes).astype(np.int_) if not mwax: # coarse channel mapping for the legacy correlator: @@ -1647,7 +1712,6 @@ def read_mwa_corr_fits( corr_ants_to_pfb_inputs[(meta_dict["antenna_inds"][i], p)] = ( 2 * i + p ) - # for mapping, start with a pair of antennas/polarizations # this is the pair we want to find the data for # map the pair to the corresponding coarse pfb input indices @@ -1669,6 +1733,55 @@ def read_mwa_corr_fits( else: map_inds = None conj = None + + # check if we want to do any select on the baseline axis + # Note: only passing the ant_1/2_arrays and baseline_array for one time. + bl_inds, bl_selections = _select_blt_preprocess( + select_antenna_nums=antenna_nums, + select_antenna_names=antenna_names, + bls=bls, + times=None, + time_range=None, + lsts=None, + lst_range=None, + blt_inds=None, + phase_center_ids=None, + antenna_names=self.antenna_names, + antenna_numbers=self.antenna_numbers, + ant_1_array=ant_1_array, + ant_2_array=ant_2_array, + baseline_array=self.antnums_to_baseline(ant_1_array, ant_2_array), + time_array=self.time_array, + time_tols=self._time_array.tols, + lst_array=self.lst_array, + lst_tols=self._lst_array.tols, + phase_center_id_array=self.phase_center_id_array, + ) + + if bl_inds is not None: + ant_1_inds = ant_1_inds[bl_inds] + ant_2_inds = ant_2_inds[bl_inds] + + history_update_string = ( + " Downselected to specific " + + ", ".join(bl_selections) + + " using pyuvdata." + ) + # do select operations on everything except data_array, flag_array + # and nsample_array + blt_inds = np.take( + np.arange(self.Nblts).reshape(self.Ntimes, self.Nbls), + bl_inds, + axis=1, + ).flatten() + self._select_by_index( + blt_inds, None, None, history_update_string, keep_all_metadata + ) + + ant_1_inds = np.tile(np.array(ant_1_inds), self.Ntimes).astype(np.int_) + ant_2_inds = np.tile(np.array(ant_2_inds), self.Ntimes).astype(np.int_) + + if read_data: # create arrays for data, nsamples, and flags self.data_array = np.zeros( (self.Nblts, self.Nfreqs, self.Npols), dtype=data_array_dtype @@ -1684,15 +1797,16 @@ def read_mwa_corr_fits( # read data files for filename in file_dict["data"]: self._read_fits_file( - filename, - time_array, - file_nums, - num_fine_chans, - meta_dict["int_time"], - mwax, - map_inds, - conj, - pol_index_array, + filename=filename, + time_array=time_array, + file_nums=file_nums, + num_fine_chans=num_fine_chans, + int_time=meta_dict["int_time"], + mwax=mwax, + map_inds=map_inds, + conj=conj, + pol_index_array=pol_index_array, + bl_inds=bl_inds, ) # propagate coarse flags @@ -1804,9 +1918,20 @@ def read_mwa_corr_fits( # remove bad antennas # select must be called after lst thread is re-joined - if remove_flagged_ants: + if ( + remove_flagged_ants + and meta_dict["flagged_ant_inds"].size > 0 + and np.sum( + np.isin( + meta_dict["flagged_ant_inds"], + np.union1d(self.ant_1_array, self.ant_2_array), + ) + ) + > 0 + ): good_ants = np.delete( - np.array(self.antenna_numbers), meta_dict["flagged_ant_inds"] + np.union1d(self.ant_1_array, self.ant_2_array), + meta_dict["flagged_ant_inds"], ) self.select(antenna_nums=good_ants, run_check=False) diff --git a/pyuvdata/uvdata/tests/test_mwa_corr_fits.py b/pyuvdata/uvdata/tests/test_mwa_corr_fits.py index a98defa6db..cecb638874 100644 --- a/pyuvdata/uvdata/tests/test_mwa_corr_fits.py +++ b/pyuvdata/uvdata/tests/test_mwa_corr_fits.py @@ -203,7 +203,8 @@ def test_select_on_read(): with uvtest.check_warnings( UserWarning, [ - 'Warning: select on read keyword set, but file_type is "mwa_corr_fits"', + "Warning: a select on read keyword is set that is not supported by " + "read_mwa_corr_fits. This select will be done after reading the file.", "some coarse channel files were not submitted", ], ): @@ -1234,3 +1235,65 @@ def test_bscale(tmp_path): # check mwax data uv4.read(filelist[11:13], use_future_array_shapes=True) assert "SCALEFAC" not in uv4.extra_keywords.keys() + + +@pytest.mark.filterwarnings("ignore:some coarse channel files were not submitted") +@pytest.mark.filterwarnings("ignore:Fixing auto-correlations to be be real-only") +@pytest.mark.parametrize( + ["select_kwargs", "warn_msg"], + [ + [{"antenna_nums": [18, 31, 66, 95]}, ""], + [{"antenna_names": [f"Tile{ant:03d}" for ant in [18, 31, 66, 95]]}, ""], + [{"bls": [(48, 34), (96, 11), (22, 87)]}, ""], + [ + {"ant_str": "48_34,96_11,22_87"}, + "a select on read keyword is set that is not supported by " + "read_mwa_corr_fits. This select will be done after reading the file.", + ], + ], +) +@pytest.mark.parametrize("mwax", [False, True]) +def test_partial_read_bl_axis(tmp_path, flag_file_init, mwax, select_kwargs, warn_msg): + + if mwax: + # generate a spoof file with 16 channels + cb_spoof = str(tmp_path / "mwax_cb_spoof80_ch137_000.fits") + meta_spoof = str(tmp_path / "mwax_cb_spoof80.metafits") + + with fits.open(filelist[12]) as mini1: + mini1[1].data = np.repeat(mini1[1].data, 16, axis=1) + mini1.writeto(cb_spoof) + + with fits.open(filelist[11]) as meta: + meta[0].header["FINECHAN"] = 80 + meta.writeto(meta_spoof) + + files_use = [meta_spoof, cb_spoof] + + else: + cb_spoof = str(tmp_path / "cb_spoof_01_00.fits") + with fits.open(filelist[1]) as mini1: + mini1[1].data = np.repeat(mini1[1].data, 32, axis=0) + mini1.writeto(cb_spoof) + files_use = [filelist[0], cb_spoof] + + uv_full = UVData.from_file(files_use, use_future_array_shapes=True) + + warn_msg_list = ["some coarse channel files were not submitted"] + if warn_msg != "": + warn_msg_list.append(warn_msg) + + if mwax and "bls" not in select_kwargs.keys(): + # The bls selection has no autos + warn_msg_list.append("Fixing auto-correlations to be be real-only") + + with uvtest.check_warnings(UserWarning, match=warn_msg_list): + uv_partial = UVData.from_file( + files_use, use_future_array_shapes=True, **select_kwargs + ) + exp_uv = uv_full.select(**select_kwargs, inplace=False) + + # history doesn't match because of different order of operations. + exp_uv.history = uv_partial.history + + assert uv_partial == exp_uv diff --git a/pyuvdata/uvdata/uvdata.py b/pyuvdata/uvdata/uvdata.py index 8ba60cee98..5d700e8f0e 100644 --- a/pyuvdata/uvdata/uvdata.py +++ b/pyuvdata/uvdata/uvdata.py @@ -136,16 +136,16 @@ def _select_blt_preprocess( Parameters ---------- - antenna_nums : array_like of int, optional + select_antenna_nums : array_like of int, optional The antennas numbers to keep in the object (antenna positions and names for the removed antennas will be retained unless `keep_all_metadata` is False). This cannot be provided if - `antenna_names` is also provided. - antenna_names : array_like of str, optional + `select_antenna_names` is also provided. + select_antenna_names : array_like of str, optional The antennas names to keep in the object (antenna positions and names for the removed antennas will be retained unless `keep_all_metadata` is False). This cannot be provided if - `antenna_nums` is also provided. + `select_antenna_nums` is also provided. bls : list of 2-tuples, optional A list of antenna number tuples (e.g. [(0, 1), (3, 2)]) specifying baselines to keep in the object. The ordering of the numbers within the @@ -11069,6 +11069,57 @@ def read_mwa_corr_fits(self, filelist, **kwargs): filelist : list of str The list of MWA correlator files to read from. Must include at least one fits file and only one metafits file per data set. + antenna_nums : array_like of int, optional + The antennas numbers to include when reading data into the object + (antenna positions and names for the removed antennas will be retained + unless `keep_all_metadata` is False). This cannot be provided if + `antenna_names` is also provided. Ignored if read_data is False. + antenna_names : array_like of str, optional + The antennas names to include when reading data into the object + (antenna positions and names for the removed antennas will be retained + unless `keep_all_metadata` is False). This cannot be provided if + `antenna_nums` is also provided. Ignored if read_data is False. + bls : list of tuple, optional + A list of antenna number tuples (e.g. [(0, 1), (3, 2)]). The + ordering of the numbers within the tuple does not matter. + Note that this is different than what can be passed to the parameter + of the same name on `select` and other read methods -- this parameter + does not accept 3-tuples or baseline numbers. + Ignored if read_data is False. + # frequencies : array_like of float, optional + # The frequencies to include when reading data into the object, each + # value passed here should exist in the freq_array. Ignored if + # read_data is False. + # freq_chans : array_like of int, optional + # The frequency channel numbers to include when reading data into the + # object. Ignored if read_data is False. + # times : array_like of float, optional + # The times to include when reading data into the object, each value + # passed here should exist in the time_array. Cannot be used with + # `time_range`, `lsts`, or `lst_array`. + # time_range : array_like of float, optional + # The time range in Julian Date to include when reading data into + # the object, must be length 2. Some of the times in the file should + # fall between the first and last elements. + # Cannot be used with `times`. + # lsts : array_like of float, optional + # The local sidereal times (LSTs) to keep in the object, each value + # passed here should exist in the lst_array. Cannot be used with + # `times`, `time_range`, or `lst_range`. + # lst_range : array_like of float, optional + # The local sidereal time (LST) range in radians to keep in the + # object, must be of length 2. Some of the LSTs in the object should + # fall between the first and last elements. If the second value is + # smaller than the first, the LSTs are treated as having phase-wrapped + # around LST = 2*pi = 0, and the LSTs kept on the object will run from + # the larger value, through 0, and end at the smaller value. + # polarizations : array_like of int, optional + # The polarizations numbers to include when reading data into the + # object, each value passed here should exist in the polarization_array. + # Ignored if read_data is False. + keep_all_metadata : bool + Option to keep all the metadata associated with antennas, even those + that do not have data associated with them after the select option. use_aoflagger_flags : bool Option to use aoflagger mwaf flag files. Defaults to true if aoflagger flag files are submitted. @@ -11723,6 +11774,7 @@ def read( An ant_str cannot be passed in addition to any of `antenna_nums`, `antenna_names`, `bls` args or the `polarizations` parameters, if it is a ValueError will be raised. + Note that this keyword is not supported for MWA correlator FITS files. bls : list of tuple, optional A list of antenna number tuples (e.g. [(0, 1), (3, 2)]) or a list of baseline 3-tuples (e.g. [(0, 1, 'xx'), (2, 3, 'yy')]) specifying baselines @@ -11730,7 +11782,8 @@ def read( the ordering of the numbers within the tuple does not matter. For length-3 tuples, the polarization string is in the order of the two antennas. If length-3 tuples are provided, `polarizations` must be - None. + None. Note that for MWA correlator FITS files, this can only be a + list of antenna number 2-tuples. catalog_names : str or array-like of str, optional The names of the phase centers (sources) to include when reading data into the object, which should match exactly in spelling and capitalization. @@ -12432,7 +12485,7 @@ def read( # everything is merged into it at the end of this loop else: - if file_type in ["fhd", "ms", "mwa_corr_fits"]: + if file_type in ["fhd", "ms"]: if ( antenna_nums is not None or antenna_names is not None @@ -12568,6 +12621,58 @@ def read( "not supported by read_mir. This select will " "be done after reading the file." ) + elif file_type == "mwa_corr_fits": + select = True + # these are all done by partial read, so set to None + select_antenna_nums = None + select_antenna_names = None + select_bls = None + # select_lst_range = None + # select_time_range = None + # select_times = None + # select_lsts = None + + # MWA corr fits can only handle length-two bls tuples, anything + # else needs to be handled via select. + if bls is not None: + if not all(len(item) == 2 for item in bls): + select_bls = bls + + # these aren't supported by partial read, so do it in select + select_ant_str = ant_str + select_blt_inds = blt_inds + select_phase_center_ids = phase_center_ids + select_polarizations = polarizations + select_frequencies = frequencies + select_freq_chans = freq_chans + select_lst_range = lst_range + select_time_range = time_range + select_times = times + select_lsts = lsts + + if all( + item is None + for item in [ + select_bls, + blt_inds, + phase_center_ids, + ant_str, + frequencies, + freq_chans, + lst_range, + time_range, + times, + lsts, + ] + ): + # If there's nothing to select, just bypass that operation. + select = False + else: + warnings.warn( + "Warning: a select on read keyword is set that is " + "not supported by read_mwa_corr_fits. This select will " + "be done after reading the file." + ) # reading a single "file". Call the appropriate file-type read if file_type == "uvfits": self.read_uvfits( @@ -12660,6 +12765,17 @@ def read( elif file_type == "mwa_corr_fits": self.read_mwa_corr_fits( filename, + antenna_nums=antenna_nums, + antenna_names=antenna_names, + bls=bls, + # frequencies=frequencies, + # freq_chans=freq_chans, + # times=times, + # time_range=time_range, + # lsts=lsts, + # lst_range=lst_range, + # polarizations=polarizations, + keep_all_metadata=keep_all_metadata, use_aoflagger_flags=use_aoflagger_flags, remove_dig_gains=remove_dig_gains, remove_coarse_band=remove_coarse_band,