Skip to content

Commit

Permalink
ensure coreforecast is installed for AutoDifferences (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Feb 16, 2024
1 parent 6295a56 commit 02a2a85
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 40 deletions.
2 changes: 1 addition & 1 deletion mlforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.11.7"
__version__ = "0.11.8"
__all__ = ['MLForecast']
from mlforecast.forecast import MLForecast
3 changes: 1 addition & 2 deletions mlforecast/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
core_scalers = None
CoreGroupedArray = None

class BaseLagTransform:
...
class BaseLagTransform: ...

Lag = None

Expand Down
39 changes: 13 additions & 26 deletions mlforecast/lag_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,16 @@ def __init__(self, window_size: int, min_samples: Optional[int] = None):
self.min_samples = min_samples

# %% ../nbs/lag_transforms.ipynb 7
class RollingMean(_RollingBase):
...
class RollingMean(_RollingBase): ...


class RollingStd(_RollingBase):
...
class RollingStd(_RollingBase): ...


class RollingMin(_RollingBase):
...
class RollingMin(_RollingBase): ...


class RollingMax(_RollingBase):
...
class RollingMax(_RollingBase): ...


class RollingQuantile(_RollingBase):
Expand Down Expand Up @@ -119,20 +115,16 @@ def __init__(
self.min_samples = min_samples

# %% ../nbs/lag_transforms.ipynb 10
class SeasonalRollingMean(_Seasonal_RollingBase):
...
class SeasonalRollingMean(_Seasonal_RollingBase): ...


class SeasonalRollingStd(_Seasonal_RollingBase):
...
class SeasonalRollingStd(_Seasonal_RollingBase): ...


class SeasonalRollingMin(_Seasonal_RollingBase):
...
class SeasonalRollingMin(_Seasonal_RollingBase): ...


class SeasonalRollingMax(_Seasonal_RollingBase):
...
class SeasonalRollingMax(_Seasonal_RollingBase): ...


class SeasonalRollingQuantile(_Seasonal_RollingBase):
Expand All @@ -154,24 +146,19 @@ def __init__(
class _ExpandingBase(BaseLagTransform):
"""Expanding statistic"""

def __init__(self):
...
def __init__(self): ...

# %% ../nbs/lag_transforms.ipynb 13
class ExpandingMean(_ExpandingBase):
...
class ExpandingMean(_ExpandingBase): ...


class ExpandingStd(_ExpandingBase):
...
class ExpandingStd(_ExpandingBase): ...


class ExpandingMin(_ExpandingBase):
...
class ExpandingMin(_ExpandingBase): ...


class ExpandingMax(_ExpandingBase):
...
class ExpandingMax(_ExpandingBase): ...


class ExpandingQuantile(_ExpandingBase):
Expand Down
30 changes: 20 additions & 10 deletions mlforecast/target_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ def update(self, df: DataFrame) -> DataFrame:
raise NotImplementedError

@abc.abstractmethod
def fit_transform(self, df: DataFrame) -> DataFrame:
...
def fit_transform(self, df: DataFrame) -> DataFrame: ...

@abc.abstractmethod
def inverse_transform(self, df: DataFrame) -> DataFrame:
...
def inverse_transform(self, df: DataFrame) -> DataFrame: ...

# %% ../nbs/target_transforms.ipynb 6
class BaseGroupedArrayTargetTransform(abc.ABC):
Expand All @@ -59,16 +57,13 @@ def set_num_threads(self, num_threads: int) -> None:
self.num_threads = num_threads

@abc.abstractmethod
def update(self, ga: GroupedArray) -> GroupedArray:
...
def update(self, ga: GroupedArray) -> GroupedArray: ...

@abc.abstractmethod
def fit_transform(self, ga: GroupedArray) -> GroupedArray:
...
def fit_transform(self, ga: GroupedArray) -> GroupedArray: ...

@abc.abstractmethod
def inverse_transform(self, ga: GroupedArray) -> GroupedArray:
...
def inverse_transform(self, ga: GroupedArray) -> GroupedArray: ...

def inverse_transform_fitted(self, ga: GroupedArray) -> GroupedArray:
return self.inverse_transform(ga)
Expand Down Expand Up @@ -137,6 +132,11 @@ class AutoDifferences(BaseGroupedArrayTargetTransform):
Maximum number of differences to apply."""

def __init__(self, max_diffs: int):
if not CORE_INSTALLED:
raise ImportError(
"coreforecast is required for this transformation. "
"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation"
)
self.scaler_ = core_scalers.AutoDifferences(max_diffs)

def fit_transform(self, ga: GroupedArray) -> GroupedArray:
Expand Down Expand Up @@ -172,6 +172,11 @@ class AutoSeasonalDifferences(AutoDifferences):
def __init__(
self, season_length: int, max_diffs: int, n_seasons: Optional[int] = 10
):
if not CORE_INSTALLED:
raise ImportError(
"coreforecast is required for this transformation. "
"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation"
)
self.scaler_ = core_scalers.AutoSeasonalDifferences(
season_length=season_length,
max_diffs=max_diffs,
Expand All @@ -196,6 +201,11 @@ class AutoSeasonalityAndDifferences(AutoDifferences):
def __init__(
self, max_season_length: int, max_diffs: int, n_seasons: Optional[int] = 10
):
if not CORE_INSTALLED:
raise ImportError(
"coreforecast is required for this transformation. "
"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation"
)
self.scaler_ = core_scalers.AutoSeasonalityAndDifferences(
max_season_length=max_season_length,
max_diffs=max_diffs,
Expand Down
18 changes: 18 additions & 0 deletions nbs/target_transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@
" max_diffs: int\n",
" Maximum number of differences to apply.\"\"\"\n",
" def __init__(self, max_diffs: int):\n",
" if not CORE_INSTALLED:\n",
" raise ImportError(\n",
" \"coreforecast is required for this transformation. \"\n",
" \"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation\"\n",
" )\n",
" self.scaler_ = core_scalers.AutoDifferences(max_diffs)\n",
"\n",
" def fit_transform(self, ga: GroupedArray) -> GroupedArray:\n",
Expand All @@ -301,6 +306,7 @@
"outputs": [],
"source": [
"#| hide\n",
"#| core\n",
"sc = AutoDifferences(1)\n",
"ga = GroupedArray(np.arange(10), np.array([0, 10]))\n",
"transformed = sc.fit_transform(ga)\n",
Expand Down Expand Up @@ -335,6 +341,11 @@
" If `None` will use all samples, otherwise `season_length` * `n_seasons samples` will be used for the test.\n",
" Smaller values will be faster but could be less accurate.\"\"\"\n",
" def __init__(self, season_length: int, max_diffs: int, n_seasons: Optional[int] = 10):\n",
" if not CORE_INSTALLED:\n",
" raise ImportError(\n",
" \"coreforecast is required for this transformation. \"\n",
" \"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation\"\n",
" ) \n",
" self.scaler_ = core_scalers.AutoSeasonalDifferences(\n",
" season_length=season_length,\n",
" max_diffs=max_diffs,\n",
Expand All @@ -361,6 +372,7 @@
],
"source": [
"#| hide\n",
"#| core\n",
"sc = AutoSeasonalDifferences(season_length=5, max_diffs=1)\n",
"ga = GroupedArray(np.arange(5)[np.arange(10) % 5], np.array([0, 10]))\n",
"transformed = sc.fit_transform(ga)\n",
Expand Down Expand Up @@ -390,6 +402,11 @@
" If `None` will use all samples, otherwise `max_season_length` * `n_seasons samples` will be used for the test.\n",
" Smaller values will be faster but could be less accurate.\"\"\"\n",
" def __init__(self, max_season_length: int, max_diffs: int, n_seasons: Optional[int] = 10):\n",
" if not CORE_INSTALLED:\n",
" raise ImportError(\n",
" \"coreforecast is required for this transformation. \"\n",
" \"Please follow the installation instructions at https://github.com/Nixtla/coreforecast/#installation\"\n",
" ) \n",
" self.scaler_ = core_scalers.AutoSeasonalityAndDifferences(\n",
" max_season_length=max_season_length,\n",
" max_diffs=max_diffs,\n",
Expand All @@ -416,6 +433,7 @@
],
"source": [
"#| hide\n",
"#| core\n",
"sc = AutoSeasonalityAndDifferences(max_season_length=5, max_diffs=1)\n",
"ga = GroupedArray(np.arange(5)[np.arange(10) % 5], np.array([0, 10]))\n",
"transformed = sc.fit_transform(ga)\n",
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author = José Morales
author_email = [email protected]
copyright = Nixtla
branch = main
version = 0.11.7
version = 0.11.8
min_python = 3.8
audience = Developers
language = English
Expand Down

0 comments on commit 02a2a85

Please sign in to comment.