Skip to content

Commit

Permalink
aggregation: Allow date, datetime and decimal dtypes in min/max aggre…
Browse files Browse the repository at this point in the history
…gation
  • Loading branch information
nonibansal committed Dec 5, 2024
1 parent d86d1ae commit 1981b9c
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 10 deletions.
3 changes: 3 additions & 0 deletions fennel/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## [1.5.58] - 2024-11-24
- Allow min/max aggregation on date, datetime and decimal dtypes

## [1.5.57] - 2024-12-02
- Fix signature method for filter class

Expand Down
204 changes: 204 additions & 0 deletions fennel/client_tests/test_complex_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from datetime import datetime, timezone, date
from decimal import Decimal as PythonDecimal

import pandas as pd
import pytest

from fennel._vendor import requests
from fennel.connectors import Webhook, source
from fennel.datasets import (
dataset,
field,
pipeline,
Dataset,
Max,
Min,
)
from fennel.dtypes import Decimal, Continuous
from fennel.featuresets import featureset, feature as F
from fennel.lib import inputs
from fennel.testing import mock

webhook = Webhook(name="fennel_webhook")
__owner__ = "[email protected]"


@pytest.mark.integration
@mock
def test_complex_min_max(client):
@source(webhook.endpoint("UserInfoDataset"), disorder="14d", cdc="append")
@dataset
class UserInfoDataset:
user_id: int
birthdate: date
birthtime: datetime
country: str
income: Decimal[2]
timestamp: datetime = field(timestamp=True)

@dataset(index=True)
class CountryDS:
country: str = field(key=True)
min_income: Decimal[2]
max_income: Decimal[2]
min_birthdate: date
max_birthdate: date
min_birthtime: datetime
max_birthtime: datetime
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(UserInfoDataset)
def avg_income_pipeline(cls, event: Dataset):
return event.groupby("country").aggregate(
min_income=Min(
of="income",
window=Continuous("forever"),
default=PythonDecimal("1.20"),
),
max_income=Max(
of="income",
window=Continuous("forever"),
default=PythonDecimal("2.20"),
),
min_birthdate=Min(
of="birthdate",
window=Continuous("forever"),
default=date(1970, 1, 1),
),
max_birthdate=Max(
of="birthdate",
window=Continuous("forever"),
default=date(1970, 1, 2),
),
min_birthtime=Min(
of="birthtime",
window=Continuous("forever"),
default=datetime(1970, 1, 1, tzinfo=timezone.utc),
),
max_birthtime=Max(
of="birthtime",
window=Continuous("forever"),
default=datetime(1970, 1, 2, tzinfo=timezone.utc),
),
)

@featureset
class CountryFeatures:
country: str
min_income: Decimal[2] = F(
CountryDS.min_income, default=PythonDecimal("1.20")
)
max_income: Decimal[2] = F(
CountryDS.max_income, default=PythonDecimal("2.20")
)
min_birthdate: date = F(
CountryDS.min_birthdate, default=date(1970, 1, 1)
)
max_birthdate: date = F(
CountryDS.max_birthdate, default=date(1970, 1, 2)
)
min_birthtime: datetime = F(
CountryDS.min_birthtime,
default=datetime(1970, 1, 1, tzinfo=timezone.utc),
)
max_birthtime: datetime = F(
CountryDS.max_birthtime,
default=datetime(1970, 1, 2, tzinfo=timezone.utc),
)

# Sync the dataset
response = client.commit(
message="msg",
datasets=[UserInfoDataset, CountryDS],
featuresets=[CountryFeatures],
)
assert response.status_code == requests.codes.OK, response.json()

client.sleep(30)

now = datetime.now(timezone.utc)
df = pd.DataFrame(
{
"user_id": [1, 2, 3, 4, 5],
"country": ["India", "USA", "India", "USA", "UK"],
"birthdate": [
date(1980, 1, 1),
date(1990, 2, 11),
date(2000, 3, 15),
date(1997, 5, 22),
date(1987, 6, 6),
],
"birthtime": [
datetime(1980, 1, 1, tzinfo=timezone.utc),
datetime(1990, 2, 11, tzinfo=timezone.utc),
datetime(2000, 3, 15, tzinfo=timezone.utc),
datetime(1997, 5, 22, tzinfo=timezone.utc),
datetime(1987, 6, 6, tzinfo=timezone.utc),
],
"income": [
PythonDecimal("1200.10"),
PythonDecimal("1000.10"),
PythonDecimal("1400.10"),
PythonDecimal("90.10"),
PythonDecimal("1100.10"),
],
"timestamp": [now, now, now, now, now],
}
)
response = client.log("fennel_webhook", "UserInfoDataset", df)
assert response.status_code == requests.codes.OK, response.json()

client.sleep()

# Querying UserInfoFeatures
df = client.query(
inputs=[CountryFeatures.country],
outputs=[CountryFeatures],
input_dataframe=pd.DataFrame(
{"CountryFeatures.country": ["India", "USA", "UK", "China"]}
),
)
assert df.shape == (4, 7)
assert df["CountryFeatures.country"].tolist() == [
"India",
"USA",
"UK",
"China",
]
assert df["CountryFeatures.min_income"].tolist() == [
PythonDecimal("1200.10"),
PythonDecimal("90.10"),
PythonDecimal("1100.10"),
PythonDecimal("1.20"),
]
assert df["CountryFeatures.max_income"].tolist() == [
PythonDecimal("1400.10"),
PythonDecimal("1000.10"),
PythonDecimal("1100.10"),
PythonDecimal("2.20"),
]
assert df["CountryFeatures.min_birthdate"].tolist() == [
date(1980, 1, 1),
date(1990, 2, 11),
date(1987, 6, 6),
date(1970, 1, 1),
]
assert df["CountryFeatures.max_birthdate"].tolist() == [
date(2000, 3, 15),
date(1997, 5, 22),
date(1987, 6, 6),
date(1970, 1, 2),
]
assert df["CountryFeatures.min_birthtime"].tolist() == [
datetime(1980, 1, 1, tzinfo=timezone.utc),
datetime(1990, 2, 11, tzinfo=timezone.utc),
datetime(1987, 6, 6, tzinfo=timezone.utc),
datetime(1970, 1, 1, tzinfo=timezone.utc),
]
assert df["CountryFeatures.max_birthtime"].tolist() == [
datetime(2000, 3, 15, tzinfo=timezone.utc),
datetime(1997, 5, 22, tzinfo=timezone.utc),
datetime(1987, 6, 6, tzinfo=timezone.utc),
datetime(1970, 1, 2, tzinfo=timezone.utc),
]
26 changes: 22 additions & 4 deletions fennel/datasets/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import date, datetime
from decimal import Decimal as PythonDecimal
from typing import List, Union, Optional

import fennel.gen.spec_pb2 as spec_proto
Expand Down Expand Up @@ -212,17 +214,25 @@ def validate(self):

class Max(AggregateType):
of: str
default: float
default: Union[float, int, date, datetime, PythonDecimal]

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default = float(self.default)
return spec_proto.PreSpec(
max=spec_proto.Max(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=self.default,
default=default,
)
)

Expand All @@ -235,17 +245,25 @@ def agg_type(self):

class Min(AggregateType):
of: str
default: float
default: Union[float, int, date, datetime, PythonDecimal]

def to_proto(self):
if self.window is None:
raise ValueError("Window must be specified for Distinct")
if isinstance(self.default, datetime):
default = float(self.default.timestamp() * 1000000.0)
elif isinstance(self.default, date):
default = float((self.default - date(1970, 1, 1)).days)
elif isinstance(self.default, PythonDecimal):
default = float(self.default)
else:
default = float(self.default)
return spec_proto.PreSpec(
min=spec_proto.Min(
window=self.window.to_proto(),
name=self.into_field,
of=self.of,
default=self.default,
default=default,
)
)

Expand Down
4 changes: 2 additions & 2 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,7 +2834,7 @@ def visitAggregate(self, obj) -> DSSchema:
f"invalid min: type of field `{agg.of}` is not int, float, date or datetime"
)
if primtive_dtype == pd.Int64Dtype and (
int(agg.default) != agg.default
int(agg.default) != agg.default # type: ignore
):
raise TypeError(
f"invalid min: default value `{agg.default}` not of type `int`"
Expand All @@ -2852,7 +2852,7 @@ def visitAggregate(self, obj) -> DSSchema:
f"invalid max: type of field `{agg.of}` is not int, float, date or datetime"
)
if primtive_dtype == pd.Int64Dtype and (
int(agg.default) != agg.default
int(agg.default) != agg.default # type: ignore
):
raise TypeError(
f"invalid max: default value `{agg.default}` not of type `int`"
Expand Down
7 changes: 4 additions & 3 deletions fennel/testing/execute_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import math
from abc import ABC, abstractmethod
from collections import Counter
from datetime import datetime, timezone, timedelta
from datetime import datetime, timezone, timedelta, date
from decimal import Decimal
from math import sqrt
from typing import Dict, List, Type, Union, Any

Expand Down Expand Up @@ -304,7 +305,7 @@ def top(self):


class MinState(AggState):
def __init__(self, default: float):
def __init__(self, default: Union[float, int, date, datetime, Decimal]):
self.counter = Counter() # type: ignore
self.min_heap = Heap(heap_type="min")
self.default = default
Expand Down Expand Up @@ -335,7 +336,7 @@ def get_val(self):


class MaxState(AggState):
def __init__(self, default: float):
def __init__(self, default: Union[float, int, date, datetime, Decimal]):
self.counter = Counter() # type: ignore
self.max_heap = Heap(heap_type="max")
self.default = default
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fennel-ai"
version = "1.5.57"
version = "1.5.58"
description = "The modern realtime feature engineering platform"
authors = ["Fennel AI <[email protected]>"]
packages = [{ include = "fennel" }]
Expand Down

0 comments on commit 1981b9c

Please sign in to comment.