Skip to content

Commit

Permalink
Use scipy's decimate. (#115)
Browse files Browse the repository at this point in the history
* change decimate to use scipy decimate

* add decimate doc example

* add doc page and test
  • Loading branch information
d-chambers authored Feb 1, 2023
1 parent 7d3c542 commit a3effc6
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 72 deletions.
18 changes: 13 additions & 5 deletions dascore/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tempfile
from pathlib import Path
from typing import Sequence, Union

import numpy as np

Expand Down Expand Up @@ -58,10 +59,11 @@ def _random_patch(starttime="2017-09-18", network="", station="", tag="random"):
@register_func(EXAMPLE_PATCHES, key="sin_wav")
def sin_wave_patch(
sample_rate=44100,
frequency=100,
frequency: Union[Sequence[float], float] = 100.0,
time_min="2020-01-01",
channel_count=3,
duration=1,
amplitude=10,
):
"""
Return a Patch composed of simple 1 second sin waves.
Expand All @@ -80,13 +82,19 @@ def sin_wave_patch(
The number of distance channels to include.
duration
Duration of signal in seconds.
amplitude
The amplitude of the sin wave.
"""
t_array = np.linspace(0.0, duration, sample_rate * duration)
sin_data = 10 * np.sin(2.0 * np.pi * frequency * t_array)
data = np.stack([sin_data] * channel_count).T
time = to_timedelta64(t_array) + np.datetime64(time_min)
# Get time and distance coords
distance = np.arange(1, channel_count + 1, 1)
time = to_timedelta64(t_array) + np.datetime64(time_min)
freqs = [frequency] if isinstance(frequency, (float, int)) else frequency
# init empty data and add frequencies.
data = np.zeros((len(time), len(distance)))
for freq in freqs:
sin_data = amplitude * np.sin(2.0 * np.pi * freq * t_array)
data += sin_data[..., np.newaxis]

patch = dc.Patch(
data=data,
Expand Down
33 changes: 1 addition & 32 deletions dascore/proc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@
import numpy as np
import pandas as pd
from scipy import ndimage
from scipy.signal import (
cheb2ord,
cheby2,
iirfilter,
medfilt2d,
sosfilt,
sosfiltfilt,
zpk2sos,
)
from scipy.signal import iirfilter, medfilt2d, sosfilt, sosfiltfilt, zpk2sos

import dascore
from dascore.constants import PatchType
Expand Down Expand Up @@ -224,29 +216,6 @@ def sobel_filter(patch: PatchType, dim: str, mode="reflect", cval=0.0) -> PatchT
# return dascore.Patch(data=out, coords=patch.coords, attrs=patch.attrs)


def _lowpass_cheby_2(data, freq, df, maxorder=12, axis=0):
"""
Cheby2-Lowpass Filter used for pre-conditioning decimation.
Based on Obspy's implementation found here:
https://docs.obspy.org/master/_modules/obspy/signal/filter.html#lowpass_cheby_2
"""
nyquist = df * 0.5
# rp - maximum ripple of passband, rs - attenuation of stopband
rp, rs, order = 1, 96, 1e99
ws = freq / nyquist # stop band frequency
wp = ws # pass band frequency
# raise for some bad scenarios
while True:
if order <= maxorder:
break
wp = wp * 0.99
order, wn = cheb2ord(wp, ws, rp, rs, analog=0)
z, p, k = cheby2(order, rs, wn, btype="low", analog=0, output="zpk")
sos = zpk2sos(z, p, k)
return sosfilt(sos, data, axis=axis)


@patch_function()
def median_filter(patch: PatchType, kernel_size=3) -> PatchType:
"""
Expand Down
75 changes: 43 additions & 32 deletions dascore/proc/resample.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""
Module for re-sampling patches.
"""
from typing import Union
from typing import Literal, Union

import numpy as np
from scipy.signal import decimate as scipy_decimate

import dascore
import dascore as dc
import dascore.compat as compat
from dascore.constants import PatchType
from dascore.exceptions import FilterValueError
from dascore.proc.filter import _get_sampling_rate, _lowpass_cheby_2
from dascore.utils.misc import check_evenly_sampled
from dascore.utils.patch import (
get_dim_value_from_kwargs,
Expand All @@ -22,7 +22,7 @@
@patch_function()
def decimate(
patch: PatchType,
lowpass: bool = True,
filter_type: Literal["iir", "fir", None] = "iir",
copy=True,
**kwargs,
) -> PatchType:
Expand All @@ -31,44 +31,55 @@ def decimate(
Parameters
----------
lowpass
If True, first apply a low-pass (anti-alis) filter.
filter_type
filter type to use to avoid aliasing. Options are:
iir - infinite impulse response
fir - finite impulse response
None - No pre-filtering
copy
If True, copy the decimated data array. This is needed if you want
the old array to get gc'ed to free memory otherwise a view is returned.
Only applies when filter_type == None.
**kwargs
Used to pass dimension and factor. For example time=10 is 10x
Used to pass dimension and factor. For example `time=10` is 10x
decimation along the time axis.
Notes
-----
Simply uses scipy.signal.decimate if filter_type is specified. Otherwise,
just slice data long specified dimension only including every n samples.
Examples
--------
# Simple example using iir
>>> import dascore as dc
>>> patch = dc.get_example_patch()
>>> decimated_irr = patch.decimate(time=10, filter_type='iir')
>>> # Example using fir along distance dimension
>>> decimated_fir = patch.decimate(distance=10, filter_type='fir')
"""
# Note: We can't simply use scipy.signal.decimate due to this issue:
# https://github.com/scipy/scipy/issues/15072
dim, axis, factor = get_dim_value_from_kwargs(patch, kwargs)
if lowpass:
# get new niquest
if factor > 16:
# Apply scipy.signal.decimate and geet new coords
if filter_type:
if filter_type == "IRR" and factor > 13:
msg = (
"Automatic filter design is unstable for decimation "
+ "factors above 16. Manual decimation is necessary."
"IRR filter is unstable for decimation factors above"
" 13. Call decimate multiple times."
)
raise FilterValueError(msg)
sr = _get_sampling_rate(patch, dim)
freq = sr * 0.5 / float(factor)
fdata = _lowpass_cheby_2(patch.data, freq, sr, axis=axis)
patch = dascore.Patch(
fdata, coords=patch.coords, attrs=patch.attrs, dims=patch.dims
)

kwargs = {dim: slice(None, None, factor)}
dar = patch._data_array.sel(**kwargs)
# need to create a new xarray so the old, probably large, numpy array
# gets gc'ed, otherwise it stays in memory (if lowpass isn't called)
data = dar.data if not copy else dar.data.copy()
attrs = dar.attrs
# update delta_dim since spacing along dimension has changed
d_attr = f"d_{dim}"
attrs[d_attr] = patch.attrs[d_attr] * factor

return dascore.Patch(data=data, coords=dar.coords, attrs=dar.attrs, dims=dar.dims)
data = scipy_decimate(patch.data, factor, ftype=filter_type, axis=axis)
coords = {x: patch.coords[x] for x in patch.dims}
coords[dim] = coords[dim][::factor]
else: # No filter, simply slice along specified dimension.
dar = patch._data_array.sel(**{dim: slice(None, None, factor)})
# Need to copy so array isn't a slice and holds onto reference of parent
data = dar.data if not copy else dar.data.copy()
coords = dar.coords
# Update delta_dim since spacing along dimension has changed.
attrs = dict(patch.attrs)
attrs[f"d_{dim}"] = patch.attrs[f"d_{dim}"] * factor
out = dc.Patch(data=data, coords=coords, attrs=attrs, dims=patch.dims)
return out


@patch_function()
Expand Down
54 changes: 54 additions & 0 deletions docs/tutorial/processing.qmd
Original file line number Diff line number Diff line change
@@ -1,3 +1,57 @@
---
title: Processing
---
The following shows some simple examples of patch processing. See the
[proc module documentation](`dascore.proc`) for a list of processing functions.

# Decimate

The [decimate patch function](`dascore.Patch.decimate`) decimates a `Patch`
along a given axis while by default performing low-pass filtering to avoid
[aliasing](https://en.wikipedia.org/wiki/Aliasing).

## Data creation

First, we create a patch composed of two sine waves; one above the new
decimation frequency and one below.

```{python}
import dascore as dc
patch = dc.examples.sin_wave_patch(
sample_rate=1000,
frequency=[200, 10],
channel_count=2,
)
_ = patch.viz.wiggle(show=True)
```

## IIR filter

Next we decimate by 10x using IIR filter

```{python}
decimated_iir = patch.decimate(time=10, filter_type='iir')
_ = decimated_iir.viz.wiggle(show=True)
```

Notice the lowpass filter removed the 200 Hz signal and only
the 10Hz wave remains.

## FIR filter

Next we decimate by 10x using FIR filter.

```{python}
decimated_fir = patch.decimate(time=10, filter_type='fir')
_ = decimated_fir.viz.wiggle(show=True)
```

## No Filter

Next, we decimate without a filter to purposely induce aliasing.

```{python}
decimated_no_filt = patch.decimate(time=10, filter_type=None)
_ = decimated_no_filt.viz.wiggle(show=True)
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
"pooch>=1.2",
"pydantic>=1.9.0",
"rich",
"scipy",
"scipy>=1.10.0",
"tables",
"typing_extensions",
"xarray",
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
pytest configuration for dascore
"""
import os
import shutil
from pathlib import Path

Expand All @@ -20,7 +21,7 @@
test_data_path = Path(__file__).parent.absolute() / "test_data"

# A list to register functions that return general spools or patches
# These are to be used for running many different patches/spools through
# These are to be used for running many patches/spools through
# Generic tests.
SPOOL_FIXTURES = []
PATCH_FIXTURES = []
Expand Down Expand Up @@ -80,7 +81,11 @@ def pytest_sessionstart(session):

import dascore as dc

matplotlib.use("Agg")
# If running in CI make sure to turn off matplotlib.
if os.environ.get("CI", False):
matplotlib.use("Agg")

# Ensure debug is set. This disables progress bars which disrupt debugging.
dc._debug = True


Expand Down
28 changes: 28 additions & 0 deletions tests/test_proc/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Tests for decimation
"""
import numpy as np
import pandas as pd
import pytest

import dascore as dc
from dascore.exceptions import ParameterError
from dascore.utils.patch import get_start_stop_step

Expand Down Expand Up @@ -90,6 +92,32 @@ def test_update_delta_dim(self, random_patch):
out = random_patch.decimate(time=10)
assert out.attrs["d_time"] == dt1 * 10

def test_float_32_stability(self, random_patch):
"""
Ensure float32 works for decimation.
See scipy#15072.
"""
ar = np.random.random((10_000, 2)).astype("float32")
dt = dc.to_timedelta64(0.001)
t1 = dc.to_datetime64("2020-01-01")
coords = {
"distance": [1, 2],
"time": np.arange(0, ar.shape[0]) * dt + t1,
}
dims = ("time", "distance")
attrs = {"d_time": dt, "time_min": t1}
patch = dc.Patch(data=ar, coords=coords, dims=dims, attrs=attrs)
# ensure all modes of decimation don't produce NaN values.
decimated_iir = patch.decimate(time=10, filter_type="iir")
assert not np.any(pd.isnull(decimated_iir.data))

decimated_fir = patch.decimate(time=10, filter_type="fir")
assert not np.any(pd.isnull(decimated_fir.data))

decimated_none = patch.decimate(time=10, filter_type=None)
assert not np.any(pd.isnull(decimated_none.data))


class TestResample:
"""
Expand Down

0 comments on commit a3effc6

Please sign in to comment.