Skip to content

Commit

Permalink
aggregation: allow specifying none as default in aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
nonibansal committed Dec 10, 2024
1 parent af3c39e commit 686d759
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 65 deletions.
1 change: 1 addition & 0 deletions fennel/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def commit(
)
self.add(featureset)
sync_request = self._get_sync_request_proto(message, env)
print(sync_request)
response = self._post_bytes(
f"{V1_API}/commit?preview={str(preview).lower()}&incremental={str(incremental).lower()}&backfill={str(backfill).lower()}",
sync_request.SerializeToString(),
Expand Down
139 changes: 139 additions & 0 deletions fennel/client_tests/test_complex_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timezone, date
from decimal import Decimal as PythonDecimal
from typing import Optional

import pandas as pd
import pytest
Expand All @@ -13,6 +14,8 @@
Dataset,
Max,
Min,
Average,
Stddev,
)
from fennel.dtypes import Decimal, Continuous
from fennel.featuresets import featureset, feature as F
Expand Down Expand Up @@ -202,3 +205,139 @@ class CountryFeatures:
datetime(1987, 6, 6, tzinfo=timezone.utc),
datetime(1970, 1, 2, tzinfo=timezone.utc),
]


@pytest.mark.integration
@mock
def test_none_default(client):
@source(webhook.endpoint("UserInfoDataset"), disorder="14d", cdc="append")
@dataset
class UserInfoDataset:
user_id: int
country: str
income: Decimal[2]
timestamp: datetime = field(timestamp=True)

@dataset(index=True)
class CountryDS:
country: str = field(key=True)
min_income: Optional[Decimal[2]]
max_income: Optional[Decimal[2]]
avg_income: Optional[float]
stddev_income: Optional[float]
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(UserInfoDataset)
def avg_income_pipeline(cls, event: Dataset):
return event.groupby("country").aggregate(
min_income=Min(
of="income",
window=Continuous("forever"),
default=None,
),
max_income=Max(
of="income",
window=Continuous("forever"),
default=None,
),
avg_income=Average(
of="income",
window=Continuous("forever"),
default=None,
),
stddev_income=Stddev(
of="income",
window=Continuous("forever"),
default=None,
),
)

@featureset
class CountryFeatures:
country: str
min_income: Decimal[2] = F(
CountryDS.min_income, default=PythonDecimal("1.20")
)
max_income: Decimal[2] = F(
CountryDS.max_income, default=PythonDecimal("2.20")
)
avg_income: float = F(CountryDS.avg_income, default=1.20)
stddev_income: Optional[float] = F(CountryDS.stddev_income)

# Sync the dataset
response = client.commit(
message="msg",
datasets=[UserInfoDataset, CountryDS],
featuresets=[CountryFeatures],
)
assert response.status_code == requests.codes.OK, response.json()

client.sleep(30)

now = datetime.now(timezone.utc)
df = pd.DataFrame(
{
"user_id": [1, 2, 3, 4, 5],
"country": ["India", "USA", "India", "USA", "UK"],
"income": [
PythonDecimal("1200.10"),
PythonDecimal("1000.10"),
PythonDecimal("1400.10"),
PythonDecimal("90.10"),
PythonDecimal("1100.10"),
],
"timestamp": [now, now, now, now, now],
}
)
response = client.log("fennel_webhook", "UserInfoDataset", df)
assert response.status_code == requests.codes.OK, response.json()

client.sleep()

df = client.query(
inputs=[CountryFeatures.country],
outputs=[CountryFeatures],
input_dataframe=pd.DataFrame(
{"CountryFeatures.country": ["India", "USA", "UK", "China"]}
),
)
assert df.shape == (4, 5)
assert df["CountryFeatures.country"].tolist() == [
"India",
"USA",
"UK",
"China",
]
assert df["CountryFeatures.min_income"].tolist() == [
PythonDecimal("1200.10"),
PythonDecimal("90.10"),
PythonDecimal("1100.10"),
PythonDecimal("1.20"),
]
assert df["CountryFeatures.max_income"].tolist() == [
PythonDecimal("1400.10"),
PythonDecimal("1000.10"),
PythonDecimal("1100.10"),
PythonDecimal("2.20"),
]
if client.is_integration_client():
assert df["CountryFeatures.avg_income"].tolist() == [
1300.1,
545.1,
1100.1,
1.2,
]
else:
assert df["CountryFeatures.avg_income"].tolist() == [
1300.1000000000001,
545.1,
1100.1000000000001,
1.2,
]
assert df["CountryFeatures.stddev_income"].tolist() == [
100.0,
455.0,
0,
pd.NA,
]
80 changes: 66 additions & 14 deletions fennel/datasets/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
from typing import List, Union, Optional

import fennel.gen.spec_pb2 as spec_proto
import fennel.gen.schema_pb2 as schema_proto
from fennel._vendor.pydantic import BaseModel, Extra, validator # type: ignore
from fennel.dtypes import Continuous, Tumbling, Hopping
from fennel.internal_lib.duration import Duration, duration_to_timedelta
from fennel.internal_lib.utils.utils import (
to_timestamp_proto,
to_date_proto,
to_decimal_proto,
)

ItemType = Union[str, List[str]]

Expand Down Expand Up @@ -109,7 +115,7 @@ class Sum(AggregateType):

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Sum")
return spec_proto.PreSpec(
sum=spec_proto.Sum(
window=self.window.to_proto(),
Expand All @@ -124,17 +130,24 @@ def signature(self):

class Average(AggregateType):
of: str
default: float = 0.0
default: Optional[float] = 0.0

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Average")
if self.default is None:
default = 0.0
default_null = True
else:
default = self.default
default_null = False
return spec_proto.PreSpec(
average=spec_proto.Average(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=self.default,
default=default,
default_null=default_null,
)
)

Expand All @@ -150,7 +163,7 @@ class Quantile(AggregateType):

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Quantile")
return spec_proto.PreSpec(
quantile=spec_proto.Quantile(
window=self.window.to_proto(),
Expand Down Expand Up @@ -183,6 +196,8 @@ class ExpDecaySum(AggregateType):
half_life: Duration

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for ExpDecaySum")
half_life = duration_to_timedelta(self.half_life)
return spec_proto.PreSpec(
exp_decay=spec_proto.ExponentialDecayAggregate(
Expand Down Expand Up @@ -214,25 +229,40 @@ def validate(self):

class Max(AggregateType):
of: str
default: Union[float, int, date, datetime, PythonDecimal]
default: Optional[Union[float, int, date, datetime, PythonDecimal]]

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Max")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
default_value = schema_proto.Value(
timestamp=to_timestamp_proto(self.default)
)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
default_value = schema_proto.Value(date=to_date_proto(self.default))
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default_value = schema_proto.Value(
decimal=to_decimal_proto(self.default)
)
elif isinstance(self.default, float):
default = self.default
default_value = schema_proto.Value(float=self.default)
elif isinstance(self.default, int):
default = float(self.default)
default_value = schema_proto.Value(int=self.default)
else:
default_value = schema_proto.Value(none=schema_proto.Value().none)
default = 0
return spec_proto.PreSpec(
max=spec_proto.Max(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=default,
default_value=default_value,
)
)

Expand All @@ -245,25 +275,40 @@ def agg_type(self):

class Min(AggregateType):
of: str
default: Union[float, int, date, datetime, PythonDecimal]
default: Optional[Union[float, int, date, datetime, PythonDecimal]]

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Min")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
default_value = schema_proto.Value(
timestamp=to_timestamp_proto(self.default)
)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
default_value = schema_proto.Value(date=to_date_proto(self.default))
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default_value = schema_proto.Value(
decimal=to_decimal_proto(self.default)
)
elif isinstance(self.default, float):
default = self.default
default_value = schema_proto.Value(float=self.default)
elif isinstance(self.default, int):
default = float(self.default)
default_value = schema_proto.Value(int=self.default)
else:
default_value = schema_proto.Value(none=schema_proto.Value().none)
default = 0
return spec_proto.PreSpec(
min=spec_proto.Min(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=default,
default_value=default_value,
)
)

Expand Down Expand Up @@ -324,16 +369,23 @@ def signature(self):

class Stddev(AggregateType):
of: str
default: float = -1.0
default: Optional[float] = -1.0

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
raise ValueError("Window must be specified for Stddev")
if self.default is None:
default = 0.0
default_null = True
else:
default = self.default
default_null = False
return spec_proto.PreSpec(
stddev=spec_proto.Stddev(
window=self.window.to_proto(),
name=self.into_field,
default=self.default,
default=default,
default_null=default_null,
of=self.of,
)
)
Expand Down
20 changes: 16 additions & 4 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2799,7 +2799,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"Cannot take average of field `{agg.of}` of type `{dtype_to_string(dtype)}`"
)
values[agg.into_field] = pd.Float64Dtype # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[pd.Float64Dtype] # type: ignore
else:
values[agg.into_field] = pd.Float64Dtype # type: ignore
elif isinstance(agg, LastK):
dtype = input_schema.get_type(agg.of)
if agg.dropnull:
Expand Down Expand Up @@ -2839,7 +2842,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"invalid min: default value `{agg.default}` not of type `int`"
)
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[fennel_get_optional_inner(dtype)] # type: ignore
else:
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
elif isinstance(agg, Max):
dtype = input_schema.get_type(agg.of)
primtive_dtype = get_primitive_dtype_with_optional(dtype)
Expand All @@ -2857,7 +2863,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"invalid max: default value `{agg.default}` not of type `int`"
)
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[fennel_get_optional_inner(dtype)] # type: ignore
else:
values[agg.into_field] = fennel_get_optional_inner(dtype) # type: ignore
elif isinstance(agg, Stddev):
dtype = input_schema.get_type(agg.of)
if (
Expand All @@ -2867,7 +2876,10 @@ def visitAggregate(self, obj) -> DSSchema:
raise TypeError(
f"Cannot get standard deviation of field {agg.of} of type {dtype_to_string(dtype)}"
)
values[agg.into_field] = pd.Float64Dtype # type: ignore
if agg.default is None:
values[agg.into_field] = Optional[pd.Float64Dtype] # type: ignore
else:
values[agg.into_field] = pd.Float64Dtype # type: ignore
elif isinstance(agg, Quantile):
dtype = input_schema.get_type(agg.of)
if (
Expand Down
Loading

0 comments on commit 686d759

Please sign in to comment.