Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SPI performance #1311

Merged
merged 89 commits into from
Oct 23, 2023
Merged

Improve SPI performance #1311

merged 89 commits into from
Oct 23, 2023

Conversation

coxipi
Copy link
Contributor

@coxipi coxipi commented Feb 28, 2023

Pull Request Checklist:

What kind of change does this PR introduce?

  • Make SPI/SPEI faster
  • fit params are now modular, can be computed before computing SPI/SPEI. This allows more options to segment computations and allow to obtain the fitting params if troubleshooting is needed.
  • time indexing now possible
  • dist_method now avoids vectorize=True in its xr.apply_ufunc. This is the main improvement in SPI/SPEI.
  • Better document the limits of usage of standardized indices. Now standardized indices are capped at extreme values ±8.21. The upper bound is a limit resulting of the use of float64.

Does this PR introduce a breaking change?

Yes.

  • pr_cal or wb_cal will not be input options in the future:

Inputing pr_cal will be deprecated in xclim==0.46.0. If pr_cal is a subset of pr, then instead of:
standardized_precipitation_index(pr=pr,pr_cal=pr.sel(time=slice(t0,t1)),...), one can call:
standardized_precipitation_index(pr=pr,cal_range=(t0,t1),...).
If for some reason pr_cal is not a subset of pr, then the following approach will still be possible:
params = standardized_index_fit_params(da=pr_cal, freq=freq, window=window, dist=dist, method=method).
spi = standardized_precipitation_index(pr=pr, params=params).
This approach can be used in both scenarios to break up the computations in two, i.e. get params, then compute
standardized indice

I could revert this breaking change if we prefer. This was a first attempt to make the computation faster, but the improvements are now independent of this change. We could also keep the modular structure for params, but revert to pr_cal instead of cal_range. It's a bit less efficient when pr_cal is simply a subset of pr, because you end up doing resampling/rolling two times on the calibration range for nothing. When first computing params, then obtaining spi in two steps, then it makes no difference

Other information:

@github-actions github-actions bot added the indicators Climate indices and indicators label Feb 28, 2023
xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
@coxipi
Copy link
Contributor Author

coxipi commented Mar 2, 2023

I think the computation is still quite heavy. A few things should help. There is less redundant computation (we can either reuse params, or at least we don't resample/roll more or less twice (like it used to do for pr and pr_cal)).

The user can now dissect the computation more easily:

  • Produce params before hand, reuse them for many computations
  • Use of indexer to obtain SPI only for certain periods. This not like selecting, say month of June, from the onset. The indexing must be done after the rolling. Said another way, you still need months "Jan to June" even if you only want SPI-6 for June. So the indexer must be used mid-computation, this is what is done.

xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
Comment on lines 1236 to +1239
@declare_units(
pr="[precipitation]",
pr_cal="[precipitation]",
params="[]",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yes, params needs an entry in declare_units if it is a Quantified.


def wrap_fit(da):
# if all values are null, obtain a template of the fit and broadcast
if indexer != {} and da.isnull().all():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, how does this work with dask ? Won't da.isnull().all() as a conditional trigger the computation ?

And why indexer != {} ? Why would this fastpath only be used with indexing ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about Dask. There is no way then to achieve a fastpath with lazy computing, no?

For the other question: Yes, indeed, spatial regions with only NaNs independently of time selecting would also benefit from this speedup (ignoring the Dask issue). I wanted to reserve this check for cases when I was sure there was a potential of all-NaN slices, e.g. when there was time-selecting. I'm not sure how costly it is to check .isnull.all, but probably very small in comparison to the whole algorithm.

xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
xclim/indices/_agro.py Outdated Show resolved Hide resolved
return np.zeros_like(da) * np.NaN

spi = np.zeros_like(da)
# loop over groups (similar to np.groupies idea, but np.groupies can't be used with custom functions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the doc (https://github.com/ml31415/numpy-groupies/), numpy groupies accepts a callable as func would it make sense to use that ?

Copy link
Contributor Author

@coxipi coxipi Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried but got a confusing error message leading me to believe this claim was maybe false. Maybe I just messed up. But I will retry, it would be worth it to clean the function.

xclim/indices/_agro.py Outdated Show resolved Hide resolved
params: xarray.DataArray
Fit parameters. The `params` can be computed using ``xclim.indices.standardized_index_fit_params`` in advance.
The ouput can be given here as input, and it overrides other options.
offset: Quantified
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Suggested change
offset: Quantified
offset: Quantified | None = None,
# pseudo-code
if offset is None:
  if params is not None:
      if dist in bounded_distributions:
          offset = 1000 mm / d
      else: 
        offset = 0
    else:
        offset = params.offset
else:
    if params is not None and params.offset != offset:
        warning

C'est à dire : put a default of None and clearly explain in the docstring what the default behaviour is. I know I just suggested something different for the snow thing, but that was in a "temporary" perspective!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could there some option where offset = -wb.min() ?

Copy link
Contributor Author

@coxipi coxipi Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check, but I think actually, there is not a restriction in scipy to use gamma/fisk distrubutions with 0 bounded data. A parameter can also be fitted for this offset. This is was an idea for another PR though. I think the problem could be: if there are more negative values outside of the calibration period, then the offset could be too weak.

The problem with offset = -wb.min() is that you can become sensitive to what data is included. Imagine your reference data is 1980-2010 and you fit two cases:

  1. 1980-2020
  2. 1980-2050

You could have a different minimum in cases 1 & 2, so different offset. Ideally, since this is a trick, we would like that our computation are not too sensitive on this. But in any case, it would be nice that the results of 1. are en exact subset of 2. if you use the same calibration data / methods for both computations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I would eventually like to get rid of the default 1 mm/day default. It was just copied on monocongo's climate indices. It's a good rule of thumb, but users have found cases where a bigger offset was needed. Maybe we can keep it in but warn users about having NaNs? Maybe just in the docstring. Not sure. I could check the R implementation to see if I can dig ideas.

coxipi and others added 2 commits October 18, 2023 17:40
Co-authored-by: Pascal Bourgault <[email protected]>
@coxipi
Copy link
Contributor Author

coxipi commented Oct 19, 2023

Showing problems with dist_method that explain the previous slow behaviour

I think it may be worthwhile to benchmark a bit the previous implementation, where is the bottleneck. The main culprit I believe is dist_method. For instance, let us first obtain some parameters.

from xclim.testing import open_dataset
import xclim
import scipy

pr = open_dataset("sdba/CanESM2_1950-2100.nc").pr
freq, window, dist, method = "MS", 1, "gamma", "APP"
params = xclim.indices.stats.standardized_index_fit_params(
    pr.sel(time=slice("1950", "1980")), freq=freq, window=window, dist=dist, method=method
)
# broadcast params like pr
params = params.rename(month="time").reindex(time=pr.time.dt.month)
params["time"] = pr.time

Now, we either use a numpy/ scipy approach (fast):

# ~0.5 seconds
import xarray as xr
from xclim.indices.stats import get_dist


dist = get_dist(params.attrs["scipy_dist"])
def wrap_cdf(da, pars):
    return dist.cdf(da[:], *pars)

out1 = xr.apply_ufunc(
    wrap_cdf,
    pr.where(pr>0), 
    params,
    input_core_dims=[["time"], ["dparams","time"]],
    output_core_dims=[["time"]],
    vectorize=True,
)
out1.values

or we use dist_method (50 times slower)

# about 25 seconds
from xclim.indices.stats import dist_method
out2 = dist_method("cdf", params, pr.where(pr > 0))
out2.values

Both methods give the same results. It remains to be seen if the speed-up persists as we scale things up, but this is already a weird result.

full spi computation

I can use the approach outlined above for the full spi computation. It is much faster than the previous implementation before this PR. It looks like:

    probs_of_zero = da.groupby(group).map(lambda x: (x == 0).sum("time") / x.notnull().sum("time"))
    params, probs_of_zero = [resample_to_time(dax, da) for dax in [params, probs_of_zero]]

    def wrap_cdf_ppf(da, pars, probs_of_zero):
        dist_probs = get_dist(params.attrs["scipy_dist"]).cdf(da[:], *pars)
        probs = probs_of_zero + ((1 - probs_of_zero) * dist_probs)
        return norm.ppf(probs)

    std_index = xr.apply_ufunc(
        wrap_cdf_ppf,
        da, 
        params,
        probs_of_zero,
        input_core_dims=[["time"], ["dparams","time"],["time"]],
        output_core_dims=[["time"]],
        vectorize=True,
        dask="parallelized",
    )

It remains about 2x slower than the approach with group_idxs, I think mainly because of the probs_of_zero computation, I will try a different approach

2nd try:

Using flox, I almost get as fast as the weird loop on group_idx I implemented:

probs_of_zero = flox.xarray.xarray_reduce(da, idxs,func="sum")/floxx.xarray_reduce(da.notnull(), idxs,func="sum")
  • weird loop with group_idx : 45 s
  • using floxx : 53 s
  • no floxx : 80 s

I don't think flox supports custom function. That would be ideal because then I could write:

def func(da):
    return (da==0).sum(dim="time")/da.notnull().sum(dim="time")

and have only one flox call. There is this Aggregation thing I'm trying to get the hang of...

3rd try

A little dirty trick that seems performant, but a bit more difficult to understand

floxx.xarray_reduce((pr+1)*(pr==0), idxs,func="mean")

It's equivalent to :

floxx.xarray_reduce((pr==0), idxs,func="sum")/floxx.xarray_reduce(pr.notnull(), idxs,func="sum")
floxx.xarray_reduce((pr==0).where(pr.notnull()), idxs,func="mean")

but avoids the double-flox, or the "where" call. Doesn't seem much faster than the 2nd try though.

By the way ...

I should try to improve dist_method instead of having a specific method for SPI/SPEI computations

xclim/indices/stats.py Outdated Show resolved Hide resolved
@coxipi
Copy link
Contributor Author

coxipi commented Oct 20, 2023

Sabotaging my previous example to mimick dist_method performance

The performance issue seems to stem from the input_core_dims used in the xr.apply_ufunc call of dist_method. With my little example, I can reproduce the slowness of dist_method:

    # like `dist_method`, about 25 s
    def wrap_cdf(da, pars):
        return dist.cdf(da, *pars)

    out = xr.apply_ufunc(
        wrap_cdf,
        pr.where(pr>0), 
        params,
        input_core_dims=[[], ["dparams"]],
        output_core_dims=[[]],
        vectorize=True,
    )

while using "time" in the core dims goes much faster. This remains true if I try using dask etc., but maybe I'm not using this well. At this point, we could chunk the time dimension, but earlier steps in stats usually require unchunked time dimension. If using "time" as a core dimension is always appropriate in this scenario, this could be way to improve dist_method easily

Copy link
Collaborator

@Zeitsperre Zeitsperre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done!

Copy link
Collaborator

@aulemahal aulemahal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small comments, but it looks good!

CHANGES.rst Outdated Show resolved Hide resolved
xclim/indices/stats.py Outdated Show resolved Hide resolved
coxipi and others added 2 commits October 23, 2023 14:36
@Zeitsperre
Copy link
Collaborator

Ignore the cancelled builds, merging when docs clear.

@coxipi coxipi merged commit dae1ffd into master Oct 23, 2023
9 checks passed
@coxipi coxipi deleted the fix_spi_performance branch October 23, 2023 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
approved Approved for additional tests indicators Climate indices and indicators
Projects
None yet
4 participants