From ba7af6900540326cc2540256aef688ecb2034446 Mon Sep 17 00:00:00 2001 From: Garrett 'Karto' Keating Date: Wed, 3 Jan 2024 08:09:00 -0500 Subject: [PATCH] Minor reworking of COMPASS import, related equality checking --- pyuvdata/uvdata/mir_parser.py | 94 +++++++++++++----------- pyuvdata/uvdata/tests/test_mir_parser.py | 59 ++++++++++----- 2 files changed, 95 insertions(+), 58 deletions(-) diff --git a/pyuvdata/uvdata/mir_parser.py b/pyuvdata/uvdata/mir_parser.py index 2e52bdfcd..ef9c25005 100644 --- a/pyuvdata/uvdata/mir_parser.py +++ b/pyuvdata/uvdata/mir_parser.py @@ -145,7 +145,10 @@ def __init__( self._has_cross = False self._tsys_applied = False self._tsys_use_cont_det = True - self._compass_solns = None + self._has_compass_soln = False + self._compass_bp_soln = None + self._compass_sphid_flags = None + self._compass_static_flags = None # This value is the forward gain of the antenna (in units of Jy/K), which is # multiplied against the system temperatures in order to produce values in units @@ -211,6 +214,14 @@ def __eq__(self, other, *, verbose=True, metadata_only=False): "raw_data": ["data", "scale_fac"], "vis_data": ["data", "flags"], "auto_data": ["data", "flags"], + "_compass_bp_soln": [ + "cal_soln", + "cal_flags", + "weight_soln", + "weight_flags", + ], + "_compass_sphid_flags": [...], + "_compass_static_flags": [...], "_stored_masks": [ "in_data", "bl_data", @@ -292,19 +303,28 @@ def __eq__(self, other, *, verbose=True, metadata_only=False): is_same &= this_item[subkey] == other_item[subkey] elif not np.array_equal(this_item[subkey], other_item[subkey]): if this_item[subkey].shape == other_item[subkey].shape: - # The atol here is set by the max value in the spectrum - # times 2^-10. That turns out to be _about_ the worst - # case scenario for moving to and from the raw data - # format, which compresses the data down from floats to - # ints. - atol = 1e-3 - if np.any(np.isfinite(this_item[subkey])): - atol *= np.nanmax(np.abs(this_item[subkey])) + atol = rtol = 0 + if subkey == "data": + # The atol here is set by the max value in the + # spectrum times 2^-10. That turns out to be _about_ + # the worst case scenario for moving to and from the + # raw data format, which compresses the data down + # from floats to ints. + atol = 1e-3 + if np.any(np.isfinite(this_item[subkey])): + atol = 1e-3 * np.nanmax( + np.abs(this_item[subkey]) + ) + else: + # Otherwise if not looking at data, use something + # close to the single precision floating point. + rtol = 1e-6 is_same &= np.allclose( this_item[subkey], other_item[subkey], atol=atol, + rtol=rtol, equal_nan=True, ) else: @@ -326,25 +346,6 @@ def __eq__(self, other, *, verbose=True, metadata_only=False): f"{item} has different keys, left is {this_attr.keys()}, " f"right is {other_attr.keys()}." ) - elif item == "_compass_solns": - is_eq = True - try: - for i_key in this_attr: - for j_key in this_attr[i_key]: - for k_key in this_attr[i_key][j_key]: - if is_eq: - arr1 = this_attr[i_key][j_key][k_key] - arr2 = other_attr[i_key][j_key][k_key] - # Note we use the two checks here b/c array_equal - # is much faster if it passes, and if it does - # python doesn't need to look at the second or arg - is_eq = np.array_equal( - arr1, arr2, equal_nan=True - ) or np.allclose(arr1, arr2, equal_nan=True) - except KeyError: - is_eq = False - if not is_eq: - verbose_print(f"{item} is different (skipping full print).") else: # We don't have special handling for this attribute at this point, so # we just use the generic __ne__ method. @@ -1084,12 +1085,12 @@ def _read_data( 'Argument for data_type not recognized, must be "cross" or "auto".' ) if apply_cal: - if self._compass_solns is None: + if not self._has_compass_soln: raise ValueError("Cannot apply calibration if no tables loaded.") if not scale_data: raise ValueError("Cannot return raw data if setting apply_cal=True") elif apply_cal is None and scale_data: - apply_cal = self._compass_solns is not None + apply_cal = self._has_compass_soln if data_type == "cross": if apply_cal: @@ -1169,7 +1170,7 @@ def _read_data( if apply_cal and common_scale: temp_dict = self._convert_raw_to_vis(temp_dict) - temp_dict = self._apply_compass_solns(self._compass_solns, temp_dict) + temp_dict = self._apply_compass_solns(temp_dict) if np.all(chan_avg_arr == 1): if apply_cal: @@ -3363,7 +3364,7 @@ def _read_compass_solns(self, filename=None): return compass_soln_dict - def read_compass_solns(self, filename=None): + def read_compass_solns(self, filename=None, load_flags=True, load_bandpass=True): """ Read in COMPASS-formatted bandpass and flagging solutions. @@ -3381,17 +3382,26 @@ def read_compass_solns(self, filename=None): If the COMPASS solutions do not appear to overlap in time with that in the MirParser object. """ + if not load_flags and not load_bandpass: + # Insert snarky no-op comment here + return if not (self.vis_data is None and self.raw_data is None): raise ValueError( "Cannot call read_compass_solns when data have already been loaded, " "call unload_data first in order to resolve this error." ) - self._compass_solns = self._read_compass_solns(filename) + compass_soln_dict = self._read_compass_solns(filename) + if load_flags: + self._compass_sphid_flags = compass_soln_dict["sphid_flags"] + self._compass_static_flags = compass_soln_dict["static_flags"] - def _apply_compass_solns( - self, compass_soln_dict=None, vis_data=None, *, apply_flags=True, apply_bp=True - ): + if load_bandpass: + self._compass_bp_soln = compass_soln_dict["bp_gains_corr"] + + self._has_compass_soln = True + + def _apply_compass_solns(self, vis_data=None): """ Apply COMPASS-derived gains and flagging. @@ -3432,10 +3442,10 @@ def _apply_compass_solns( chunk_arr = self.sp_data.get_value("corrchunk", header_key=sphid_arr) # SPW# sb_arr = self.bl_data.get_value("isb", header_key=blhid_arr) # SB| 0:LSB 1:USB - if apply_bp: + if self._compass_bp_soln is not None: # Let's grab the bandpass solns upfront before we iterate through # all of the individual spectral records. - bp_soln = compass_soln_dict["bp_gains_corr"] + bp_soln = self._compass_bp_soln for sphid, sb, ant1, rx1, ant2, rx2, chunk in zip( sphid_arr, sb_arr, ant1_arr, rx1_arr, ant2_arr, rx2_arr, chunk_arr @@ -3460,11 +3470,13 @@ def _apply_compass_solns( cal_soln["cal_flags"] | cal_soln["weight_flags"] ) - if apply_flags: + if not ( + self._compass_sphid_flags is None or self._compass_static_flags is None + ): # For the sake of reading/coding, let's assign the two catalogs of flags # to their own variables, so that we can easily call them later. - sphid_flags = compass_soln_dict["sphid_flags"] - static_flags = compass_soln_dict["static_flags"] + sphid_flags = self._compass_sphid_flags + static_flags = self._compass_static_flags for idx, sphid in enumerate(sphid_arr): # Now we'll step through each spectral record that we have to process. diff --git a/pyuvdata/uvdata/tests/test_mir_parser.py b/pyuvdata/uvdata/tests/test_mir_parser.py index 97bca1656..d52c8795a 100644 --- a/pyuvdata/uvdata/tests/test_mir_parser.py +++ b/pyuvdata/uvdata/tests/test_mir_parser.py @@ -316,7 +316,7 @@ def test_compass_read_err(mir_data: MirParser, compass_soln_file): mir_data.read_compass_solns(compass_soln_file) -def test_compass_flag_sphid_apply(mir_data, compass_soln_file): +def test_compass_flag_sphid_apply(mir_data: MirParser, compass_soln_file): """ Test COMPASS per-sphid flagging. @@ -326,10 +326,12 @@ def test_compass_flag_sphid_apply(mir_data, compass_soln_file): for entry in mir_data.vis_data.values(): entry["flags"][:] = False - compass_solns = mir_data._read_compass_solns(compass_soln_file) - mir_data._apply_compass_solns( - compass_solns, mir_data.vis_data, apply_bp=False, apply_flags=True - ) + assert mir_data._compass_bp_soln is None + vis_data = mir_data.vis_data + mir_data.vis_data = None + mir_data.read_compass_solns(compass_soln_file, load_flags=True, load_bandpass=False) + mir_data.vis_data = vis_data + mir_data._apply_compass_solns(mir_data.vis_data) for key, entry in mir_data.vis_data.items(): if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0: assert not np.all(entry["flags"][1::2]) @@ -348,14 +350,19 @@ def test_compass_flag_static_apply(mir_data, compass_soln_file): entry["flags"][-1] = True mir_data.in_data["mjd"] += 1 + + vis_data = mir_data.vis_data + mir_data.vis_data = None with uvtest.check_warnings( UserWarning, "No metadata from COMPASS matches that in this data set." ): - compass_solns = mir_data._read_compass_solns(compass_soln_file) + mir_data.read_compass_solns( + compass_soln_file, load_flags=True, load_bandpass=False + ) - mir_data._apply_compass_solns( - compass_solns, mir_data.vis_data, apply_bp=False, apply_flags=True - ) + assert mir_data._compass_bp_soln is None + mir_data.vis_data = vis_data + mir_data._apply_compass_solns(mir_data.vis_data) for key, entry in mir_data.vis_data.items(): if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0: @@ -365,7 +372,7 @@ def test_compass_flag_static_apply(mir_data, compass_soln_file): @pytest.mark.parametrize("muck_solns", ["none", "some", "all"]) -def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns): +def test_compass_bp_apply(mir_data: MirParser, compass_soln_file, muck_solns): """ Test COMPASS bandpass calibration. @@ -381,15 +388,23 @@ def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns): if muck_solns == "all": mir_data.bl_data["iant2"] += 1 + vis_data = mir_data.vis_data + mir_data.vis_data = None + with uvtest.check_warnings( None if (muck_solns == "none") else UserWarning, None if (muck_solns == "none") else "No metadata from COMPASS matches", ): - compass_solns = mir_data._read_compass_solns(compass_soln_file) + mir_data.read_compass_solns( + compass_soln_file, load_flags=False, load_bandpass=True + ) - mir_data._apply_compass_solns( - compass_solns, mir_data.vis_data, apply_bp=True, apply_flags=False - ) + assert mir_data._compass_static_flags is None + assert mir_data._compass_sphid_flags is None + + mir_data.vis_data = vis_data + + mir_data._apply_compass_solns(mir_data.vis_data) for key, entry in mir_data.vis_data.items(): if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0: @@ -399,6 +414,16 @@ def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns): assert (muck_solns != "none") == np.all(entry["flags"]) +def test_compass_no_op(mir_data: MirParser, compass_soln_file): + mir_data.read_compass_solns( + compass_soln_file, load_flags=False, load_bandpass=False + ) + assert not mir_data._has_compass_soln + assert mir_data._compass_bp_soln is None + assert mir_data._compass_sphid_flags is None + assert mir_data._compass_static_flags is None + + def test_compass_rechunk_routing(mir_data: MirParser, compass_soln_file): mir_data.unload_data() mir_data.read_compass_solns(compass_soln_file) @@ -668,12 +693,12 @@ def test_data_errs(mir_data, attr): @pytest.mark.parametrize( "compass_soln,kwargs,err_msg", [ - [None, {}, "Cannot apply calibration if no tables loaded."], - [{}, {"scale_data": False}, "Cannot return raw data if setting apply_cal=True"], + [False, {}, "Cannot apply calibration if no tables loaded."], + [True, {"scale_data": False}, "Cannot return raw data if setting apply_cal=Tr"], ], ) def test_read_data_errs(mir_data, compass_soln, kwargs, err_msg): - mir_data._compass_solns = compass_soln + mir_data._has_compass_soln = compass_soln with pytest.raises(ValueError, match=err_msg): mir_data._read_data("cross", apply_cal=True, **kwargs)