Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SPI performance #1311

Merged
merged 89 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
be3b9c6
pr_cal now created within SPI + refactoring
coxipi Feb 28, 2023
2d5b942
Better variable name
coxipi Feb 28, 2023
a939edb
params can be used as input, and computed beforehand (`get_params`)
coxipi Mar 1, 2023
aeb8038
Clearer var names and better doc
coxipi Mar 1, 2023
3c5ac71
Remove uses_range
coxipi Mar 1, 2023
ed4164d
cal_range type changed, new warnings
coxipi Mar 1, 2023
8f22ff1
cal_range type change 2/2
coxipi Mar 1, 2023
cbca142
fitting/rolling/resampling defined in a separate function
coxipi Mar 1, 2023
16f4b6b
resampling/rolling/fitting out of SPI
coxipi Mar 2, 2023
e990406
Update doc and formatting
coxipi Mar 2, 2023
ca577c2
indexer support
coxipi Mar 2, 2023
fde2441
update doc
coxipi Mar 2, 2023
f9c28c6
Merge branch 'master' into fix_spi_performance
Zeitsperre Mar 8, 2023
f9c73f4
Merge branch 'master' into fix_spi_performance
Zeitsperre Mar 10, 2023
0fa83f4
typo function isnull
coxipi Mar 15, 2023
5b3bbe6
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Mar 15, 2023
f747dea
More documentation, remove useless step
coxipi Jun 19, 2023
2b6dfe9
Refactoring, simplifications, shorter msgs
coxipi Jun 22, 2023
09d72b3
More documentation, some cleaning
coxipi Jun 22, 2023
17c1463
SPI accepts da & params as sufficient arguments
coxipi Jun 28, 2023
d6d11d8
Merge branch 'master' into fix_spi_performance
Zeitsperre Jun 28, 2023
d2dd196
Remove problem in declare_units and move functions
coxipi Jun 30, 2023
50b7f71
remvoe pr_cal from declare_units
coxipi Jun 30, 2023
fdef084
Refactor functions names, no more "spx"
coxipi Jul 1, 2023
3f90a30
faster _get_standardized_index & refactoring
coxipi Jul 11, 2023
b02750b
fix variable names
coxipi Jul 11, 2023
e81ad73
Fixing SPEI units
coxipi Jul 11, 2023
5b7d483
add issue number
coxipi Jul 11, 2023
f8bae10
cal_range -> cal_{start|end} && offset in SPEI public
coxipi Jul 14, 2023
900225e
remove DATE_TUPLE type, etc
coxipi Jul 14, 2023
b60cd2d
Add offset as possible input in standardized_fit_params
coxipi Jul 20, 2023
4a293a3
new test on modularity and fix API error in SPI
coxipi Jul 20, 2023
fec71df
update CHANGES
coxipi Jul 21, 2023
993b639
remove comments
coxipi Jul 21, 2023
10fc30b
Merge branch 'master' into fix_spi_performance
Zeitsperre Jul 24, 2023
f3c0b0f
Merge branch 'master' into fix_spi_performance
Zeitsperre Jul 25, 2023
c8891c9
clipped std_index values
coxipi Jul 27, 2023
410c86d
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Jul 27, 2023
486d5fe
format docstring
coxipi Jul 27, 2023
852bb07
Better description CHANGES.rst
coxipi Jul 27, 2023
fcbca5e
Merge branch 'master' into fix_spi_performance
Zeitsperre Jul 31, 2023
f62cf04
remove unrelated comments
coxipi Aug 1, 2023
85b815f
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Aug 1, 2023
64052a5
more docuemntation on std_index bounds
coxipi Aug 1, 2023
c69c9cf
Merge branch 'master' of https://github.com/Ouranosinc/xclim into fix…
coxipi Aug 2, 2023
9b5b3b3
Merge branch 'master' into fix_spi_performance
coxipi Aug 14, 2023
abdc00b
Broadcast params if needed
coxipi Aug 16, 2023
7782481
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Aug 16, 2023
2b26e88
Use template to broadcast
coxipi Aug 16, 2023
6bba201
better notation
coxipi Aug 16, 2023
ea1d3ea
correct function name in docstring test
coxipi Aug 16, 2023
35f3897
Merge branch 'master' into fix_spi_performance
Zeitsperre Aug 25, 2023
485e718
Improve description of SPEI offset
coxipi Sep 11, 2023
8f1868e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2023
06271eb
add notes about NaNs probable origin in SPEI
coxipi Sep 11, 2023
bd1d122
Merge branch 'master' of https://github.com/Ouranosinc/xclim into fix…
coxipi Sep 12, 2023
18f346d
Update CHANGES
coxipi Sep 12, 2023
b15e566
Improve documentation (review)
coxipi Sep 12, 2023
25c1d01
Improve documentation (review) 2/2
coxipi Sep 12, 2023
a2283bf
add docstring for _get_standardized_index
coxipi Sep 12, 2023
02f3821
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Sep 12, 2023
693f588
typo in docsring
coxipi Sep 12, 2023
272d479
Merge branch 'master' into fix_spi_performance
Zeitsperre Oct 10, 2023
211c955
simpler loop over group_idx
coxipi Oct 18, 2023
a71963f
Merge branch 'fix_spi_performance' of https://github.com/Ouranosinc/x…
coxipi Oct 18, 2023
e0f008e
Merge branch 'master' of https://github.com/Ouranosinc/xclim into fix…
coxipi Oct 18, 2023
b3d3713
more comments, update docstrings
coxipi Oct 18, 2023
f6acc45
more simple management of the SPEI offset
coxipi Oct 18, 2023
1553208
Merge branch 'master' into fix_spi_performance
Zeitsperre Oct 18, 2023
2ca3f1c
Merge branch 'master' into fix_spi_performance
Zeitsperre Oct 18, 2023
c717169
Better doc & simplifications (review)
coxipi Oct 18, 2023
19deb92
only accept "D" & "MS" freqs
coxipi Oct 19, 2023
2aefe81
put std_index functions in stats.py, get rid of group_idx loop
coxipi Oct 19, 2023
96e320a
non-hardcoded definition of group indices
coxipi Oct 19, 2023
ffa1e10
correct infer_freq usage
coxipi Oct 20, 2023
3e28bbb
optimized dist_method to simplify std_index functions
coxipi Oct 21, 2023
b809ab6
remove uneeded function name in header
coxipi Oct 21, 2023
cc8b3e6
fix xci -> xci.stat where appropriate
coxipi Oct 23, 2023
3b4b8f2
typo
coxipi Oct 23, 2023
9ec645a
remove useless input/output core_dims from xr.apply_ufunc
coxipi Oct 23, 2023
3279515
Merge branch 'master' of https://github.com/Ouranosinc/xclim into fix…
coxipi Oct 23, 2023
845cac4
no need to broadcast `params_norm`
coxipi Oct 23, 2023
e461450
Merge branch 'master' into fix_spi_performance
Zeitsperre Oct 23, 2023
c702e48
fix merge of CHANGES.rst
Zeitsperre Oct 23, 2023
b3b9d7c
docstring and typing adjustments
Zeitsperre Oct 23, 2023
6929929
Merge remote-tracking branch 'origin/fix_spi_performance' into fix_sp…
Zeitsperre Oct 23, 2023
69ab14a
whitespace
Zeitsperre Oct 23, 2023
7f46856
Update CHANGES.rst (review)
coxipi Oct 23, 2023
59e7df6
only chunk if necessary, check if dask used (review)
coxipi Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xclim/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def unprefix_attrs(source: dict, keys: Sequence, prefix: str):
InputKind.STRING: "str",
InputKind.DAY_OF_YEAR: "date (string, MM-DD)",
InputKind.DATE: "date (string, YYYY-MM-DD)",
InputKind.DATE_TUPLE: "date tuple ((string_1, YYYY-MM-DD), (string_2, YYYY-MM-DD))",
InputKind.BOOL: "boolean",
InputKind.DATASET: "Dataset, optional",
InputKind.KWARGS: "",
Expand Down
6 changes: 6 additions & 0 deletions xclim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,12 @@ class InputKind(IntEnum):

Annotation : ``bool``, may be optional.
"""
DATE_TUPLE = 10
"""A tuple of dates in the YYYY-MM-DD format, may include a time.

!!! not sure how to write this !!!
Annotation : `Tuple[xclim.core.utils.DateStr, xclim.core.utils.DateStr]`
"""
KWARGS = 50
"""A mapping from argument name to value.

Expand Down
273 changes: 171 additions & 102 deletions xclim/indices/_agro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import warnings
from typing import Tuple

import numpy as np
import xarray
Expand All @@ -16,7 +17,7 @@
rate2amount,
to_agg_units,
)
from xclim.core.utils import DayOfYearStr, Quantified, uses_dask
from xclim.core.utils import DateStr, DayOfYearStr, Quantified, uses_dask
from xclim.indices._threshold import (
first_day_temperature_above,
first_day_temperature_below,
Expand Down Expand Up @@ -884,28 +885,97 @@ def water_budget(
return out


def _preprocess_spx(da, freq, window, **indexer):
_, base, _, _ = parse_offset(freq or xarray.infer_freq(da.time))
try:
group = {"D": "time.dayofyear", "M": "time.month"}[base]
coxipi marked this conversation as resolved.
Show resolved Hide resolved
except KeyError():
raise ValueError(f"Standardized index with frequency `{freq}` not supported.")
if freq:
da = da.resample(time=freq).mean(keep_attrs=True)
if window > 1:
da = da.rolling(time=window).mean(skipna=False, keep_attrs=True)
da = select_time(da, **indexer)
return da, group


def _compute_spx_fit_params(
da, cal_range, freq, window, dist, method, group=None, **indexer
):
# "WPM" method doesn't seem to work for gamma or pearson3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, PWM should work with Gamma.

dist_and_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]}
if dist not in dist_and_methods:
raise NotImplementedError(f"The distribution `{dist}` is not supported.")
if method not in dist_and_methods[dist]:
raise NotImplementedError(
f"The method `{method}` is not supported for distribution `{dist}`."
)

if group is None:
da, group = _preprocess_spx(da, freq, window)

if cal_range:
da = da.sel(time=slice(cal_range[0], cal_range[1]))

if uses_dask(da) and len(da.chunks[da.get_axis_num("time")]) > 1:
warnings.warn(
"The input data is chunked on time dimension and must be fully rechunked to"
" run `fit` on groups ."
" Beware, this operation can significantly increase the number of tasks dask"
" has to handle.",
stacklevel=2,
)
da = da.chunk({"time": -1})

da = select_time(da, **indexer)

def wrap_fit(da):
if indexer != {}:
if da.isnull.all():
select_dims = {d: 0 for d in da.dims if d != "time"}
with xarray.set_options(keep_attrs=True):
params = (
fit(da.isel(time=slice(0, 2))[select_dims], dist, method)
* da.isel(time=0, drop=True)
* np.NaN
)
return params
coxipi marked this conversation as resolved.
Show resolved Hide resolved
return fit(da, dist, method)

params = da.groupby(group).map(wrap_fit)
params.attrs["Calibration period"] = str(cal_range)

return params


@declare_units(
pr="[precipitation]",
pr_cal="[precipitation]",
params="[]",
)
def standardized_precipitation_index(
pr: xarray.DataArray,
pr_cal: Quantified,
freq: str = "MS",
coxipi marked this conversation as resolved.
Show resolved Hide resolved
cal_range: tuple[DateStr, DateStr] | None = None,
coxipi marked this conversation as resolved.
Show resolved Hide resolved
params: Quantified | None = None,
freq: str | None = "MS",
window: int = 1,
dist: str = "gamma",
method: str = "APP",
**indexer,
) -> xarray.DataArray:
r"""Standardized Precipitation Index (SPI).

Parameters
----------
pr : xarray.DataArray
Daily precipitation.
pr_cal : xarray.DataArray
Daily precipitation used for calibration. Usually this is a temporal subset of `pr` over some reference period.
freq : str
Resampling frequency. A monthly or daily frequency is expected.
cal_range: Tuple[DateStr, DateStr] | None
Dates used to take a subset the input dataset for calibration. The tuple is formed by two `DateStr`,
i.e. a `str` in format `"YYYY-MM-DD"`. Default option `None` means that the full range of the input dataset is used.
params: xarray.DataArray
Fit parameters.
freq : str | None
Resampling frequency. A monthly or daily frequency is expected. Option `None` assumes that desired resampling
has already been applied input dataset and will skip the resampling step.
window : int
Averaging window length relative to the resampling frequency. For example, if `freq="MS"`,
i.e. a monthly resampling, the window is an integer number of months.
Expand All @@ -915,6 +985,9 @@ def standardized_precipitation_index(
method : {'APP', 'ML'}
Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method
uses a deterministic function that doesn't involve any optimization.
indexer
Indexing parameters to compute the indicator on a temporal subset of the data.
It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`.

Returns
-------
Expand All @@ -923,123 +996,110 @@ def standardized_precipitation_index(

Notes
-----
The length `N` of the N-month SPI is determined by choosing the `window = N`.
Supported statistical distributions are: ["gamma"]
* The length `N` of the N-month SPI is determined by choosing the `window = N`.
* Supported statistical distributions are: ["gamma", "fisk"], where "fisk" is scipy's implementation of
a log-logistic distribution
* If `params` is given as input, it overrides the `cal_range` option.

Example
-------
>>> from datetime import datetime
>>> from xclim.indices import standardized_precipitation_index
>>> ds = xr.open_dataset(path_to_pr_file)
>>> pr = ds.pr
>>> pr_cal = pr.sel(time=slice(datetime(1990, 5, 1), datetime(1990, 8, 31)))
>>> cal_range = ("1990-05-01", "1990-08-31")
>>> spi_3 = standardized_precipitation_index(
... pr, pr_cal, freq="MS", window=3, dist="gamma", method="ML"
... pr,
... cal_range=cal_range,
... freq="MS",
... window=3,
... dist="gamma",
... method="ML",
... ) # Computing SPI-3 months using a gamma distribution for the fit
>>> # Fitting parameters can also be obtained ...
>>> params = _compute_spx_fit_params(
... pr,
... cal_range,
... freq="MS",
... window=3,
... dist="gamma",
... method="ML",
... ) # First getting params
>>> # ... and used as input
>>> spi_3 = standardized_precipitation_index(pr, params=params)

References
----------
:cite:cts:`mckee_relationship_1993`

"""
# "WPM" method doesn't seem to work for gamma or pearson3
dist_and_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]}
if dist not in dist_and_methods:
raise NotImplementedError(f"The distribution `{dist}` is not supported.")
if method not in dist_and_methods[dist]:
raise NotImplementedError(
f"The method `{method}` is not supported for distribution `{dist}`."
uses_input_params = params is not None
if cal_range and uses_input_params:
raise ValueError(
"Inputing both calibration dates (`cal_range`) and calibration parameters (`params`) is not accepted,"
"input only one or neither of those options (the latter case reverts to default behaviour which performs a calibration"
"with the full input dataset)."
)

# calibration period
cal_period = pr_cal.time[[0, -1]].dt.strftime("%Y-%m-%dT%H:%M:%S").values.tolist()

# Determine group type
if freq == "D" or freq is None:
freq = "D"
group = "time.dayofyear"
else:
_, base, _, _ = parse_offset(freq)
if base in ["M"]:
group = "time.month"
else:
raise NotImplementedError(f"Resampling frequency `{freq}` not supported.")

# Resampling precipitations
if freq != "D":
pr = pr.resample(time=freq).mean(keep_attrs=True)
pr_cal = pr_cal.resample(time=freq).mean(keep_attrs=True)

def needs_rechunking(da):
if uses_dask(da) and len(da.chunks[da.get_axis_num("time")]) > 1:
warnings.warn(
"The input data is chunked on time dimension and must be fully rechunked to"
" run `fit` on groups ."
" Beware, this operation can significantly increase the number of tasks dask"
" has to handle.",
stacklevel=2,
pr, group = _preprocess_spx(pr, freq, window, **indexer)
if uses_input_params is False:
params = _compute_spx_fit_params(
pr, cal_range, freq, window, dist, method, group=group
)
params_dict = dict(params.groupby(group.rsplit(".")[1]))

def ppf_to_cdf(da, params):
# ppf to cdf
if dist in ["gamma", "fisk"]:
prob_pos = dist_method("cdf", params, pr.where(pr > 0))
prob_zero = (pr == 0).sum("time") / pr.notnull().sum("time")
prob = prob_zero + (1 - prob_zero) * prob_pos
# Invert to normal distribution with ppf and obtain SPI
params_norm = xarray.DataArray(
[0, 1],
dims=["dparams"],
coords=dict(dparams=(["loc", "scale"])),
attrs=dict(scipy_dist="norm"),
)
sub_spi = dist_method("ppf", params_norm, prob)
return sub_spi

def get_sub_spi(pr):
group_key = pr[group][0].values.item()
sub_params = params_dict[group_key]
if indexer != {}:
if pr.isnull().all():
select_dims = {d: 0 for d in pr.dims if d != "time"}
sub_spi = ppf_to_cdf(
pr.isel(time=slice(0, 2))[select_dims], sub_params[select_dims]
)
return True
return False

if needs_rechunking(pr):
pr = pr.chunk({"time": -1})
coxipi marked this conversation as resolved.
Show resolved Hide resolved
if needs_rechunking(pr_cal):
pr_cal = pr_cal.chunk({"time": -1})

# Rolling precipitations
if window > 1:
pr = pr.rolling(time=window).mean(skipna=False, keep_attrs=True)
pr_cal = pr_cal.rolling(time=window).mean(skipna=False, keep_attrs=True)

# Obtain fitting params and expand along time dimension
def resample_to_time(da, da_ref):
if freq == "D":
da = resample_doy(da, da_ref)
else:
da = da.rename(month="time").reindex(time=da_ref.time.dt.month)
da["time"] = da_ref.time
return da

params = pr_cal.groupby(group).map(lambda x: fit(x, dist, method))
params = resample_to_time(params, pr)

# ppf to cdf
if dist in ["gamma", "fisk"]:
prob_pos = dist_method("cdf", params, pr.where(pr > 0))
prob_zero = resample_to_time(
pr.groupby(group).map(
lambda x: (x == 0).sum("time") / x.notnull().sum("time")
),
pr,
with xarray.set_options(keep_attrs=True):
sub_spi = sub_spi * pr * np.NaN
return sub_spi
return ppf_to_cdf(pr, sub_params)

spi = pr.groupby(group).map(get_sub_spi)
spi.attrs = params.attrs
if uses_input_params:
spi.attrs["Calibration period"] = (
spi.attrs["Calibration period"] + "(input parameters)"
)
prob = prob_zero + (1 - prob_zero) * prob_pos

# Invert to normal distribution with ppf and obtain SPI
params_norm = xarray.DataArray(
[0, 1],
dims=["dparams"],
coords=dict(dparams=(["loc", "scale"])),
attrs=dict(scipy_dist="norm"),
)
spi = dist_method("ppf", params_norm, prob)
spi.attrs["units"] = ""
spi.attrs["calibration_period"] = cal_period

return spi


@declare_units(
wb="[precipitation]",
wb_cal="[precipitation]",
params="[]",
)
def standardized_precipitation_evapotranspiration_index(
wb: xarray.DataArray,
wb_cal: Quantified,
cal_range: tuple[DateStr, DateStr] | None = None,
params: Quantified | None = None,
freq: str = "MS",
window: int = 1,
dist: str = "gamma",
method: str = "APP",
dist: str = "fisk",
method: str = "ML",
**indexer,
) -> xarray.DataArray:
r"""Standardized Precipitation Evapotranspiration Index (SPEI).

Expand All @@ -1051,10 +1111,14 @@ def standardized_precipitation_evapotranspiration_index(
----------
wb : xarray.DataArray
Daily water budget (pr - pet).
wb_cal : xarray.DataArray
Daily water budget used for calibration.
freq : str
Resampling frequency. A monthly or daily frequency is expected.
cal_range: Tuple[DateStr, DateStr] | None
Dates used to take a subset the input dataset for calibration. The tuple is formed by two `DateStr`,
i.e. a `str` in format `"YYYY-MM-DD"`. Default option `None` means that the full range of the input dataset is used.
params: xarray.DataArray
Fit parameters.
freq : str | None
Resampling frequency. A monthly or daily frequency is expected. Option `None` assumes that desired resampling
has already been applied input dataset and will skip the resampling step.
window : int
Averaging window length relative to the resampling frequency. For example, if `freq="MS"`, i.e. a monthly
resampling, the window is an integer number of months.
Expand All @@ -1064,6 +1128,9 @@ def standardized_precipitation_evapotranspiration_index(
Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method
uses a deterministic function that doesn't involve any optimization. Available methods
vary with the distribution: 'gamma':{'APP', 'ML'}, 'fisk':{'ML'}
indexer
Indexing parameters to compute the indicator on a temporal subset of the data.
It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`.

Returns
-------
Expand All @@ -1085,9 +1152,11 @@ def standardized_precipitation_evapotranspiration_index(
# library is taken
offset = convert_units_to("1 mm/d", wb.units, context="hydro")
with xarray.set_options(keep_attrs=True):
wb, wb_cal = wb + offset, wb_cal + offset
wb = wb + offset

spei = standardized_precipitation_index(wb, wb_cal, freq, window, dist, method)
spei = standardized_precipitation_index(
wb, cal_range, params, freq, window, dist, method, **indexer
)

return spei

Expand Down