Skip to content

Commit

Permalink
aggregation: allow specifying none as default in aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
nonibansal committed Dec 10, 2024
1 parent af3c39e commit c71eda0
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 67 deletions.
3 changes: 3 additions & 0 deletions fennel/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## [1.6.0] - 2024-12-10
- Allow None as default value for min/max/avg/stddev aggregations.

## [1.5.58] - 2024-11-24
- Allow min/max aggregation on date, datetime and decimal dtypes

Expand Down
127 changes: 127 additions & 0 deletions fennel/client_tests/test_complex_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timezone, date
from decimal import Decimal as PythonDecimal
from typing import Optional

import pandas as pd
import pytest
Expand All @@ -13,6 +14,8 @@
Dataset,
Max,
Min,
Average,
Stddev,
)
from fennel.dtypes import Decimal, Continuous
from fennel.featuresets import featureset, feature as F
Expand Down Expand Up @@ -202,3 +205,127 @@ class CountryFeatures:
datetime(1987, 6, 6, tzinfo=timezone.utc),
datetime(1970, 1, 2, tzinfo=timezone.utc),
]


@pytest.mark.integration
@mock
def test_none_default(client):
@source(webhook.endpoint("UserInfoDataset"), disorder="14d", cdc="append")
@dataset
class UserInfoDataset:
user_id: int
country: str
income: float
timestamp: datetime = field(timestamp=True)

@dataset(index=True)
class CountryDS:
country: str = field(key=True)
min_income: Optional[float]
max_income: Optional[float]
avg_income: Optional[float]
stddev_income: Optional[float]
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(UserInfoDataset)
def avg_income_pipeline(cls, event: Dataset):
return event.groupby("country").aggregate(
min_income=Min(
of="income",
window=Continuous("forever"),
default=None,
),
max_income=Max(
of="income",
window=Continuous("forever"),
default=None,
),
avg_income=Average(
of="income",
window=Continuous("forever"),
default=None,
),
stddev_income=Stddev(
of="income",
window=Continuous("forever"),
default=None,
),
)

@featureset
class CountryFeatures:
country: str
min_income: float = F(CountryDS.min_income, default=1.20)
max_income: float = F(CountryDS.max_income, default=2.20)
avg_income: float = F(CountryDS.avg_income, default=1.20)
stddev_income: Optional[float] = F(CountryDS.stddev_income)

# Sync the dataset
response = client.commit(
message="msg",
datasets=[UserInfoDataset, CountryDS],
featuresets=[CountryFeatures],
)
assert response.status_code == requests.codes.OK, response.json()

client.sleep(30)

now = datetime.now(timezone.utc)
df = pd.DataFrame(
{
"user_id": [1, 2, 3, 4, 5],
"country": ["India", "USA", "India", "USA", "UK"],
"income": [
1200.10,
1000.10,
1400.10,
90.10,
1100.10,
],
"timestamp": [now, now, now, now, now],
}
)
response = client.log("fennel_webhook", "UserInfoDataset", df)
assert response.status_code == requests.codes.OK, response.json()

client.sleep()

df = client.query(
inputs=[CountryFeatures.country],
outputs=[CountryFeatures],
input_dataframe=pd.DataFrame(
{"CountryFeatures.country": ["India", "USA", "UK", "China"]}
),
)
assert df.shape == (4, 5)
assert df["CountryFeatures.country"].tolist() == [
"India",
"USA",
"UK",
"China",
]
assert df["CountryFeatures.min_income"].tolist() == [
1200.10,
90.10,
1100.10,
1.20,
]
assert df["CountryFeatures.max_income"].tolist() == [
1400.10,
1000.10,
1100.10,
2.20,
]
assert df["CountryFeatures.avg_income"].tolist() == [
1300.1,
545.1,
1100.1,
1.2,
]
assert df["CountryFeatures.stddev_income"].tolist() == [
100.0,
455.0,
0,
pd.NA,
]
86 changes: 71 additions & 15 deletions fennel/datasets/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from datetime import date, datetime
from decimal import Decimal as PythonDecimal
from typing import List, Union, Optional
from typing import List, Union, Optional, Any

import fennel.gen.spec_pb2 as spec_proto
import fennel.gen.schema_pb2 as schema_proto
from fennel._vendor.pydantic import BaseModel, Extra, validator # type: ignore
from fennel.dtypes import Continuous, Tumbling, Hopping
from fennel.internal_lib.duration import Duration, duration_to_timedelta
from fennel.internal_lib.utils.utils import (
to_timestamp_proto,
to_date_proto,
to_decimal_proto,
)

ItemType = Union[str, List[str]]

Expand Down Expand Up @@ -109,7 +115,7 @@ class Sum(AggregateType):

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Sum")
return spec_proto.PreSpec(
sum=spec_proto.Sum(
window=self.window.to_proto(),
Expand All @@ -124,17 +130,24 @@ def signature(self):

class Average(AggregateType):
of: str
default: float = 0.0
default: Optional[float] = 0.0

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Average")
if self.default is None:
default = 0.0
default_null = True
else:
default = self.default
default_null = False
return spec_proto.PreSpec(
average=spec_proto.Average(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=self.default,
default=default,
default_null=default_null,
)
)

Expand All @@ -150,7 +163,7 @@ class Quantile(AggregateType):

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Quantile")
return spec_proto.PreSpec(
quantile=spec_proto.Quantile(
window=self.window.to_proto(),
Expand Down Expand Up @@ -183,6 +196,8 @@ class ExpDecaySum(AggregateType):
half_life: Duration

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for ExpDecaySum")
half_life = duration_to_timedelta(self.half_life)
return spec_proto.PreSpec(
exp_decay=spec_proto.ExponentialDecayAggregate(
Expand Down Expand Up @@ -214,25 +229,42 @@ def validate(self):

class Max(AggregateType):
of: str
default: Union[float, int, date, datetime, PythonDecimal]
default: Any

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Max")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
default_value = schema_proto.Value(
timestamp=to_timestamp_proto(self.default)
)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
default_value = schema_proto.Value(date=to_date_proto(self.default))
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default_value = schema_proto.Value(
decimal=to_decimal_proto(self.default)
)
elif isinstance(self.default, float):
default = self.default
default_value = schema_proto.Value(float=self.default)
elif isinstance(self.default, int):
default = float(self.default)
default_value = schema_proto.Value(int=self.default)
elif self.default is None:
default_value = schema_proto.Value(none=schema_proto.Value().none)
default = 0.0
else:
raise ValueError(f"invalid default value for Min: `{self.default}`")
return spec_proto.PreSpec(
max=spec_proto.Max(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=default,
default_value=default_value,
)
)

Expand All @@ -245,25 +277,42 @@ def agg_type(self):

class Min(AggregateType):
of: str
default: Union[float, int, date, datetime, PythonDecimal]
default: Any

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Min")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
default_value = schema_proto.Value(
timestamp=to_timestamp_proto(self.default)
)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
default_value = schema_proto.Value(date=to_date_proto(self.default))
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default_value = schema_proto.Value(
decimal=to_decimal_proto(self.default)
)
elif isinstance(self.default, float):
default = self.default
default_value = schema_proto.Value(float=self.default)
elif isinstance(self.default, int):
default = float(self.default)
default_value = schema_proto.Value(int=self.default)
elif self.default is None:
default_value = schema_proto.Value(none=schema_proto.Value().none)
default = 0.0
else:
raise ValueError(f"invalid default value for Min: `{self.default}`")
return spec_proto.PreSpec(
min=spec_proto.Min(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=default,
default_value=default_value,
)
)

Expand Down Expand Up @@ -324,16 +373,23 @@ def signature(self):

class Stddev(AggregateType):
of: str
default: float = -1.0
default: Optional[float] = -1.0

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Stddev")
if self.default is None:
default = 0.0
default_null = True
else:
default = self.default
default_null = False
return spec_proto.PreSpec(
stddev=spec_proto.Stddev(
window=self.window.to_proto(),
name=self.into_field,
default=self.default,
default=default,
default_null=default_null,
of=self.of,
)
)
Expand Down
20 changes: 16 additions & 4 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2799,7 +2799,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"Cannot take average of field `{agg.of}` of type `{dtype_to_string(dtype)}`"
)
values[agg.into_field] = pd.Float64Dtype # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[pd.Float64Dtype] # type: ignore
else:
values[agg.into_field] = pd.Float64Dtype # type: ignore
elif isinstance(agg, LastK):
dtype = input_schema.get_type(agg.of)
if agg.dropnull:
Expand Down Expand Up @@ -2839,7 +2842,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"invalid min: default value `{agg.default}` not of type `int`"
)
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[fennel_get_optional_inner(dtype)] # type: ignore
else:
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
elif isinstance(agg, Max):
dtype = input_schema.get_type(agg.of)
primtive_dtype = get_primitive_dtype_with_optional(dtype)
Expand All @@ -2857,7 +2863,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"invalid max: default value `{agg.default}` not of type `int`"
)
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[fennel_get_optional_inner(dtype)] # type: ignore
else:
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
elif isinstance(agg, Stddev):
dtype = input_schema.get_type(agg.of)
if (
Expand All @@ -2867,7 +2876,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"Cannot get standard deviation of field {agg.of} of type {dtype_to_string(dtype)}"
)
values[agg.into_field] = pd.Float64Dtype # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[pd.Float64Dtype] # type: ignore
else:
values[agg.into_field] = pd.Float64Dtype # type: ignore
elif isinstance(agg, Quantile):
dtype = input_schema.get_type(agg.of)
if (
Expand Down
Loading

0 comments on commit c71eda0

Please sign in to comment.