diff --git a/src/timeseriesflattener/test_aggregators.py b/src/timeseriesflattener/test_aggregators.py index 89cfa9b1..ee7ad1ea 100644 --- a/src/timeseriesflattener/test_aggregators.py +++ b/src/timeseriesflattener/test_aggregators.py @@ -42,6 +42,7 @@ class SingleVarAggregatorExample: aggregator: Aggregator input_values: Sequence[float | None] expected_output_values: Sequence[float] + fallback_str: str = "nan" @property def input_frame(self) -> pl.LazyFrame: @@ -58,7 +59,7 @@ def expected_output(self) -> pl.DataFrame: return pl.DataFrame( { "prediction_time_uuid": [1], - f"value_{self.aggregator.name}_fallback_nan": self.expected_output_values, + f"value_{self.aggregator.name}_fallback_{self.fallback_str}": self.expected_output_values, } ) @@ -90,12 +91,13 @@ def expected_output(self) -> pl.DataFrame: aggregator=VarianceAggregator(), input_values=[1, 2], expected_output_values=[0.5] ), SingleVarAggregatorExample( - aggregator=HasValuesAggregator(), input_values=[1, 2], expected_output_values=[1] + aggregator=HasValuesAggregator(), input_values=[1, 2], expected_output_values=[1], fallback_str="False" ), SingleVarAggregatorExample( aggregator=HasValuesAggregator(), input_values=[None], # type: ignore expected_output_values=[0], + fallback_str="False", ), ComplexAggregatorExample( aggregator=SlopeAggregator(timestamp_col_name="timestamp"), @@ -153,7 +155,7 @@ def test_aggregator(example: AggregatorExampleType): timestamp_col_name="timestamp", ), aggregators=[example.aggregator], - fallback=np.nan, + fallback=np.nan if example.aggregator.name != "bool" else False, ) assert_frame_equal(result.collect(), example.expected_output)