diff --git a/pyuvdata/parameter.py b/pyuvdata/parameter.py index 7803d8e0d..3c09521dd 100644 --- a/pyuvdata/parameter.py +++ b/pyuvdata/parameter.py @@ -659,6 +659,48 @@ def check_acceptability(self): ) return False, message + def compare_value(self, value): + """ + Compare UVParameter value to a supplied value. + + Parameters + ---------- + value + The value to compare against that stored in the UVParameter object. Must + be the same type. + + Returns + ------- + same : bool + True if the values are equivalent (or within specified tolerances), + otherwise false. + """ + # Catch the case when the values are different types + if not ( + isinstance(value, self.value.__class__) + and isinstance(self.value, value.__class__) + ): + raise ValueError( + "UVParameter value and supplied values are of different types." + ) + + # If these are numeric types, handle them via allclose + if isinstance(value, (np.ndarray, int, float, complex)): + # Check that we either have a number or an ndarray + if not isinstance(value, np.ndarray) or value.shape == self.value.shape: + if np.allclose( + value, + self.value, + rtol=self.tols[0], + atol=self.tols[1], + equal_nan=True, + ): + return True + return False + else: + # Otherwise just default to checking equality + return value == self.value + class AngleParameter(UVParameter): """ diff --git a/pyuvdata/tests/test_parameter.py b/pyuvdata/tests/test_parameter.py index 7a8d58823..94dc0f3bb 100644 --- a/pyuvdata/tests/test_parameter.py +++ b/pyuvdata/tests/test_parameter.py @@ -657,3 +657,30 @@ def test_spoof(): assert param.value is None param.apply_spoof() assert param.value == 1.0 + + +def test_compare_value_err(): + param = uvp.UVParameter("_test1", value=3.0, tols=[0, 1], expected_type=float) + with pytest.raises( + ValueError, + match="UVParameter value and supplied values are of different types.", + ): + param.compare_value("test") + + +@pytest.mark.parametrize( + "value,status", + [ + (np.array([1, 2]), False), + (np.array([1, 2, 3]), True), + (np.array([1.0, 2.0, 3.0]), True), + (np.array([2, 3, 4]), True), + (np.array([4, 5, 6]), False), + (np.array([1, 2, 3, 4, 5, 6]), False), + ], +) +def test_compare_value(value, status): + param = uvp.UVParameter( + "_test1", value=np.array([1, 2, 3]), tols=[0, 1], expected_type=float + ) + assert param.compare_value(value) == status