Skip to content

Commit

Permalink
fix: don't fix types in HasValuesAggregator fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed May 17, 2024
1 parent 6996013 commit 6e403d5
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/timeseriesflattener/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SingleVarAggregatorExample:
aggregator: Aggregator
input_values: Sequence[float | None]
expected_output_values: Sequence[float]
fallback_str: str = "nan"

@property
def input_frame(self) -> pl.LazyFrame:
Expand All @@ -58,7 +59,7 @@ def expected_output(self) -> pl.DataFrame:
return pl.DataFrame(
{
"prediction_time_uuid": [1],
f"value_{self.aggregator.name}_fallback_nan": self.expected_output_values,
f"value_{self.aggregator.name}_fallback_{self.fallback_str}": self.expected_output_values,
}
)

Expand Down Expand Up @@ -90,12 +91,13 @@ def expected_output(self) -> pl.DataFrame:
aggregator=VarianceAggregator(), input_values=[1, 2], expected_output_values=[0.5]
),
SingleVarAggregatorExample(
aggregator=HasValuesAggregator(), input_values=[1, 2], expected_output_values=[1]
aggregator=HasValuesAggregator(), input_values=[1, 2], expected_output_values=[1], fallback_str="False"
),
SingleVarAggregatorExample(
aggregator=HasValuesAggregator(),
input_values=[None], # type: ignore
expected_output_values=[0],
fallback_str="False",
),
ComplexAggregatorExample(
aggregator=SlopeAggregator(timestamp_col_name="timestamp"),
Expand Down Expand Up @@ -153,7 +155,7 @@ def test_aggregator(example: AggregatorExampleType):
timestamp_col_name="timestamp",
),
aggregators=[example.aggregator],
fallback=np.nan,
fallback=np.nan if example.aggregator.name != "bool" else False,
)

assert_frame_equal(result.collect(), example.expected_output)

0 comments on commit 6e403d5

Please sign in to comment.