Skip to content

Commit

Permalink
feat: use relativedelta instead of timedelta
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Nov 9, 2024
1 parent 96e0d79 commit f381419
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 42 deletions.
2 changes: 0 additions & 2 deletions docs/user-guide/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,3 @@ for a_train, a_test, b_train, b_test, y_train, y_test in tbs.split(a, b, y, time

!!! warning
Ideally each array can be a different type (numpy, pandas, polars, and so on...), in practice there are a few limitations that might arise from the different types, so please be aware of that.

We are working to make the library more flexible and to support more types of arrays and more interactions between them.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ requires-python = ">=3.8"
authors = [{name = "Francesco Bruzzesi"}]

dependencies = [
"python-dateutil",
"numpy",
"narwhals>=1.0.0",
"typing-extensions>=4.4.0; python_version < '3.11'",
Expand All @@ -25,6 +26,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: MIT License",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from datetime import date
from datetime import datetime
from datetime import timedelta

import pandas as pd
import pytest
from dateutil.relativedelta import relativedelta

from timebasedcv.splitstate import SplitState

Expand Down Expand Up @@ -41,7 +41,7 @@
)
@pytest.mark.parametrize(
"expected_train_len, expected_forecast_len, expected_gap_len, expected_total_len",
[(timedelta(days=30), timedelta(days=27), timedelta(days=1), timedelta(days=58))],
[(relativedelta(days=30), relativedelta(days=27), relativedelta(days=1), relativedelta(months=1, days=27))],
)
def test_splitstate_valid(
train_start,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_splitstate_add():
forecast_end=datetime(2023, 2, 28, 0),
)

delta = timedelta(days=1)
delta = relativedelta(days=1)
expected_split_state = SplitState(
train_start=datetime(2023, 1, 2, 0),
train_end=datetime(2023, 2, 1, 0),
Expand All @@ -131,7 +131,7 @@ def test_splitstate_sub():
forecast_end=datetime(2023, 3, 1, 0),
)

delta = timedelta(days=1)
delta = relativedelta(days=1)
expected_split_state = SplitState(
train_start=datetime(2023, 1, 1, 0),
train_end=datetime(2023, 1, 31, 0),
Expand Down
13 changes: 7 additions & 6 deletions tests/test_timebasedsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from contextlib import nullcontext as does_not_raise
from datetime import date
from datetime import datetime
from datetime import timedelta

import numpy as np
import pandas as pd
import pytest
from dateutil.relativedelta import relativedelta

from timebasedcv import TimeBasedSplit
from timebasedcv.core import _CoreTimeBasedSplit
Expand All @@ -29,7 +29,8 @@
X, y = df[["a", "b"]], df["y"]

err_msg_freq = (
r"`frequency` must be one of \('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks'\)"
"`frequency` must be one of "
r"\('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks', 'months', 'years'\)"
)
err_msg_int = r"\(`train_size_`, `forecast_horizon_`, `gap_`, `stride_`\) arguments must be of type `int`."
err_msg_lower_bound = r"must be greater or equal than \(1, 1, 0, 1\)"
Expand Down Expand Up @@ -110,10 +111,10 @@ def test_core_properties(frequency, train_size, forecast_horizon, gap, stride):
stride=stride,
)

assert cv.train_delta == timedelta(**{frequency: train_size})
assert cv.forecast_delta == timedelta(**{frequency: forecast_horizon})
assert cv.gap_delta == timedelta(**{frequency: gap})
assert cv.stride_delta == timedelta(**{frequency: stride or forecast_horizon})
assert cv.train_delta == relativedelta(**{frequency: train_size})
assert cv.forecast_delta == relativedelta(**{frequency: forecast_horizon})
assert cv.gap_delta == relativedelta(**{frequency: gap})
assert cv.stride_delta == relativedelta(**{frequency: stride or forecast_horizon})


@pytest.mark.parametrize(
Expand Down
30 changes: 15 additions & 15 deletions timebasedcv/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sys
from datetime import timedelta
from itertools import chain
from typing import TYPE_CHECKING
from typing import Generator
Expand All @@ -13,6 +12,7 @@
from typing import overload

import narwhals.stable.v1 as nw
from dateutil.relativedelta import relativedelta

from timebasedcv.splitstate import SplitState
from timebasedcv.utils._backends import BACKEND_TO_INDEXING_METHOD
Expand Down Expand Up @@ -52,7 +52,7 @@ class _CoreTimeBasedSplit:
Arguments:
frequency: The frequency (or time unit) of the time series. Must be one of "days", "seconds", "microseconds",
"milliseconds", "minutes", "hours", "weeks". These are the only valid values for the `unit` argument of
`timedelta` from python `datetime` standard library.
`relativedelta` from python `datetime` standard library.
train_size: Defines the minimum number of time units required to be in the train set.
forecast_horizon: Specifies the number of time units to forecast.
gap: Sets the number of time units to skip between the end of the train set and the start of the forecast set.
Expand Down Expand Up @@ -172,24 +172,24 @@ def __repr__(self: Self) -> str:
return f"{self.name_}" "(\n " f"{_new_line_tab.join(f'{s} = {v}' for s, v in zip(_attrs, _values))}" "\n)"

@property
def train_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `train_size`."""
return timedelta(**{str(self.frequency_): self.train_size_})
def train_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `train_size`."""
return relativedelta(**{str(self.frequency_): self.train_size_}) # type: ignore[arg-type]

@property
def forecast_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `forecast_horizon`."""
return timedelta(**{str(self.frequency_): self.forecast_horizon_})
def forecast_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `forecast_horizon`."""
return relativedelta(**{str(self.frequency_): self.forecast_horizon_}) # type: ignore[arg-type]

@property
def gap_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to the `gap` and `frequency`."""
return timedelta(**{str(self.frequency_): self.gap_})
def gap_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to the `gap` and `frequency`."""
return relativedelta(**{str(self.frequency_): self.gap_}) # type: ignore[arg-type]

@property
def stride_delta(self: Self) -> timedelta:
"""Returns the `timedelta` object corresponding to `stride`."""
return timedelta(**{str(self.frequency_): self.stride_})
def stride_delta(self: Self) -> relativedelta:
"""Returns the `relativedelta` object corresponding to `stride`."""
return relativedelta(**{str(self.frequency_): self.stride_}) # type: ignore[arg-type]

def _splits_from_period(
self: Self,
Expand Down Expand Up @@ -339,7 +339,7 @@ class TimeBasedSplit(_CoreTimeBasedSplit):
Arguments:
frequency: The frequency (or time unit) of the time series. Must be one of "days", "seconds", "microseconds",
"milliseconds", "minutes", "hours", "weeks". These are the only valid values for the `unit` argument of
`timedelta` from python `datetime` standard library.
`relativedelta` from python `datetime` standard library.
train_size: Defines the minimum number of time units required to be in the train set.
forecast_horizon: Specifies the number of time units to forecast.
gap: Sets the number of time units to skip between the end of the train set and the start of the forecast set.
Expand Down
30 changes: 16 additions & 14 deletions timebasedcv/splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Generic
from typing import Union

from dateutil.relativedelta import relativedelta

from timebasedcv.utils._funcs import pairwise
from timebasedcv.utils._funcs import pairwise_comparison
from timebasedcv.utils._types import DateTimeLike
Expand Down Expand Up @@ -91,42 +93,42 @@ def __post_init__(self: Self) -> None:
raise ValueError(msg)

@property
def train_length(self: Self) -> timedelta:
def train_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `train_end`.
Returns:
A `timedelta` object representing the time between `train_start` and `train_end`.
A `relativedelta` object representing the time between `train_start` and `train_end`.
"""
return self.train_end - self.train_start
return relativedelta(self.train_end, self.train_start)

@property
def forecast_length(self: Self) -> timedelta:
def forecast_length(self: Self) -> relativedelta:
"""Returns the time between `forecast_start` and `forecast_end`.
Returns:
A `timedelta` object representing the time between `forecast_start` and `forecast_end`.
A `relativedelta` object representing the time between `forecast_start` and `forecast_end`.
"""
return self.forecast_end - self.forecast_start
return relativedelta(self.forecast_end, self.forecast_start)

@property
def gap_length(self: Self) -> timedelta:
def gap_length(self: Self) -> relativedelta:
"""Returns the time between `train_end` and `forecast_start`.
Returns:
A `timedelta` object representing the time between `train_end` and `forecast_start`.
A `relativedelta` object representing the time between `train_end` and `forecast_start`.
"""
return self.forecast_start - self.train_end
return relativedelta(self.forecast_start, self.train_end)

@property
def total_length(self: Self) -> timedelta:
def total_length(self: Self) -> relativedelta:
"""Returns the time between `train_start` and `forecast_end`.
Returns:
A `timedelta` object representing the time between `train_start` and `forecast_end`.
A `relativedelta` object representing the time between `train_start` and `forecast_end`.
"""
return self.forecast_end - self.train_start
return relativedelta(self.forecast_end, self.train_start)

def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
def __add__(self: Self, other: Union[timedelta, relativedelta, pd.Timedelta]) -> SplitState:
"""Adds `other` to each value of the state."""
return SplitState(
train_start=self.train_start + other,
Expand All @@ -135,7 +137,7 @@ def __add__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
forecast_end=self.forecast_end + other,
)

def __sub__(self: Self, other: Union[timedelta, pd.Timedelta]) -> SplitState:
def __sub__(self: Self, other: Union[timedelta, relativedelta, pd.Timedelta]) -> SplitState:
"""Subtracts other to each value of the state."""
return SplitState(
train_start=self.train_start - other,
Expand Down
4 changes: 3 additions & 1 deletion timebasedcv/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
DateTimeLike = TypeVar("DateTimeLike", datetime, date, "pd.Timestamp")
NullableDatetime = Union[DateTimeLike, None]

FrequencyUnit: TypeAlias = Literal["days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks"]
FrequencyUnit: TypeAlias = Literal[
"days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks", "months", "years"
]
WindowType: TypeAlias = Literal["rolling", "expanding"]
ModeType: TypeAlias = Literal["forward", "backward"]

Expand Down

0 comments on commit f381419

Please sign in to comment.