Skip to content

Commit

Permalink
Minor reworking of COMPASS import, related equality checking
Browse files Browse the repository at this point in the history
  • Loading branch information
kartographer committed Jan 5, 2024
1 parent 7a2bca4 commit ba7af69
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 58 deletions.
94 changes: 53 additions & 41 deletions pyuvdata/uvdata/mir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def __init__(
self._has_cross = False
self._tsys_applied = False
self._tsys_use_cont_det = True
self._compass_solns = None
self._has_compass_soln = False
self._compass_bp_soln = None
self._compass_sphid_flags = None
self._compass_static_flags = None

# This value is the forward gain of the antenna (in units of Jy/K), which is
# multiplied against the system temperatures in order to produce values in units
Expand Down Expand Up @@ -211,6 +214,14 @@ def __eq__(self, other, *, verbose=True, metadata_only=False):
"raw_data": ["data", "scale_fac"],
"vis_data": ["data", "flags"],
"auto_data": ["data", "flags"],
"_compass_bp_soln": [
"cal_soln",
"cal_flags",
"weight_soln",
"weight_flags",
],
"_compass_sphid_flags": [...],
"_compass_static_flags": [...],
"_stored_masks": [
"in_data",
"bl_data",
Expand Down Expand Up @@ -292,19 +303,28 @@ def __eq__(self, other, *, verbose=True, metadata_only=False):
is_same &= this_item[subkey] == other_item[subkey]
elif not np.array_equal(this_item[subkey], other_item[subkey]):
if this_item[subkey].shape == other_item[subkey].shape:
# The atol here is set by the max value in the spectrum
# times 2^-10. That turns out to be _about_ the worst
# case scenario for moving to and from the raw data
# format, which compresses the data down from floats to
# ints.
atol = 1e-3
if np.any(np.isfinite(this_item[subkey])):
atol *= np.nanmax(np.abs(this_item[subkey]))
atol = rtol = 0
if subkey == "data":
# The atol here is set by the max value in the
# spectrum times 2^-10. That turns out to be _about_
# the worst case scenario for moving to and from the
# raw data format, which compresses the data down
# from floats to ints.
atol = 1e-3
if np.any(np.isfinite(this_item[subkey])):
atol = 1e-3 * np.nanmax(
np.abs(this_item[subkey])
)
else:
# Otherwise if not looking at data, use something
# close to the single precision floating point.
rtol = 1e-6

is_same &= np.allclose(
this_item[subkey],
other_item[subkey],
atol=atol,
rtol=rtol,
equal_nan=True,
)
else:
Expand All @@ -326,25 +346,6 @@ def __eq__(self, other, *, verbose=True, metadata_only=False):
f"{item} has different keys, left is {this_attr.keys()}, "
f"right is {other_attr.keys()}."
)
elif item == "_compass_solns":
is_eq = True
try:
for i_key in this_attr:
for j_key in this_attr[i_key]:
for k_key in this_attr[i_key][j_key]:
if is_eq:
arr1 = this_attr[i_key][j_key][k_key]
arr2 = other_attr[i_key][j_key][k_key]
# Note we use the two checks here b/c array_equal
# is much faster if it passes, and if it does
# python doesn't need to look at the second or arg
is_eq = np.array_equal(
arr1, arr2, equal_nan=True
) or np.allclose(arr1, arr2, equal_nan=True)
except KeyError:
is_eq = False
if not is_eq:
verbose_print(f"{item} is different (skipping full print).")
else:
# We don't have special handling for this attribute at this point, so
# we just use the generic __ne__ method.
Expand Down Expand Up @@ -1084,12 +1085,12 @@ def _read_data(
'Argument for data_type not recognized, must be "cross" or "auto".'
)
if apply_cal:
if self._compass_solns is None:
if not self._has_compass_soln:
raise ValueError("Cannot apply calibration if no tables loaded.")
if not scale_data:
raise ValueError("Cannot return raw data if setting apply_cal=True")
elif apply_cal is None and scale_data:
apply_cal = self._compass_solns is not None
apply_cal = self._has_compass_soln

if data_type == "cross":
if apply_cal:
Expand Down Expand Up @@ -1169,7 +1170,7 @@ def _read_data(

if apply_cal and common_scale:
temp_dict = self._convert_raw_to_vis(temp_dict)
temp_dict = self._apply_compass_solns(self._compass_solns, temp_dict)
temp_dict = self._apply_compass_solns(temp_dict)

if np.all(chan_avg_arr == 1):
if apply_cal:
Expand Down Expand Up @@ -3363,7 +3364,7 @@ def _read_compass_solns(self, filename=None):

return compass_soln_dict

def read_compass_solns(self, filename=None):
def read_compass_solns(self, filename=None, load_flags=True, load_bandpass=True):
"""
Read in COMPASS-formatted bandpass and flagging solutions.
Expand All @@ -3381,17 +3382,26 @@ def read_compass_solns(self, filename=None):
If the COMPASS solutions do not appear to overlap in time with that in
the MirParser object.
"""
if not load_flags and not load_bandpass:
# Insert snarky no-op comment here
return
if not (self.vis_data is None and self.raw_data is None):
raise ValueError(
"Cannot call read_compass_solns when data have already been loaded, "
"call unload_data first in order to resolve this error."
)

self._compass_solns = self._read_compass_solns(filename)
compass_soln_dict = self._read_compass_solns(filename)
if load_flags:
self._compass_sphid_flags = compass_soln_dict["sphid_flags"]
self._compass_static_flags = compass_soln_dict["static_flags"]

def _apply_compass_solns(
self, compass_soln_dict=None, vis_data=None, *, apply_flags=True, apply_bp=True
):
if load_bandpass:
self._compass_bp_soln = compass_soln_dict["bp_gains_corr"]

self._has_compass_soln = True

def _apply_compass_solns(self, vis_data=None):
"""
Apply COMPASS-derived gains and flagging.
Expand Down Expand Up @@ -3432,10 +3442,10 @@ def _apply_compass_solns(
chunk_arr = self.sp_data.get_value("corrchunk", header_key=sphid_arr) # SPW#
sb_arr = self.bl_data.get_value("isb", header_key=blhid_arr) # SB| 0:LSB 1:USB

if apply_bp:
if self._compass_bp_soln is not None:
# Let's grab the bandpass solns upfront before we iterate through
# all of the individual spectral records.
bp_soln = compass_soln_dict["bp_gains_corr"]
bp_soln = self._compass_bp_soln

for sphid, sb, ant1, rx1, ant2, rx2, chunk in zip(
sphid_arr, sb_arr, ant1_arr, rx1_arr, ant2_arr, rx2_arr, chunk_arr
Expand All @@ -3460,11 +3470,13 @@ def _apply_compass_solns(
cal_soln["cal_flags"] | cal_soln["weight_flags"]
)

if apply_flags:
if not (
self._compass_sphid_flags is None or self._compass_static_flags is None
):
# For the sake of reading/coding, let's assign the two catalogs of flags
# to their own variables, so that we can easily call them later.
sphid_flags = compass_soln_dict["sphid_flags"]
static_flags = compass_soln_dict["static_flags"]
sphid_flags = self._compass_sphid_flags
static_flags = self._compass_static_flags

for idx, sphid in enumerate(sphid_arr):
# Now we'll step through each spectral record that we have to process.
Expand Down
59 changes: 42 additions & 17 deletions pyuvdata/uvdata/tests/test_mir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_compass_read_err(mir_data: MirParser, compass_soln_file):
mir_data.read_compass_solns(compass_soln_file)


def test_compass_flag_sphid_apply(mir_data, compass_soln_file):
def test_compass_flag_sphid_apply(mir_data: MirParser, compass_soln_file):
"""
Test COMPASS per-sphid flagging.
Expand All @@ -326,10 +326,12 @@ def test_compass_flag_sphid_apply(mir_data, compass_soln_file):
for entry in mir_data.vis_data.values():
entry["flags"][:] = False

compass_solns = mir_data._read_compass_solns(compass_soln_file)
mir_data._apply_compass_solns(
compass_solns, mir_data.vis_data, apply_bp=False, apply_flags=True
)
assert mir_data._compass_bp_soln is None
vis_data = mir_data.vis_data
mir_data.vis_data = None
mir_data.read_compass_solns(compass_soln_file, load_flags=True, load_bandpass=False)
mir_data.vis_data = vis_data
mir_data._apply_compass_solns(mir_data.vis_data)
for key, entry in mir_data.vis_data.items():
if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0:
assert not np.all(entry["flags"][1::2])
Expand All @@ -348,14 +350,19 @@ def test_compass_flag_static_apply(mir_data, compass_soln_file):
entry["flags"][-1] = True

mir_data.in_data["mjd"] += 1

vis_data = mir_data.vis_data
mir_data.vis_data = None
with uvtest.check_warnings(
UserWarning, "No metadata from COMPASS matches that in this data set."
):
compass_solns = mir_data._read_compass_solns(compass_soln_file)
mir_data.read_compass_solns(
compass_soln_file, load_flags=True, load_bandpass=False
)

mir_data._apply_compass_solns(
compass_solns, mir_data.vis_data, apply_bp=False, apply_flags=True
)
assert mir_data._compass_bp_soln is None
mir_data.vis_data = vis_data
mir_data._apply_compass_solns(mir_data.vis_data)

for key, entry in mir_data.vis_data.items():
if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0:
Expand All @@ -365,7 +372,7 @@ def test_compass_flag_static_apply(mir_data, compass_soln_file):


@pytest.mark.parametrize("muck_solns", ["none", "some", "all"])
def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns):
def test_compass_bp_apply(mir_data: MirParser, compass_soln_file, muck_solns):
"""
Test COMPASS bandpass calibration.
Expand All @@ -381,15 +388,23 @@ def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns):
if muck_solns == "all":
mir_data.bl_data["iant2"] += 1

vis_data = mir_data.vis_data
mir_data.vis_data = None

with uvtest.check_warnings(
None if (muck_solns == "none") else UserWarning,
None if (muck_solns == "none") else "No metadata from COMPASS matches",
):
compass_solns = mir_data._read_compass_solns(compass_soln_file)
mir_data.read_compass_solns(
compass_soln_file, load_flags=False, load_bandpass=True
)

mir_data._apply_compass_solns(
compass_solns, mir_data.vis_data, apply_bp=True, apply_flags=False
)
assert mir_data._compass_static_flags is None
assert mir_data._compass_sphid_flags is None

mir_data.vis_data = vis_data

mir_data._apply_compass_solns(mir_data.vis_data)

for key, entry in mir_data.vis_data.items():
if mir_data.sp_data.get_value("corrchunk", header_key=key) != 0:
Expand All @@ -399,6 +414,16 @@ def test_compass_bp_apply(mir_data, compass_soln_file, muck_solns):
assert (muck_solns != "none") == np.all(entry["flags"])


def test_compass_no_op(mir_data: MirParser, compass_soln_file):
mir_data.read_compass_solns(
compass_soln_file, load_flags=False, load_bandpass=False
)
assert not mir_data._has_compass_soln
assert mir_data._compass_bp_soln is None
assert mir_data._compass_sphid_flags is None
assert mir_data._compass_static_flags is None


def test_compass_rechunk_routing(mir_data: MirParser, compass_soln_file):
mir_data.unload_data()
mir_data.read_compass_solns(compass_soln_file)
Expand Down Expand Up @@ -668,12 +693,12 @@ def test_data_errs(mir_data, attr):
@pytest.mark.parametrize(
"compass_soln,kwargs,err_msg",
[
[None, {}, "Cannot apply calibration if no tables loaded."],
[{}, {"scale_data": False}, "Cannot return raw data if setting apply_cal=True"],
[False, {}, "Cannot apply calibration if no tables loaded."],
[True, {"scale_data": False}, "Cannot return raw data if setting apply_cal=Tr"],
],
)
def test_read_data_errs(mir_data, compass_soln, kwargs, err_msg):
mir_data._compass_solns = compass_soln
mir_data._has_compass_soln = compass_soln
with pytest.raises(ValueError, match=err_msg):
mir_data._read_data("cross", apply_cal=True, **kwargs)

Expand Down

0 comments on commit ba7af69

Please sign in to comment.