Skip to content

Commit

Permalink
Merge pull request #489 from Aarhus-Psychiatry-Research/fix/480/pure_…
Browse files Browse the repository at this point in the history
…NaNs_when_flattening_with_lookbehind-tuple_in_01_basic

fix(#480): pure NaNs when flattening with lookbehind-tuple in 01_basic
  • Loading branch information
HLasse authored Feb 22, 2024
2 parents 70da39a + 3ade969 commit 6ffbd92
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/01_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down
11 changes: 8 additions & 3 deletions src/timeseriesflattenerv2/spec_processors/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@ def _get_timedelta_frame(
predictiontime_frame: PredictionTimeFrame, value_frame: ValueFrame
) -> TimeDeltaFrame:
# Join the prediction time dataframe
joined_frame = predictiontime_frame.df.join(
value_frame.df, on=predictiontime_frame.entity_id_col_name, how="left"
# ensure that the timestamp col names are different to avoid conflicts
unique_predictiontime_frame_timestamp_col_name = (
f"__{predictiontime_frame.timestamp_col_name}__"
)

joined_frame = predictiontime_frame.df.rename(
{predictiontime_frame.timestamp_col_name: unique_predictiontime_frame_timestamp_col_name}
).join(value_frame.df, on=predictiontime_frame.entity_id_col_name, how="left")

# Get timedelta
timedelta_frame = joined_frame.with_columns(
(
pl.col(value_frame.value_timestamp_col_name)
- pl.col(predictiontime_frame.timestamp_col_name)
- pl.col(unique_predictiontime_frame_timestamp_col_name)
).alias("time_from_prediction_to_value")
)

Expand Down
25 changes: 25 additions & 0 deletions src/timeseriesflattenerv2/spec_processors/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,31 @@ def test_get_timedelta_frame():
assert result.get_timedeltas() == expected_timedeltas


def test_get_timedelta_frame_same_timestamp_col_names():
pred_frame = str_to_pl_df(
"""entity_id,timestamp
1,2021-01-03"""
)

value_frame = str_to_pl_df(
"""entity_id,value,timestamp
1,1,2021-01-01
1,2,2021-01-02
1,3,2021-01-03"""
)

expected_timedeltas = [dt.timedelta(days=-2), dt.timedelta(days=-1), dt.timedelta(days=0)]

result = process_spec._get_timedelta_frame(
predictiontime_frame=PredictionTimeFrame(
init_df=pred_frame.lazy(), timestamp_col_name="timestamp"
),
value_frame=ValueFrame(init_df=value_frame.lazy()),
)

assert result.get_timedeltas() == expected_timedeltas


def test_slice_without_any_within_window():
timedelta_frame = TimeDeltaFrame(
df=pl.LazyFrame(
Expand Down

0 comments on commit 6ffbd92

Please sign in to comment.