Skip to content

Commit

Permalink
updates following review
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed Feb 22, 2024
1 parent c2b9529 commit fa95168
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/timeseriesflattenerv2/spec_processors/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
def process_static_spec(
spec: StaticSpec, predictiontime_frame: PredictionTimeFrame
) -> ProcessedFrame:
new_col_names = [
f"{spec.column_prefix}_{value_col_name}_fallback_{spec.fallback}"
old2new_colname = {
value_col_name: f"{spec.column_prefix}_{value_col_name}_fallback_{spec.fallback}"
for value_col_name in spec.value_frame.value_col_names
]
}
prediction_times_with_time_from_event = (
predictiontime_frame.df.join(
spec.value_frame.df, on=predictiontime_frame.entity_id_col_name, how="left"
)
.rename(dict(zip(spec.value_frame.value_col_names, new_col_names)))
.select(predictiontime_frame.pred_time_uuid_col_name, *new_col_names)
.rename(old2new_colname)
.select(predictiontime_frame.pred_time_uuid_col_name, *old2new_colname.values())
)

return ProcessedFrame(
Expand Down
24 changes: 24 additions & 0 deletions src/timeseriesflattenerv2/spec_processors/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,30 @@ def test_multiple_aggregators():
assert_frame_equal(aggregated_values.collect(), expected)


def test_masking_multiple_values_multiple_aggregators():
masked_frame = TimeMaskedFrame(
validate_cols_exist=False,
init_df=str_to_pl_df(
"""pred_time_uuid,value_1,value_2
1-2021-01-03,1,np.nan
1-2021-01-03,2,np.nan
2-2021-01-03,2,np.nan
2-2021-01-03,4,np.nan"""
).lazy(),
value_col_names=["value_1", "value_2"],
)

aggregated_values = process_spec._aggregate_masked_frame(
masked_frame=masked_frame, aggregators=[MeanAggregator(), MaxAggregator()], fallback=0
)

expected = str_to_pl_df(
"""pred_time_uuid,value_1_mean_fallback_0,value_2_mean_fallback_0,value_1_max_fallback_0,value_2_max_fallback_0
1-2021-01-03,1.5,0,2,0
2-2021-01-03,3,0,4,0"""
)


def test_process_time_from_event_spec():
pred_frame = str_to_pl_df(
"""entity_id,pred_timestamp
Expand Down

0 comments on commit fa95168

Please sign in to comment.