Skip to content

Commit

Permalink
agg: Support *args for aggregation (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar authored Nov 9, 2023
1 parent 2a0eddd commit e8cbb7e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 63 deletions.
24 changes: 6 additions & 18 deletions docs/examples/getting-started/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@

# docsnip datasets
@dataset
@source(
postgres.table("product_info", cursor="last_modified"),
every="1m",
tier="prod",
)
@source(postgres.table("product", cursor="updated"), every="1m", tier="prod")
@source(webhook.endpoint("Product"), tier="dev")
@meta(owner="[email protected]", tags=["PII"])
class Product:
Expand Down Expand Up @@ -83,20 +79,12 @@ class UserSellerOrders:
@inputs(Order, Product)
def my_pipeline(cls, orders: Dataset, products: Dataset):
orders = orders.join(products, how="left", on=["product_id"])
orders = orders.transform(
lambda df: df[["uid", "seller_id", "timestamp"]].fillna(0),
schema={
"uid": int,
"seller_id": int,
"timestamp": datetime,
},
)

orders = orders.transform(lambda df: df.fillna(0))
orders = orders.drop("product_id", "desc", "price")
orders = orders.dropnull()
return orders.groupby("uid", "seller_id").aggregate(
[
Count(window=Window("1d"), into_field="num_orders_1d"),
Count(window=Window("1w"), into_field="num_orders_1w"),
]
Count(window=Window("1d"), into_field="num_orders_1d"),
Count(window=Window("1w"), into_field="num_orders_1w"),
)


Expand Down
70 changes: 33 additions & 37 deletions fennel/client_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,43 +1111,39 @@ class MovieRatingWindowed:
@inputs(RatingActivity)
def pipeline_aggregate(cls, activity: Dataset):
return activity.groupby("movie").aggregate(
[
Count(window=Window("3d"), into_field=str(cls.num_ratings_3d)),
Sum(
window=Window("7d"),
of="rating",
into_field=str(cls.sum_ratings_7d),
),
Average(
window=Window("6h"),
of="rating",
into_field=str(cls.avg_rating_6h),
),
Count(
window=Window("forever"), into_field=str(cls.total_ratings)
),
Stddev(
window=Window("3d"),
of="rating",
into_field=str(cls.std_rating_3d),
),
Stddev(
window=Window("7d"),
of="rating",
into_field=str(cls.std_rating_7d),
),
Stddev(
window=Window("10m"),
of="rating",
into_field=str(cls.std_rating_10m),
),
Stddev(
window=Window("10m"),
of="rating",
default=-3.14159,
into_field=str(cls.std_rating_10m_other_default),
),
]
Count(window=Window("3d"), into_field=str(cls.num_ratings_3d)),
Sum(
window=Window("7d"),
of="rating",
into_field=str(cls.sum_ratings_7d),
),
Average(
window=Window("6h"),
of="rating",
into_field=str(cls.avg_rating_6h),
),
Count(window=Window("forever"), into_field=str(cls.total_ratings)),
Stddev(
window=Window("3d"),
of="rating",
into_field=str(cls.std_rating_3d),
),
Stddev(
window=Window("7d"),
of="rating",
into_field=str(cls.std_rating_7d),
),
Stddev(
window=Window("10m"),
of="rating",
into_field=str(cls.std_rating_10m),
),
Stddev(
window=Window("10m"),
of="rating",
default=-3.14159,
into_field=str(cls.std_rating_10m_other_default),
),
)


Expand Down
6 changes: 2 additions & 4 deletions fennel/client_tests/test_fraud_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ class UserTransactionSums:
@inputs(CreditCardTransactions)
def first_pipeline(cls, transactions: Dataset):
return transactions.groupby("cc_num").aggregate(
[
Sum(of="amt", window=Window("1d"), into_field="sum_amt_1d"),
Sum(of="amt", window=Window("7d"), into_field="sum_amt_7d"),
]
Sum(of="amt", window=Window("1d"), into_field="sum_amt_1d"),
Sum(of="amt", window=Window("7d"), into_field="sum_amt_7d"),
)


Expand Down
11 changes: 7 additions & 4 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,15 @@ def __init__(self, node: _Node, *args):
self.node = node
self.node.out_edges.append(self)

def aggregate(self, aggregates: List[AggregateType], *args) -> _Node:
if len(args) > 0 or not isinstance(aggregates, list):
def aggregate(self, *args) -> _Node:
if len(args) == 0:
raise TypeError(
"aggregate operator, takes a list of aggregates "
"found: {}".format(type(aggregates))
"aggregate operator expects atleast one aggregation operation"
)
if len(args) == 1 and isinstance(args[0], list):
aggregates = args[0]
else:
aggregates = list(args)
if len(self.keys) == 1 and isinstance(self.keys[0], list):
self.keys = self.keys[0] # type: ignore
return Aggregate(self.node, list(self.keys), aggregates)
Expand Down

0 comments on commit e8cbb7e

Please sign in to comment.