Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use dateutils.relativedelta instead of timedelta #64

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/user-guide/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,3 @@ for a_train, a_test, b_train, b_test, y_train, y_test in tbs.split(a, b, y, time

!!! warning
Ideally each array can be a different type (numpy, pandas, polars, and so on...), in practice there are a few limitations that might arise from the different types, so please be aware of that.

We are working to make the library more flexible and to support more types of arrays and more interactions between them.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ requires-python = ">=3.8"
authors = [{name = "Francesco Bruzzesi"}]

dependencies = [
"python-dateutil",
"numpy",
"narwhals>=1.0.0",
"typing-extensions>=4.4.0; python_version < '3.11'",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from datetime import date
from datetime import datetime
from datetime import timedelta

import pandas as pd
import pytest
from dateutil.relativedelta import relativedelta

from timebasedcv.splitstate import SplitState

Expand Down Expand Up @@ -41,7 +41,7 @@
)
@pytest.mark.parametrize(
"expected_train_len, expected_forecast_len, expected_gap_len, expected_total_len",
[(timedelta(days=30), timedelta(days=27), timedelta(days=1), timedelta(days=58))],
[(relativedelta(days=30), relativedelta(days=27), relativedelta(days=1), relativedelta(months=1, days=27))],
)
def test_splitstate_valid(
train_start,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_splitstate_add():
forecast_end=datetime(2023, 2, 28, 0),
)

delta = timedelta(days=1)
delta = relativedelta(days=1)
expected_split_state = SplitState(
train_start=datetime(2023, 1, 2, 0),
train_end=datetime(2023, 2, 1, 0),
Expand All @@ -131,7 +131,7 @@ def test_splitstate_sub():
forecast_end=datetime(2023, 3, 1, 0),
)

delta = timedelta(days=1)
delta = relativedelta(days=1)
expected_split_state = SplitState(
train_start=datetime(2023, 1, 1, 0),
train_end=datetime(2023, 1, 31, 0),
Expand Down
13 changes: 7 additions & 6 deletions tests/test_timebasedsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from contextlib import nullcontext as does_not_raise
from datetime import date
from datetime import datetime
from datetime import timedelta

import narwhals as nw
import numpy as np
import pandas as pd
import pytest
from dateutil.relativedelta import relativedelta

from timebasedcv import TimeBasedSplit
from timebasedcv.core import _CoreTimeBasedSplit
Expand All @@ -30,7 +30,8 @@
X, y = df[["a", "b"]], df["y"]

err_msg_freq = (
r"`frequency` must be one of \('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks'\)"
"`frequency` must be one of "
r"\('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks', 'months', 'years'\)"
)
err_msg_int = r"\(`train_size_`, `forecast_horizon_`, `gap_`, `stride_`\) arguments must be of type `int`."
err_msg_lower_bound = r"must be greater or equal than \(1, 1, 0, 1\)"
Expand Down Expand Up @@ -111,10 +112,10 @@ def test_core_properties(frequency, train_size, forecast_horizon, gap, stride):
stride=stride,
)

assert cv.train_delta == timedelta(**{frequency: train_size})
assert cv.forecast_delta == timedelta(**{frequency: forecast_horizon})
assert cv.gap_delta == timedelta(**{frequency: gap})
assert cv.stride_delta == timedelta(**{frequency: stride or forecast_horizon})
assert cv.train_delta == relativedelta(**{frequency: train_size})
assert cv.forecast_delta == relativedelta(**{frequency: forecast_horizon})
assert cv.gap_delta == relativedelta(**{frequency: gap})
assert cv.stride_delta == relativedelta(**{frequency: stride or forecast_horizon})


@pytest.mark.parametrize(
Expand Down
34 changes: 17 additions & 17 deletions timebasedcv/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sys
from datetime import timedelta
from itertools import chain
from typing import TYPE_CHECKING
from typing import Generator
Expand All @@ -13,6 +12,7 @@
from typing import overload

import narwhals.stable.v1 as nw
from dateutil.relativedelta import relativedelta

from timebasedcv.splitstate import SplitState
from timebasedcv.utils._backends import BACKEND_TO_INDEXING_METHOD
Expand Down Expand Up @@ -51,8 +51,8 @@ class _CoreTimeBasedSplit:

Arguments:
frequency: The frequency (or time unit) of the time series. Must be one of "days", "seconds", "microseconds",
"milliseconds", "minutes", "hours", "weeks". These are the only valid values for the `unit` argument of
`timedelta` from python `datetime` standard library.
"milliseconds", "minutes", "hours", "weeks", "months" or "years". These are the valid values for the
`unit` argument of `relativedelta` from python `dateutil` library.
train_size: Defines the minimum number of time units required to be in the train set.
forecast_horizon: Specifies the number of time units to forecast.
gap: Sets the number of time units to skip between the end of the train set and the start of the forecast set.
Expand Down Expand Up @@ -172,24 +172,24 @@ def __repr__(self: Self) -> str:
return f"{self.name_}" "(\n " f"{_new_line_tab.join(f'{s} = {v}' for s, v in zip(_attrs, _values))}" "\n)"

@property
def train_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `train_size`."""
return timedelta(**{str(self.frequency_): self.train_size_})
def train_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `train_size`."""
return relativedelta(**{str(self.frequency_): self.train_size_}) # type: ignore[arg-type]

@property
def forecast_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `forecast_horizon`."""
return timedelta(**{str(self.frequency_): self.forecast_horizon_})
def forecast_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `forecast_horizon`."""
return relativedelta(**{str(self.frequency_): self.forecast_horizon_}) # type: ignore[arg-type]

@property
def gap_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `gap` and `frequency`."""
return timedelta(**{str(self.frequency_): self.gap_})
def gap_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `gap` and `frequency`."""
return relativedelta(**{str(self.frequency_): self.gap_}) # type: ignore[arg-type]

@property
def stride_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to `stride`."""
return timedelta(**{str(self.frequency_): self.stride_})
def stride_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to `stride`."""
return relativedelta(**{str(self.frequency_): self.stride_}) # type: ignore[arg-type]

def _splits_from_period(
self: Self,
Expand Down Expand Up @@ -338,8 +338,8 @@ class TimeBasedSplit(_CoreTimeBasedSplit):

Arguments:
frequency: The frequency (or time unit) of the time series. Must be one of "days", "seconds", "microseconds",
"milliseconds", "minutes", "hours", "weeks". These are the only valid values for the `unit` argument of
`timedelta` from python `datetime` standard library.
"milliseconds", "minutes", "hours", "weeks", "months" or "years". These are the valid values for the
`unit` argument of `relativedelta` from python `dateutil` library.
train_size: Defines the minimum number of time units required to be in the train set.
forecast_horizon: Specifies the number of time units to forecast.
gap: Sets the number of time units to skip between the end of the train set and the start of the forecast set.
Expand Down
4 changes: 2 additions & 2 deletions timebasedcv/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class TimeBasedCVSplitter(_BaseKFold): # type: ignore[no-any-unimported]

Arguments:
frequency: The frequency (or time unit) of the time series. Must be one of "days", "seconds", "microseconds",
"milliseconds", "minutes", "hours", "weeks". These are the only valid values for the `unit` argument of
`timedelta` from python `datetime` standard library.
"milliseconds", "minutes", "hours", "weeks", "months" or "years". These are the valid values for the
`unit` argument of `relativedelta` from python `dateutil` library.
train_size: Defines the minimum number of time units required to be in the train set.
forecast_horizon: Specifies the number of time units to forecast.
time_series: The time series used to create boolean mask for splits. It is not required to be sorted, but it
Expand Down
32 changes: 17 additions & 15 deletions timebasedcv/splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Generic
from typing import Union

from dateutil.relativedelta import relativedelta

from timebasedcv.utils._funcs import pairwise
from timebasedcv.utils._funcs import pairwise_comparison
from timebasedcv.utils._types import DateTimeLike
Expand All @@ -27,7 +29,7 @@

@dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
"""A `SplitState` represents the state of a split in terms of its 4 cut/split points.
"""A `SplitState` represents the state of a split in terms of its four cut/split points.

Namely these are start and end of training set, start and end of forecasting/test set.

Expand Down Expand Up @@ -91,42 +93,42 @@ def __post_init__(self: Self) -> None:
raise ValueError(msg)

@property
def train_length(self: Self) -> timedelta:
def train_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `train_end`.

Returns:
A `timedelta` object representing the time between `train_start` and `train_end`.
A `relativedelta` object representing the time between `train_start` and `train_end`.
"""
return self.train_end - self.train_start
return relativedelta(self.train_end, self.train_start)

@property
def forecast_length(self: Self) -> timedelta:
def forecast_length(self: Self) -> relativedelta:
"""Returns the time between `forecast_start` and `forecast_end`.

Returns:
A `timedelta` object representing the time between `forecast_start` and `forecast_end`.
A `relativedelta` object representing the time between `forecast_start` and `forecast_end`.
"""
return self.forecast_end - self.forecast_start
return relativedelta(self.forecast_end, self.forecast_start)

@property
def gap_length(self: Self) -> timedelta:
def gap_length(self: Self) -> relativedelta:
"""Returns the time between `train_end` and `forecast_start`.

Returns:
A `timedelta` object representing the time between `train_end` and `forecast_start`.
A `relativedelta` object representing the time between `train_end` and `forecast_start`.
"""
return self.forecast_start - self.train_end
return relativedelta(self.forecast_start, self.train_end)

@property
def total_length(self: Self) -> timedelta:
def total_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `forecast_end`.

Returns:
A `timedelta` object representing the time between `train_start` and `forecast_end`.
A `relativedelta` object representing the time between `train_start` and `forecast_end`.
"""
return self.forecast_end - self.train_start
return relativedelta(self.forecast_end, self.train_start)

def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
def __add__(self: Self, other: Union[timedelta, relativedelta, pd.Timedelta]) -> SplitState:
"""Adds `other` to each value of the state."""
return SplitState(
train_start=self.train_start + other,
Expand All @@ -135,7 +137,7 @@ def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
forecast_end=self.forecast_end + other,
)

def __sub__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
def __sub__(self: Self, other: Union[timedelta, relativedelta, pd.Timedelta]) -> SplitState:
"""Subtracts other to each value of the state."""
return SplitState(
train_start=self.train_start - other,
Expand Down
4 changes: 3 additions & 1 deletion timebasedcv/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
DateTimeLike = TypeVar("DateTimeLike", datetime, date, "pd.Timestamp")
NullableDatetime = Union[DateTimeLike, None]

FrequencyUnit: TypeAlias = Literal["days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks"]
FrequencyUnit: TypeAlias = Literal[
"days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks", "months", "years"
]
WindowType: TypeAlias = Literal["rolling", "expanding"]
ModeType: TypeAlias = Literal["forward", "backward"]

Expand Down