diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 19d6dac99..0ca0ae6be 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,8 +13,8 @@ Announcements New indicators ^^^^^^^^^^^^^^ -* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. (:pull:`1885`). -* New ``hot_spell_max_magnitude``: yields the magnitude of the most intensive heat wave. (:pull:`1926`). +* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. (:pull:`1885`, :pull:`1778`). +* New ``hot_spell_max_magnitude`` : yields the magnitude of the most intensive heat wave. (:pull:`1926`). * New ``chill_portion`` and ``chill_unit``: chill portion based on the Dynamic Model and chill unit based on the Utah model indicators. (:issue:`1753`, :pull:`1909`). * New ``water_cycle_intensity``: yields the sum of precipitation and actual evapotranspiration. (:issue:`410`, :pull:`1947`). @@ -25,6 +25,9 @@ New features and enhancements * ``xclim.indices.run_length.windowed_max_run_sum`` accumulates positive values across runs and yields the the maximum valued run. (:pull:`1926`). * Helper function ``xclim.indices.helpers.make_hourly_temperature`` to estimate hourly temperatures from daily min and max temperatures. (:pull:`1909`). * New global option ``resample_map_blocks`` to wrap all ``resample().map()`` code inside a ``xr.map_blocks`` to lower the number of dask tasks. Uses utility ``xclim.indices.helpers.resample_map`` and requires ``flox`` to ensure the chunking allows such block-mapping. Defaults to False. (:pull:`1848`). +* ``xclim.indices.run_length.runs_with_holes`` allows to input a condition that must be met for a run to start and a second condition that must be met for the run to stop. (:pull:`1778`). +* New generic compute function ``xclim.indices.generic.thresholded_events`` that finds events based on a threshold condition and returns basic stats for each. See also ``xclim.indices.run_length.find_events``. (:pull:`1778`). +* ``xclim.core.units.rate2amount`` and ``xclim.core.units.amount2rate`` can now also accept quantities (pint objects or strings), in which case the ``dim`` argument must be the ``time`` coordinate through which we can find the sampling rate. (:pull:`1778`). Bug fixes ^^^^^^^^^ diff --git a/tests/test_generic.py b/tests/test_generic.py index 0054dd935..c1c4f9f4e 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -8,6 +8,7 @@ from xclim.core.calendar import doy_to_days_since, select_time from xclim.indices import generic +from xclim.testing.helpers import assert_lazy K2C = 273.15 @@ -768,3 +769,126 @@ def test_spell_length_statistics_multi(tasmin_series, tasmax_series): ) xr.testing.assert_equal(outs, outm) np.testing.assert_allclose(outc, 1) + + +class TestThresholdedEvents: + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_simple(self, pr_series, use_dask): + arr = np.array([0, 0, 0, 1, 2, 3, 0, 3, 3, 10, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 3, 3, 2, 0, 0, 0, 2, 0, 0, 0, 0]) # fmt: skip + pr = pr_series(arr, start="2000-01-01", units="mm") + if use_dask: + pr = pr.chunk(-1) + + with assert_lazy: + out = generic.thresholded_events( + pr, + thresh="1 mm", + op=">=", + window=3, + ) + + assert out.event.size == np.ceil(arr.size / (3 + 1)) + out = out.load().dropna("event", how="all") + + np.testing.assert_array_equal(out.event_length, [3, 3, 4, 4]) + np.testing.assert_array_equal(out.event_effective_length, [3, 3, 4, 4]) + np.testing.assert_array_equal(out.event_sum, [6, 16, 7, 9]) + np.testing.assert_array_equal( + out.event_start, + np.array( + ["2000-01-04", "2000-01-08", "2000-01-16", "2000-01-26"], + dtype="datetime64[ns]", + ), + ) + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_diff_windows(self, pr_series, use_dask): + arr = np.array([0, 0, 0, 1, 2, 3, 0, 3, 3, 10, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 3, 3, 2, 0, 0, 0, 2, 0, 0, 0, 0]) # fmt: skip + pr = pr_series(arr, start="2000-01-01", units="mm") + if use_dask: + pr = pr.chunk(-1) + + # different window stop + out = ( + generic.thresholded_events( + pr, thresh="2 mm", op=">=", window=3, window_stop=4 + ) + .load() + .dropna("event", how="all") + ) + + np.testing.assert_array_equal(out.event_length, [3, 3, 7]) + np.testing.assert_array_equal(out.event_effective_length, [3, 3, 4]) + np.testing.assert_array_equal(out.event_sum, [16, 6, 10]) + np.testing.assert_array_equal( + out.event_start, + np.array( + ["2000-01-08", "2000-01-17", "2000-01-27"], dtype="datetime64[ns]" + ), + ) + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_cftime(self, pr_series, use_dask): + arr = np.array([0, 0, 0, 1, 2, 3, 0, 3, 3, 10, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 3, 3, 2, 0, 0, 0, 2, 0, 0, 0, 0]) # fmt: skip + pr = pr_series(arr, start="2000-01-01", units="mm").convert_calendar("noleap") + if use_dask: + pr = pr.chunk(-1) + + # cftime + with assert_lazy: + out = generic.thresholded_events( + pr, + thresh="1 mm", + op=">=", + window=3, + window_stop=3, + ) + out = out.load().dropna("event", how="all") + + np.testing.assert_array_equal(out.event_length, [7, 4, 4]) + np.testing.assert_array_equal(out.event_effective_length, [6, 4, 4]) + np.testing.assert_array_equal(out.event_sum, [22, 7, 9]) + exp = xr.DataArray( + [1, 2, 3], + dims=("time",), + coords={ + "time": np.array( + ["2000-01-04", "2000-01-16", "2000-01-26"], dtype="datetime64[ns]" + ) + }, + ) + np.testing.assert_array_equal( + out.event_start, exp.convert_calendar("noleap").time + ) + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_freq(self, pr_series, use_dask): + jan = [0, 0, 0, 1, 2, 3, 0, 3, 3, 10, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 3, 2, 3, 2] # fmt: skip + fev = [2, 2, 1, 0, 0, 0, 3, 3, 4, 5, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # fmt: skip + pr = pr_series(np.array(jan + fev), start="2000-01-01", units="mm") + if use_dask: + pr = pr.chunk(-1) + + with assert_lazy: + out = generic.thresholded_events( + pr, thresh="1 mm", op=">=", window=3, freq="MS", window_stop=3 + ) + assert out.event_length.shape == (2, 6) + out = out.load().dropna("event", how="all") + + np.testing.assert_array_equal(out.event_length, [[7, 6, 4], [3, 5, np.nan]]) + np.testing.assert_array_equal( + out.event_effective_length, [[6, 6, 4], [3, 5, np.nan]] + ) + np.testing.assert_array_equal(out.event_sum, [[22, 12, 10], [5, 17, np.nan]]) + np.testing.assert_array_equal( + out.event_start, + np.array( + [ + ["2000-01-04", "2000-01-17", "2000-01-28"], + ["2000-02-01", "2000-02-07", "NaT"], + ], + dtype="datetime64[ns]", + ), + ) diff --git a/tests/test_run_length.py b/tests/test_run_length.py index 9d3d94144..0c64e3a84 100644 --- a/tests/test_run_length.py +++ b/tests/test_run_length.py @@ -126,9 +126,8 @@ def test_rle(ufunc, use_dask, index): @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("index", ["first", "last"]) -def test_extract_events_identity(use_dask, index): - # implement more tests, this is just to show that this reproduces the behaviour - # of rle +def test_runs_with_holes_identity(use_dask, index): + # This test reproduces the behaviour or `rle` values = np.zeros((10, 365, 4, 4)) time = pd.date_range("2000-01-01", periods=365, freq="D") values[:, 1:11, ...] = 1 @@ -137,19 +136,19 @@ def test_extract_events_identity(use_dask, index): if use_dask: da = da.chunk({"a": 1, "b": 2}) - events = rl.extract_events(da != 0, 1, da == 0, 1) + events = rl.runs_with_holes(da != 0, 1, da == 0, 1) expected = da np.testing.assert_array_equal(events, expected) -def test_extract_events(): +def test_runs_with_holes(): values = np.zeros(365) time = pd.date_range("2000-01-01", periods=365, freq="D") a = [0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0] values[0 : len(a)] = a da = xr.DataArray(values, coords={"time": time}, dims=("time")) - events = rl.extract_events(da == 1, 1, da == 0, 3) + events = rl.runs_with_holes(da == 1, 1, da == 0, 3) expected = values * 0 expected[1:11] = 1 diff --git a/tests/test_temperature.py b/tests/test_temperature.py index 00042e912..e0585bf3d 100644 --- a/tests/test_temperature.py +++ b/tests/test_temperature.py @@ -640,6 +640,20 @@ def test_1d(self, tasmax_series, tasmin_series): ) np.testing.assert_allclose(hsf.values[:1], 0) + def test_gap(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="22.1 C", thresh_tasmax="30.1 C", freq="YS", min_gap=3 + ) + np.testing.assert_allclose(hsf.values[:1], 1) + class TestHeatSpellMaxLength: def test_1d(self, tasmax_series, tasmin_series): diff --git a/xclim/core/units.py b/xclim/core/units.py index 0e393079d..0dcdfbb38 100644 --- a/xclim/core/units.py +++ b/xclim/core/units.py @@ -413,6 +413,7 @@ def cf_conversion( FREQ_UNITS = { "D": "d", "W": "week", + "h": "h", } """ Resampling frequency units for :py:func:`xclim.core.units.infer_sampling_units`. @@ -622,8 +623,8 @@ def to_agg_units( def _rate_and_amount_converter( - da: xr.DataArray, - dim: str = "time", + da: Quantified, + dim: str | xr.DataArray = "time", to: str = "amount", sampling_rate_from_coord: bool = False, out_units: str | None = None, @@ -632,10 +633,27 @@ def _rate_and_amount_converter( m = 1 u = None # Default to assume a non-uniform axis label: Literal["lower", "upper"] = "lower" # Default to "lower" label for diff - time = da[dim] + if isinstance(dim, str): + if not isinstance(da, xr.DataArray): + raise ValueError( + "If `dim` is a string, the data to convert must be a DataArray." + ) + time = da[dim] + else: + time = dim + dim = time.name + + # We accept str, Quantity or DataArray + # Ensure the code below has a DataArray, so its simpler + # We convert back at the end + orig_da = da + if isinstance(da, str): + da = str2pint(da) + if isinstance(da, units.Quantity): + da = xr.DataArray(da.magnitude, attrs={"units": f"{da.units:~cf}"}) try: - freq = xr.infer_freq(da[dim]) + freq = xr.infer_freq(time) except ValueError as err: if sampling_rate_from_coord: freq = None @@ -669,7 +687,7 @@ def _rate_and_amount_converter( ), dims=(dim,), name=dim, - attrs=da[dim].attrs, + attrs=time.attrs, ) else: m, u = multi, FREQ_UNITS[base] @@ -683,7 +701,7 @@ def _rate_and_amount_converter( # and `label` has been updated accordingly. dt = ( time.diff(dim, label=label) - .reindex({dim: da[dim]}, method="ffill") + .reindex({dim: time}, method="ffill") .astype(float) ) dt = dt / 1e9 # Convert to seconds @@ -716,15 +734,17 @@ def _rate_and_amount_converter( out = out.assign_attrs(standard_name=new_name) if out_units: - out = cast(xr.DataArray, convert_units_to(out, out_units)) + out = convert_units_to(out, out_units) + if not isinstance(orig_da, xr.DataArray): + out = units.Quantity(out.item(), out.attrs["units"]) return out @_register_conversion("amount2rate", "from") def rate2amount( - rate: xr.DataArray, - dim: str = "time", + rate: Quantified, + dim: str | xr.DataArray = "time", sampling_rate_from_coord: bool = False, out_units: str | None = None, ) -> xr.DataArray: @@ -738,10 +758,10 @@ def rate2amount( Parameters ---------- - rate : xr.DataArray + rate : xr.DataArray, pint.Quantity or string "Rate" variable, with units of "amount" per time. Ex: Precipitation in "mm / d". - dim : str - The time dimension. + dim : str or DataArray + The name of time dimension or the coordinate itself. sampling_rate_from_coord : boolean For data with irregular time coordinates. If True, the diff of the time coordinate will be used as the sampling rate, meaning each data point will be assumed to apply for the interval ending at the next point. See notes. @@ -756,7 +776,7 @@ def rate2amount( Returns ------- - xr.DataArray + xr.DataArray or Quantity Examples -------- @@ -804,8 +824,8 @@ def rate2amount( @_register_conversion("amount2rate", "to") def amount2rate( - amount: xr.DataArray, - dim: str = "time", + amount: Quantified, + dim: str | xr.DataArray = "time", sampling_rate_from_coord: bool = False, out_units: str | None = None, ) -> xr.DataArray: @@ -819,10 +839,10 @@ def amount2rate( Parameters ---------- - amount : xr.DataArray + amount : xr.DataArray, pint.Quantity or string "amount" variable. Ex: Precipitation amount in "mm". - dim : str - The time dimension. + dim : str or xr.DataArray + The name of the time dimension or the time coordinate itself. sampling_rate_from_coord : boolean For data with irregular time coordinates. If True, the diff of the time coordinate will be used as the sampling rate, @@ -839,7 +859,7 @@ def amount2rate( Returns ------- - xr.DataArray + xr.DataArray or Quantity See Also -------- @@ -1157,12 +1177,16 @@ def check_units( ) -def _check_output_has_units(out: xr.DataArray | tuple[xr.DataArray]) -> None: +def _check_output_has_units( + out: xr.DataArray | tuple[xr.DataArray] | xr.Dataset, +) -> None: """Perform very basic sanity check on the output. Indices are responsible for unit management. If this fails, it's a developer's error. """ - if not isinstance(out, tuple): + if isinstance(out, xr.Dataset): + out = out.data_vars.values() + elif not isinstance(out, tuple): out = (out,) for outd in out: diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index bb5a06615..2f0cc1e87 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -307,7 +307,8 @@ class TempHourlyWithIndexing(ResamplingIndicatorWithIndexing): long_name="Number of heat spells", description="{freq} number of heat spells events. A heat spell occurs when the {window}-day " "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " - "All days of the {window}-day period are considered part of the spell.", + "All days of the {window}-day period are considered part of the spell. Gaps of fewer than {min_gap} day(s) are allowed " + "within a spell.", abstract="Number of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", cell_methods="", @@ -341,7 +342,8 @@ class TempHourlyWithIndexing(ResamplingIndicatorWithIndexing): long_name="Longest heat spell", description="{freq} maximum length of heat spells. A heat spell occurs when the {window}-day " "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " - "All days of the {window}-day period are considered part of the spell.", + "All days of the {window}-day period are considered part of the spell. Gaps of fewer than {min_gap} day(s) are allowed " + "within a spell.", abstract="The longest heat spell of a period. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", compute=indices.generic.bivariate_spell_length_statistics, @@ -373,7 +375,8 @@ class TempHourlyWithIndexing(ResamplingIndicatorWithIndexing): long_name="Total length of heat spells.", description="{freq} total length of heat spell events. " "A heat spell occurs when the {window}-day averages of daily minimum and maximum temperatures " - "each exceed {thresh_tasmin} and {thresh_tasmax}. All days of the {window}-day period are considered part of the spell.", + "each exceed {thresh_tasmin} and {thresh_tasmax}. All days of the {window}-day period are considered part of the spell." + "Gaps of fewer than {min_gap} day(s) are allowed within a spell.", abstract="Total length of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", compute=indices.generic.bivariate_spell_length_statistics, diff --git a/xclim/indices/_agro.py b/xclim/indices/_agro.py index 719df0357..02a1a32ad 100644 --- a/xclim/indices/_agro.py +++ b/xclim/indices/_agro.py @@ -1038,7 +1038,7 @@ def _get_first_run_start(_pram): raise ValueError(f"Unknown method_dry_start: {method_dry_start}.") # First and second condition combined in a run length - events = rl.extract_events(da_start, 1, da_stop, window_dry) + events = rl.runs_with_holes(da_start, 1, da_stop, window_dry) run_positions = rl.rle(events) >= (window_not_dry_start + window_wet_start) return _get_first_run(run_positions, date_min_start, date_max_start) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 89ceaf18d..9595b93b0 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -362,6 +362,7 @@ def spell_mask( win_reducer: str, op: str, thresh: float | Sequence[float], + min_gap: int = 1, weights: Sequence[float] = None, var_reducer: str = "all", ) -> xr.DataArray: @@ -384,6 +385,9 @@ def spell_mask( The threshold to compare the rolling statistics against, as ``window_stats op threshold``. If data is a list, this must be a list of the same length with a threshold for each variable. This function does not handle units and can't accept Quantified objects. + min_gap: int + The shortest possible gap between two spells. Spells closer than this are merged by assigning + the gap steps to the merged spell. weights: sequence of floats A list of weights of the same length as the window. Only supported if `win_reducer` is "mean". @@ -434,7 +438,7 @@ def spell_mask( mask = getattr(mask, var_reducer)("variable") # We need to filter out the spells shorter than "window" # find sequences of consecutive respected constraints - cs_s = rl._cumsum_reset_on_zero(mask) + cs_s = rl._cumsum_reset(mask) # end of these sequences cs_s = cs_s.where(mask.shift({"time": -1}, fill_value=0) == 0) # propagate these end of sequences @@ -453,9 +457,17 @@ def spell_mask( if not np.isscalar(thresh): mask = getattr(mask, var_reducer)("variable") # True for all days part of a spell that respected the condition (shift because of the two rollings) - is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) + is_in_spell = (mask.rolling(time=window).sum() >= 1).shift( + time=-(window - 1), fill_value=False + ) # Cut back to the original size is_in_spell = is_in_spell.isel(time=slice(0, data.time.size)) + + if min_gap > 1: + is_in_spell = rl.runs_with_holes(is_in_spell, 1, ~is_in_spell, min_gap).astype( + bool + ) + return is_in_spell @@ -467,12 +479,15 @@ def _spell_length_statistics( op: str, spell_reducer: str | Sequence[str], freq: str, + min_gap: int = 1, resample_before_rl: bool = True, **indexer, ) -> xr.DataArray | Sequence[xr.DataArray]: if isinstance(spell_reducer, str): spell_reducer = [spell_reducer] - is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = spell_mask( + data, window, win_reducer, op, thresh, min_gap=min_gap + ).astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) outs = [] @@ -512,6 +527,7 @@ def spell_length_statistics( op: str, spell_reducer: str, freq: str, + min_gap: int = 1, resample_before_rl: bool = True, **indexer, ): @@ -537,6 +553,9 @@ def spell_length_statistics( Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. + min_gap : int + The shortest possible gap between two spells. Spells closer than this are merged by assigning + the gap steps to the merged spell. resample_before_rl : bool Determines if the resampling should take place before or after the run length encoding (or a similar algorithm) is applied to runs. @@ -588,7 +607,8 @@ def spell_length_statistics( op, spell_reducer, freq, - resample_before_rl, + min_gap=min_gap, + resample_before_rl=resample_before_rl, **indexer, ) @@ -604,6 +624,7 @@ def bivariate_spell_length_statistics( op: str, spell_reducer: str, freq: str, + min_gap: int = 1, resample_before_rl: bool = True, **indexer, ): @@ -633,6 +654,9 @@ def bivariate_spell_length_statistics( Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. + min_gap : int + The shortest possible gap between two spells. Spells closer than this are merged by assigning + the gap steps to the merged spell. resample_before_rl : bool Determines if the resampling should take place before or after the run length encoding (or a similar algorithm) is applied to runs. @@ -656,6 +680,7 @@ def bivariate_spell_length_statistics( op, spell_reducer, freq, + min_gap, resample_before_rl, **indexer, ) @@ -1478,3 +1503,68 @@ def detrend( trend = xr.polyval(ds[dim], coeff.polyfit_coefficients) with xr.set_options(keep_attrs=True): return ds - trend + + +@declare_relative_units(thresh="") +def thresholded_events( + data: xr.DataArray, + thresh: Quantified, + op: str, + window: int, + thresh_stop: Quantified | None = None, + op_stop: str | None = None, + window_stop: int = 1, + freq: str | None = None, +) -> xr.Dataset: + r"""Thresholded events. + + Finds all events along the time dimension. An event starts if the start condition is fulfilled for a given number of consecutive time steps. + It ends when the end condition is fulfilled for a given number of consecutive time steps. + + Conditions are simple comparison of the data with a threshold: ``cond = data op thresh``. + The end conditions defaults to the negation of the start condition. + + The resulting ``event`` dimension always has its maximal possible size : ``data.size / (window + window_stop)``. + + Parameters + ---------- + data : xr.DataArray + Variable. + thresh : Quantified + Threshold defining the event. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator defining the event, e.g. arr > thresh. + window: int + Number of time steps where the event condition must be true to start an event. + thresh_stop : Quantified, optional + Threshold defining the end of an event. Defaults to `thresh`. + op_stop : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"}, optional + Logical operator for the end of an event. Defaults to the opposite of `op`. + window_stop: int, optional + Number of time steps where the end condition must be true to end an event. Defaults to 1. + freq: str, optional + A frequency to divide the data into periods. If absent, the output has not time dimension. + If given, the events are searched within in each resample period independently. + + Returns + ------- + xr.Dataset, same shape as the data except the time dimension is replaced by an "event" dimension. + event_length: The number of time steps in each event + event_effective_length: The number of time steps of even event where the start condition is true. + event_sum: The sum within each event, only considering the steps where start condition is true. + event_start: The datetime of the start of the run. + """ + thresh = convert_units_to(thresh, data) + + # Start and end conditions + da_start = compare(data, op, thresh) + if thresh_stop is None and op_stop is None: + da_stop = ~da_start + else: + thresh_stop = convert_units_to(thresh_stop or thresh, data) + if op_stop is not None: + da_stop = compare(data, op_stop, thresh_stop) + else: + da_stop = ~compare(data, op, thresh_stop) + + return rl.find_events(da_start, window, da_stop, window_stop, data, freq) diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 0ea34a85e..5d5e71d97 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -12,6 +12,7 @@ from warnings import warn import numpy as np +import pandas as pd import xarray as xr from numba import njit from xarray.core.utils import get_temp_dimname @@ -124,10 +125,11 @@ def resample_and_rl( return out -def _cumsum_reset_on_zero( +def _cumsum_reset( da: xr.DataArray, dim: str = "time", index: str = "last", + reset_on_zero: bool = True, ) -> xr.DataArray: """Compute the cumulative sum for each series of numbers separated by zero. @@ -140,6 +142,9 @@ def _cumsum_reset_on_zero( index : {'first', 'last'} If 'first', the largest value of the cumulative sum is indexed with the first element in the run. If 'last'(default), with the last element in the run. + reset_on_zero : bool + If True, the cumulative sum is reset on each zero value of `da`. Otherwise, the cumulative sum resets + on NaNs. Default is True. Returns ------- @@ -151,7 +156,10 @@ def _cumsum_reset_on_zero( # Example: da == 100110111 -> cs_s == 100120123 cs = da.cumsum(dim=dim) # cumulative sum e.g. 111233456 - cs2 = cs.where(da == 0) # keep only numbers at positions of zeroes e.g. N11NN3NNN + cond = da == 0 if reset_on_zero else da.isnull() # reset condition + cs2 = cs.where( + cond + ) # keep only numbers at positions of zeroes e.g. N11NN3NNN (default) cs2[{dim: 0}] = 0 # put a zero in front e.g. 011NN3NNN cs2 = cs2.ffill(dim=dim) # e.g. 011113333 out = cs - cs2 @@ -200,7 +208,7 @@ def rle( da = da[{dim: slice(None, None, -1)}] # Get cumulative sum for each series of 1, e.g. da == 100110111 -> cs_s == 100120123 - cs_s = _cumsum_reset_on_zero(da, dim) + cs_s = _cumsum_reset(da, dim) # Keep total length of each series (and also keep 0's), e.g. 100120123 -> 100N20NN3 # Keep numbers with a 0 to the right and also the last number @@ -553,7 +561,7 @@ def find_boundary_run(runs, position): else: # _cusum_reset_on_zero() is an intermediate step in rle, which is sufficient here - d = _cumsum_reset_on_zero(da, dim=dim, index=position) + d = _cumsum_reset(da, dim=dim, index=position) d = xr.where(d >= window, 1, 0) # for "first" run, return "first" element in the run (and conversely for "last" run) if freq is not None: @@ -769,7 +777,7 @@ def get_out(rls): return da.copy(data=out.transpose(*da.dims).data) -def extract_events( +def runs_with_holes( da_start: xr.DataArray, window_start: int, da_stop: xr.DataArray, @@ -799,19 +807,18 @@ def extract_events( Notes ----- A season (as defined in ``season``) could be considered as an event with `window_stop == window_start` and `da_stop == 1 - da_start`, - although it has more constraints on when to start and stop a run through the `date` argument. + although it has more constraints on when to start and stop a run through the `date` argument and only one season can be found. """ da_start = da_start.astype(int).fillna(0) da_stop = da_stop.astype(int).fillna(0) - start_runs = _cumsum_reset_on_zero(da_start, dim=dim, index="first") - stop_runs = _cumsum_reset_on_zero(da_stop, dim=dim, index="first") + start_runs = _cumsum_reset(da_start, dim=dim, index="first") + stop_runs = _cumsum_reset(da_stop, dim=dim, index="first") start_positions = xr.where(start_runs >= window_start, 1, np.nan) stop_positions = xr.where(stop_runs >= window_stop, 0, np.nan) # start positions (1) are f-filled until a stop position (0) is met runs = stop_positions.combine_first(start_positions).ffill(dim=dim).fillna(0) - return runs @@ -1737,3 +1744,146 @@ def suspicious_run( keep_attrs=True, kwargs={"window": window, "op": op, "thresh": thresh}, ) + + +def _find_events(da_start, da_stop, data, window_start, window_stop): + """Actual finding of events for each period. + + Get basic blocks to work with, our runs with holes and the lengths of those runs. + Series of ones indicating where we have continuous runs with pauses + not exceeding `window_stop` + """ + runs = runs_with_holes(da_start, window_start, da_stop, window_stop) + + # Compute the length of freezing rain events + # I think int16 is safe enough, fillna first to suppress warning + ds = rle(runs).fillna(0).astype(np.int16).to_dataset(name="event_length") + # Time duration where the precipitation threshold is exceeded during an event + # (duration of complete run - duration of holes in the run ) + ds["event_effective_length"] = _cumsum_reset( + da_start.where(runs == 1), index="first", reset_on_zero=False + ).astype(np.int16) + + if data is not None: + # Ex: Cumulated precipitation in a given freezing rain event + ds["event_sum"] = _cumsum_reset( + data.where(runs == 1), index="first", reset_on_zero=False + ) + + # Keep time as a variable, it will be used to keep start of events + ds["event_start"] = ds["time"].broadcast_like(ds) # .astype(int) + # We convert to an integer for the filtering, time object won't do well in the apply_ufunc/vectorize + time_min = ds.event_start.min() + ds["event_start"] = ds.event_start.copy( + data=(ds.event_start - time_min).values.astype("timedelta64[s]").astype(int) + ) + + # Filter events: Reduce time dimension + def _filter_events(da, rl, max_event_number): + out = np.full(max_event_number, np.nan) + events_start = da[rl > 0] + out[: len(events_start)] = events_start + return out + + # Dask inputs need to be told their length before computing anything. + max_event_number = int(np.ceil(da_start.time.size / (window_start + window_stop))) + ds = xr.apply_ufunc( + _filter_events, + ds, + ds.event_length, + input_core_dims=[["time"], ["time"]], + output_core_dims=[["event"]], + kwargs=dict(max_event_number=max_event_number), + dask_gufunc_kwargs=dict(output_sizes={"event": max_event_number}), + dask="parallelized", + vectorize=True, + ) + + # convert back start to a time + if time_min.dtype == "O": + # Can't add a numpy array of timedeltas to an array of cftime (because they have non-compatible dtypes) + # Also, we can't add cftime to NaTType. So we fill with negative timedeltas and mask them after the addition + + def _get_start_cftime(deltas, time_min=None): + starts = time_min + pd.to_timedelta(deltas, "s").to_pytimedelta() + starts[starts < time_min] = np.nan + return starts + + ds["event_start"] = xr.apply_ufunc( + _get_start_cftime, + ds.event_start.fillna(-1), + dask="parallelized", + kwargs={"time_min": time_min.item()}, + output_dtypes=[time_min.dtype], + ) + else: + ds["event_start"] = ds.event_start.copy( + data=time_min.values + ds.event_start.data.astype("timedelta64[s]") + ) + + ds["event"] = np.arange(1, ds.event.size + 1) + ds["event_length"].attrs["units"] = "1" + ds["event_effective_length"].attrs["units"] = "1" + ds["event_start"].attrs["units"] = "" + if data is not None: + ds["event_sum"].attrs["units"] = data.units + return ds + + +# TODO: Implement more event stats ? (max, effective sum, etc) +def find_events( + condition: xr.DataArray, + window: int, + condition_stop: xr.DataArray | None = None, + window_stop: int = 1, + data: xr.DataArray | None = None, + freq: str | None = None, +): + """Find events (runs). + + An event starts with a run of ``window`` consecutive True values in the condition + and stops with ``window_stop`` consecutive True values in the stop condition. + + This returns a Dataset with each event along an `event` dimension. It does not + perform statistics over the events like other function in this module do. + + Parameters + ---------- + condition : DataArray of boolean values + The boolean mask, true where the start condition of the event is fulfilled. + window : int + The number of consecutive True values for an event to start. + condition_stop : DataArray of boolean values, optional + The stopping boolean mask, true where the end condition of the event is fulfilled. + Defaults to the opposite of ``condition``. + window_stop : int + The number of consecutive True values in ``condition_stop`` for an event to end. + Defaults to 1. + data : DataArray, optional + The actual data. If present, its sum within each event is added to the output. + freq : str, optional + A frequency to divide the data into periods. If absent, the output has not time dimension. + If given, the events are searched within in each resample period independently. + + Returns + ------- + xr.Dataset, same shape as the data it has a new "event" dimension (and the time dimension is resample or removed, according to ``freq``). + event_length: The number of time steps in each event + event_effective_length: The number of time steps of even event where the start condition is true. + event_start: The datetime of the start of the run. + event_sum: The sum within each event, only considering the steps where start condition is true. Only present if ``data`` is given. + """ + if condition_stop is None: + condition_stop = ~condition + + if freq is None: + return _find_events(condition, condition_stop, data, window, window_stop) + + ds = xr.Dataset({"da_start": condition, "da_stop": condition_stop}) + if data is not None: + ds = ds.assign(data=data) + return ds.resample(time=freq).map( + lambda grp: _find_events( + grp.da_start, grp.da_stop, grp.get("data", None), window, window_stop + ) + )