Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_463 #464

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dascore/proc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_filter_units,
get_inverted_quant,
invert_quantity,
quant_sequence_to_quant_array,
)
from dascore.utils.docs import compose_docstring
from dascore.utils.misc import (
Expand Down Expand Up @@ -562,6 +563,14 @@ def _get_slope_array(dft_patch, directional, freq_dims):

def _maybe_transform_units(filt, dft_patch, freq_dims):
"""Handle units on filter."""
# Hand the units/partial units in sequence.
units = getattr(filt, "units", None)
try:
filt = np.array(filt)
except ValueError:
filt = quant_sequence_to_quant_array(filt)
if units:
filt = filt * dc.get_quantity(units)
if not isinstance(filt, dc.units.Quantity):
return filt
array, units = filt.magnitude, filt.units
Expand Down
37 changes: 37 additions & 0 deletions dascore/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from functools import cache
from typing import TypeVar

Expand All @@ -11,6 +12,7 @@
from pint import DimensionalityError, Quantity, UndefinedUnitError, Unit

import dascore as dc
from dascore.compat import is_array
from dascore.exceptions import UnitError
from dascore.utils.misc import unbyte
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float
Expand Down Expand Up @@ -304,6 +306,41 @@ def _check_to_units(to_unit, dim):
return out1, out2


def quant_sequence_to_quant_array(sequence: Sequence[Quantity]) -> Quantity:
"""
Convert a sequence of Quantities (eg list) to a Quantity array.

Will simplify all quantities. Raises an error if not all elements have
the same units.

Parameters
----------
sequence
A sequence of Quantities.

Notes
-----
This is probably not efficient for large lists.
"""
if is_array(sequence):
# This is a numpy array, just return multiplied by quantity.
return sequence * get_quantity("dimensionless")
# iterate the sequence and manually convert to base units.
try:
base_unit_sequence = [x.to_base_units() for x in sequence]
except AttributeError:
msg = "Not all values in sequence are quantities."
raise UnitError(msg)
if not len(base_unit_sequence):
return np.array([]) * get_quantity("dimensionless")
units = {x.units for x in base_unit_sequence}
if len(units) != 1:
msg = "Not all values in sequence have compatible units."
raise UnitError(msg)
array = np.array([x.magnitude for x in base_unit_sequence])
return array * next(iter(units))


def __getattr__(name):
"""
Allows arbitrary units (quantities) to be imported from this module.
Expand Down
10 changes: 9 additions & 1 deletion tests/test_proc/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_bad_dims(self, example_patch):
patch.slope_filter(filt=filt, dims=("time", "distance"))

def test_units_raise_no_unit_coords(self, example_patch):
"""Ensure A UnitError is raised if one of hte coords does't have units."""
"""Ensure A UnitError is raised if one of the coords doesn't have units."""
patch = example_patch.set_units(distance="")
filt = np.array([1e3, 1.5e3, 5e3, 10e3]) * get_unit("m/s")
with pytest.raises(UnitError):
Expand All @@ -476,3 +476,11 @@ def test_inverted_units(self, example_patch):
out1 = example_patch.slope_filter(filt=slowness)
out2 = example_patch.slope_filter(filt=filt * get_unit("m/s"))
assert np.allclose(out1.data, out2.data)

def test_units_list(self, example_patch):
"""Ensure units as a list still work (see #463)."""
speed = 5_000 * dc.get_quantity("m/s")
filt = [speed * 0.90, speed * 0.95, speed * 1.05, speed * 1.1]
# The test passes if this line doesn't raise an error.
out = example_patch.slope_filter(filt)
assert isinstance(out, dc.Patch)
50 changes: 50 additions & 0 deletions tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_quantity_str,
get_unit,
invert_quantity,
quant_sequence_to_quant_array,
)


Expand Down Expand Up @@ -303,3 +304,52 @@ def test_array_quantity(self):
array = np.arange(10) * get_quantity("m")
out = convert_units(array, to_units="ft")
np.allclose(array.magnitude, out * 3.28084)


class TestQuantSequenceToQuantArray:
"""Ensure we can convert a quantity sequence to an array."""

def test_valid_sequence_same_units(self):
"""Test with a valid sequence of quantities with the same units."""
meter = get_quantity("m")
sequence = [1 * meter, 2 * meter, 3 * meter]
result = quant_sequence_to_quant_array(sequence)
expected = np.array([1, 2, 3]) * meter
np.testing.assert_array_equal(result.magnitude, expected.magnitude)
assert result.units == expected.units

def test_valid_sequence_different_units(self):
"""Test sequence of quantities with compatible but different units."""
m, cm, km = get_quantity("m"), get_quantity("cm"), get_quantity("km")

sequence = [1 * m, 100 * cm, 0.001 * km]
result = quant_sequence_to_quant_array(sequence)
expected = np.array([1, 1, 1]) * m
assert np.allclose(result.magnitude, expected.magnitude)
assert result.units == expected.units

def test_incompatible_units(self):
"""Test with a sequence of quantities with incompatible units."""
sequence = [1 * get_quantity("m"), 1 * get_quantity("s")]
msg = "Not all values in sequence have compatible units."
with pytest.raises(UnitError, match=msg):
quant_sequence_to_quant_array(sequence)

def test_non_quantity_elements(self):
"""Test with a sequence containing non-quantity elements."""
sequence = [1 * get_quantity("m"), 5]
msg = "Not all values in sequence are quantities."
with pytest.raises(UnitError, match=msg):
quant_sequence_to_quant_array(sequence)

def test_empty_sequence(self):
"""Test with an empty sequence."""
sequence = []
out = quant_sequence_to_quant_array(sequence)
assert isinstance(out, Quantity)

def test_numpy_array_input(self):
"""Test with a numpy array input."""
sequence = np.array([1, 2, 3])
out = quant_sequence_to_quant_array(sequence)
assert isinstance(out, Quantity)
Loading