Skip to content

Commit

Permalink
update temporal.py and test_flattener.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Apr 12, 2024
1 parent bb88ed9 commit f8c31bf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/timeseriesflattener/spec_processors/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,17 @@ def _create_stride_chunks(
)
)

return PredictionTimeFrame(init_df=step_predictiontime_df), ValueFrame(step_value_df)
vf = spec.value_frame
return PredictionTimeFrame(
init_df=step_predictiontime_df,
entity_id_col_name=predictiontime_frame.entity_id_col_name,
timestamp_col_name=predictiontime_frame.timestamp_col_name,
pred_time_uuid_col_name=predictiontime_frame.pred_time_uuid_col_name,
), ValueFrame(
init_df=step_value_df,
entity_id_col_name=vf.entity_id_col_name,
value_timestamp_col_name=vf.value_timestamp_col_name,
)


def process_temporal_spec(
Expand Down
31 changes: 31 additions & 0 deletions src/timeseriesflattener/test_flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,34 @@ def test_add_static_spec():
1-2022-01-01 00:00:00.000000,1"""
)
assert_frame_equal(result.collect(), expected, ignore_colums=["entity_id", "pred_timestamp"])


def test_add_features_with_non_default_entity_id_col_name():
prediction_times_df_str = """dw_ek_borger,pred_timestamp,
1,2022-01-01 00:00:00
"""
outcome_df_str = """dw_ek_borger,timestamp,value,
1,2022-01-02 00:00:01, 2
1,2022-01-15 00:00:00, 1
"""
result = flattener.Flattener(
predictiontime_frame=PredictionTimeFrame(
init_df=str_to_pl_df(prediction_times_df_str), entity_id_col_name="dw_ek_borger"
)
).aggregate_timeseries(
specs=[
OutcomeSpec(
value_frame=ValueFrame(
init_df=str_to_pl_df(outcome_df_str), entity_id_col_name="dw_ek_borger"
),
lookahead_distances=[(dt.timedelta(days=5), dt.timedelta(days=30))],
fallback=np.NaN,
aggregators=[MeanAggregator()],
)
]
)
expected = str_to_pl_df(
"""pred_time_uuid,outc_value_within_5_to_30_days_mean_fallback_nan
1-2022-01-01 00:00:00.000000,1"""
)
assert_frame_equal(result.collect(), expected, ignore_colums=["dw_ek_borger", "pred_timestamp"])

0 comments on commit f8c31bf

Please sign in to comment.