From f8c31bf84ca8067c45c5fcaa6c53dd98b95edad8 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Fri, 12 Apr 2024 11:07:17 +0200 Subject: [PATCH] update temporal.py and test_flattener.py --- .../spec_processors/temporal.py | 12 ++++++- src/timeseriesflattener/test_flattener.py | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/timeseriesflattener/spec_processors/temporal.py b/src/timeseriesflattener/spec_processors/temporal.py index 3c867ecc..b1b3793d 100644 --- a/src/timeseriesflattener/spec_processors/temporal.py +++ b/src/timeseriesflattener/spec_processors/temporal.py @@ -212,7 +212,17 @@ def _create_stride_chunks( ) ) - return PredictionTimeFrame(init_df=step_predictiontime_df), ValueFrame(step_value_df) + vf = spec.value_frame + return PredictionTimeFrame( + init_df=step_predictiontime_df, + entity_id_col_name=predictiontime_frame.entity_id_col_name, + timestamp_col_name=predictiontime_frame.timestamp_col_name, + pred_time_uuid_col_name=predictiontime_frame.pred_time_uuid_col_name, + ), ValueFrame( + init_df=step_value_df, + entity_id_col_name=vf.entity_id_col_name, + value_timestamp_col_name=vf.value_timestamp_col_name, + ) def process_temporal_spec( diff --git a/src/timeseriesflattener/test_flattener.py b/src/timeseriesflattener/test_flattener.py index 7ac4611e..d10e832a 100644 --- a/src/timeseriesflattener/test_flattener.py +++ b/src/timeseriesflattener/test_flattener.py @@ -284,3 +284,34 @@ def test_add_static_spec(): 1-2022-01-01 00:00:00.000000,1""" ) assert_frame_equal(result.collect(), expected, ignore_colums=["entity_id", "pred_timestamp"]) + + +def test_add_features_with_non_default_entity_id_col_name(): + prediction_times_df_str = """dw_ek_borger,pred_timestamp, + 1,2022-01-01 00:00:00 + """ + outcome_df_str = """dw_ek_borger,timestamp,value, + 1,2022-01-02 00:00:01, 2 + 1,2022-01-15 00:00:00, 1 + """ + result = flattener.Flattener( + predictiontime_frame=PredictionTimeFrame( + init_df=str_to_pl_df(prediction_times_df_str), entity_id_col_name="dw_ek_borger" + ) + ).aggregate_timeseries( + specs=[ + OutcomeSpec( + value_frame=ValueFrame( + init_df=str_to_pl_df(outcome_df_str), entity_id_col_name="dw_ek_borger" + ), + lookahead_distances=[(dt.timedelta(days=5), dt.timedelta(days=30))], + fallback=np.NaN, + aggregators=[MeanAggregator()], + ) + ] + ) + expected = str_to_pl_df( + """pred_time_uuid,outc_value_within_5_to_30_days_mean_fallback_nan +1-2022-01-01 00:00:00.000000,1""" + ) + assert_frame_equal(result.collect(), expected, ignore_colums=["dw_ek_borger", "pred_timestamp"])