From 5af92382bcee3d9638c85444c304aa83f06d1749 Mon Sep 17 00:00:00 2001 From: Garrett 'Karto' Keating Date: Thu, 21 Sep 2023 13:49:26 -0400 Subject: [PATCH] Adding compare_value method to make it easier to check UVParameter values (particularly w/ ndarrays) --- pyuvdata/parameter.py | 42 ++++++++++++++++++++++++++++++++ pyuvdata/tests/test_parameter.py | 27 ++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/pyuvdata/parameter.py b/pyuvdata/parameter.py index 7803d8e0dd..3c09521dd3 100644 --- a/pyuvdata/parameter.py +++ b/pyuvdata/parameter.py @@ -659,6 +659,48 @@ def check_acceptability(self): ) return False, message + def compare_value(self, value): + """ + Compare UVParameter value to a supplied value. + + Parameters + ---------- + value + The value to compare against that stored in the UVParameter object. Must + be the same type. + + Returns + ------- + same : bool + True if the values are equivalent (or within specified tolerances), + otherwise false. + """ + # Catch the case when the values are different types + if not ( + isinstance(value, self.value.__class__) + and isinstance(self.value, value.__class__) + ): + raise ValueError( + "UVParameter value and supplied values are of different types." + ) + + # If these are numeric types, handle them via allclose + if isinstance(value, (np.ndarray, int, float, complex)): + # Check that we either have a number or an ndarray + if not isinstance(value, np.ndarray) or value.shape == self.value.shape: + if np.allclose( + value, + self.value, + rtol=self.tols[0], + atol=self.tols[1], + equal_nan=True, + ): + return True + return False + else: + # Otherwise just default to checking equality + return value == self.value + class AngleParameter(UVParameter): """ diff --git a/pyuvdata/tests/test_parameter.py b/pyuvdata/tests/test_parameter.py index 7a8d58823b..94dc0f3bb5 100644 --- a/pyuvdata/tests/test_parameter.py +++ b/pyuvdata/tests/test_parameter.py @@ -657,3 +657,30 @@ def test_spoof(): assert param.value is None param.apply_spoof() assert param.value == 1.0 + + +def test_compare_value_err(): + param = uvp.UVParameter("_test1", value=3.0, tols=[0, 1], expected_type=float) + with pytest.raises( + ValueError, + match="UVParameter value and supplied values are of different types.", + ): + param.compare_value("test") + + +@pytest.mark.parametrize( + "value,status", + [ + (np.array([1, 2]), False), + (np.array([1, 2, 3]), True), + (np.array([1.0, 2.0, 3.0]), True), + (np.array([2, 3, 4]), True), + (np.array([4, 5, 6]), False), + (np.array([1, 2, 3, 4, 5, 6]), False), + ], +) +def test_compare_value(value, status): + param = uvp.UVParameter( + "_test1", value=np.array([1, 2, 3]), tols=[0, 1], expected_type=float + ) + assert param.compare_value(value) == status