Skip to content

Commit

Permalink
Extend STLDecomposer to Support Multiseries (#4253)
Browse files Browse the repository at this point in the history
* creates multiple graphs

* reset condition for period

* take dataframe as y input and fix indexing

* update inverse_transform and get_trend_dataframe

* update get_trend_prediction_intervals

* add ms seasonal data

* add multiseries tests

* add plot test

* add periods parameters

* add unstacking and test

---------

Co-authored-by: Becca McBrayer <[email protected]>
Co-authored-by: Christopher Bunn <[email protected]>
  • Loading branch information
3 people authored Aug 31, 2023
1 parent 90033c5 commit 69344b2
Show file tree
Hide file tree
Showing 7 changed files with 1,079 additions and 326 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Release Notes
-------------
**Future Releases**
* Enhancements
* Extended STLDecomposer to Support Multiseries :pr:`4253`
* Fixes
* Changes
* Documentation Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import re
from abc import abstractmethod
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -324,9 +325,9 @@ def _project_seasonal(
def plot_decomposition(
self,
X: pd.DataFrame,
y: pd.Series,
y: Union[pd.Series, pd.DataFrame],
show: bool = False,
) -> tuple[plt.Figure, list]:
) -> Union[tuple[plt.Figure, list], dict[str, tuple[plt.Figure]]]:
"""Plots the decomposition of the target signal.
Args:
Expand All @@ -336,24 +337,49 @@ def plot_decomposition(
show (bool): Whether to display the plot or not. Defaults to False.
Returns:
matplotlib.pyplot.Figure, list[matplotlib.pyplot.Axes]: The figure and axes that have the decompositions
(Single series) matplotlib.pyplot.Figure, list[matplotlib.pyplot.Axes]: The figure and axes that have the decompositions
plotted on them
(Multi series) dict[str, (matplotlib.pyplot.Figure, list[matplotlib.pyplot.Axes])]: A dictionary that maps the series id
to the figure and axes that have the decompositions plotted on them
"""
if isinstance(y, pd.Series):
y = y.to_frame()

plot_info = {}
if self.frequency and self.time_index and len(y.columns) > 1:
X.index = pd.DatetimeIndex(X[self.time_index], freq=self.frequency)
decomposition_results = self.get_trend_dataframe(X, y)
fig, axs = plt.subplots(4)
fig.set_size_inches(18.5, 14.5)
axs[0].plot(decomposition_results[0]["signal"], "r")
axs[0].set_title("signal")
axs[1].plot(decomposition_results[0]["trend"], "b")
axs[1].set_title("trend")
axs[2].plot(decomposition_results[0]["seasonality"], "g")
axs[2].set_title("seasonality")
axs[3].plot(decomposition_results[0]["residual"], "y")
axs[3].set_title("residual")
if show: # pragma: no cover
plt.show()
return fig, axs

# Iterate through each series id
for id in y.columns:
fig, axs = plt.subplots(4)
fig.set_size_inches(18.5, 14.5)

if len(y.columns) > 1:
results = decomposition_results[id][0]
else:
results = decomposition_results[0]
axs[0].plot(results["signal"], "r")
axs[0].set_title("signal")
axs[1].plot(results["trend"], "b")
axs[1].set_title("trend")
axs[2].plot(results["seasonality"], "g")
axs[2].set_title("seasonality")
axs[3].plot(results["residual"], "y")
axs[3].set_title("residual")

# If multiseries, return a dictionary of tuples
if len(y.columns) > 1:
fig.suptitle("Decomposition for Series {}".format(id))
plot_info[id] = (fig, axs)
else:
plot_info = (fig, axs)

if show: # pragma: no cover
plt.show()

return plot_info

def _check_target(self, X: pd.DataFrame, y: pd.Series):
"""Function to ensure target is not None and has a pandas.DatetimeIndex."""
Expand Down
Loading

0 comments on commit 69344b2

Please sign in to comment.