Skip to content

Commit

Permalink
Adding compare_value method to make it easier to check UVParameter va…
Browse files Browse the repository at this point in the history
…lues (particularly w/ ndarrays)
  • Loading branch information
kartographer committed Sep 21, 2023
1 parent bda36ae commit 5af9238
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
42 changes: 42 additions & 0 deletions pyuvdata/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,48 @@ def check_acceptability(self):
)
return False, message

def compare_value(self, value):
"""
Compare UVParameter value to a supplied value.
Parameters
----------
value
The value to compare against that stored in the UVParameter object. Must
be the same type.
Returns
-------
same : bool
True if the values are equivalent (or within specified tolerances),
otherwise false.
"""
# Catch the case when the values are different types
if not (
isinstance(value, self.value.__class__)
and isinstance(self.value, value.__class__)
):
raise ValueError(
"UVParameter value and supplied values are of different types."
)

# If these are numeric types, handle them via allclose
if isinstance(value, (np.ndarray, int, float, complex)):
# Check that we either have a number or an ndarray
if not isinstance(value, np.ndarray) or value.shape == self.value.shape:
if np.allclose(
value,
self.value,
rtol=self.tols[0],
atol=self.tols[1],
equal_nan=True,
):
return True
return False
else:
# Otherwise just default to checking equality
return value == self.value


class AngleParameter(UVParameter):
"""
Expand Down
27 changes: 27 additions & 0 deletions pyuvdata/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,30 @@ def test_spoof():
assert param.value is None
param.apply_spoof()
assert param.value == 1.0


def test_compare_value_err():
param = uvp.UVParameter("_test1", value=3.0, tols=[0, 1], expected_type=float)
with pytest.raises(
ValueError,
match="UVParameter value and supplied values are of different types.",
):
param.compare_value("test")


@pytest.mark.parametrize(
"value,status",
[
(np.array([1, 2]), False),
(np.array([1, 2, 3]), True),
(np.array([1.0, 2.0, 3.0]), True),
(np.array([2, 3, 4]), True),
(np.array([4, 5, 6]), False),
(np.array([1, 2, 3, 4, 5, 6]), False),
],
)
def test_compare_value(value, status):
param = uvp.UVParameter(
"_test1", value=np.array([1, 2, 3]), tols=[0, 1], expected_type=float
)
assert param.compare_value(value) == status

0 comments on commit 5af9238

Please sign in to comment.