Skip to content

Commit

Permalink
fix: scrambled features with step size
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed May 1, 2024
1 parent ec49982 commit 6bdddbd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/timeseriesflattener/flattener.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import datetime as dt
from dataclasses import dataclass
from functools import partial
from multiprocessing import Pool
from typing import TYPE_CHECKING, Union
import datetime as dt

import polars as pl
import tqdm
Expand Down Expand Up @@ -128,6 +128,10 @@ def aggregate_timeseries(
for spec in specs:
spec.value_frame.df = spec.value_frame.df.lazy()

self.predictiontime_frame.df = self.predictiontime_frame.df.sort(
self.predictiontime_frame.timestamp_col_name
) # type: ignore

# Process and collect the specs. One-by-one, to get feedback on progress.
dfs: Sequence[pl.LazyFrame] = []
if self.n_workers is None:
Expand Down
47 changes: 47 additions & 0 deletions src/timeseriesflattener/test_flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .feature_specs.outcome import OutcomeSpec
from .feature_specs.prediction_times import PredictionTimeFrame
from .feature_specs.predictor import PredictorSpec
from .feature_specs.static import StaticSpec, StaticFrame

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -315,3 +316,49 @@ def test_add_features_with_non_default_entity_id_col_name():
1-2022-01-01 00:00:00.000000,1"""
)
assert_frame_equal(result.collect(), expected, ignore_colums=["dw_ek_borger", "pred_timestamp"])


@pytest.mark.parametrize("step_size", [None, dt.timedelta(days=30)])
def test_multiple_features_with_unordered_prediction_times(step_size):
prediction_times_df_str = """entity_id,pred_timestamp,
2,2022-01-02 00:00:00
1,2022-01-01 00:00:00
1,2020-01-01 00:00:00
"""
pred_df_str = """entity_id,timestamp,value,
1,2021-12-31 00:00:01, 1
"""
static_df_str = """entity_id,static
1,1
2,2
"""
result = flattener.Flattener(
predictiontime_frame=PredictionTimeFrame(init_df=str_to_pl_df(prediction_times_df_str))
).aggregate_timeseries(
specs=[
PredictorSpec(
value_frame=ValueFrame(init_df=str_to_pl_df(pred_df_str)),
lookbehind_distances=[dt.timedelta(days=1)],
fallback=0,
aggregators=[MeanAggregator()],
),
StaticSpec(
value_frame=StaticFrame(init_df=str_to_pl_df(static_df_str)),
column_prefix="pred",
fallback=0,
),
],
step_size=step_size,
)
expected = str_to_pl_df(
"""pred_time_uuid,pred_value_within_0_to_1_days_mean_fallback_0,pred_static_fallback_0
2-2022-01-02 00:00:00.000000,0.0,2
1-2022-01-01 00:00:00.000000,1.0,1
1-2020-01-01 00:00:00.000000,0.0,1
"""
).sort("pred_time_uuid")
assert_frame_equal(
result.df.collect().sort("pred_time_uuid"),
expected,
ignore_colums=["entity_id", "pred_timestamp"],
)

0 comments on commit 6bdddbd

Please sign in to comment.