diff --git a/docs/examples/getting-started/quickstart.py b/docs/examples/getting-started/quickstart.py index 2e5530781..89dc198d5 100644 --- a/docs/examples/getting-started/quickstart.py +++ b/docs/examples/getting-started/quickstart.py @@ -31,11 +31,7 @@ # docsnip datasets @dataset -@source( - postgres.table("product_info", cursor="last_modified"), - every="1m", - tier="prod", -) +@source(postgres.table("product", cursor="updated"), every="1m", tier="prod") @source(webhook.endpoint("Product"), tier="dev") @meta(owner="chris@fennel.ai", tags=["PII"]) class Product: @@ -83,20 +79,12 @@ class UserSellerOrders: @inputs(Order, Product) def my_pipeline(cls, orders: Dataset, products: Dataset): orders = orders.join(products, how="left", on=["product_id"]) - orders = orders.transform( - lambda df: df[["uid", "seller_id", "timestamp"]].fillna(0), - schema={ - "uid": int, - "seller_id": int, - "timestamp": datetime, - }, - ) - + orders = orders.transform(lambda df: df.fillna(0)) + orders = orders.drop("product_id", "desc", "price") + orders = orders.dropnull() return orders.groupby("uid", "seller_id").aggregate( - [ - Count(window=Window("1d"), into_field="num_orders_1d"), - Count(window=Window("1w"), into_field="num_orders_1w"), - ] + Count(window=Window("1d"), into_field="num_orders_1d"), + Count(window=Window("1w"), into_field="num_orders_1w"), ) diff --git a/fennel/client_tests/test_dataset.py b/fennel/client_tests/test_dataset.py index 2a8c87206..35516a51f 100644 --- a/fennel/client_tests/test_dataset.py +++ b/fennel/client_tests/test_dataset.py @@ -1111,43 +1111,39 @@ class MovieRatingWindowed: @inputs(RatingActivity) def pipeline_aggregate(cls, activity: Dataset): return activity.groupby("movie").aggregate( - [ - Count(window=Window("3d"), into_field=str(cls.num_ratings_3d)), - Sum( - window=Window("7d"), - of="rating", - into_field=str(cls.sum_ratings_7d), - ), - Average( - window=Window("6h"), - of="rating", - into_field=str(cls.avg_rating_6h), - ), - Count( - window=Window("forever"), into_field=str(cls.total_ratings) - ), - Stddev( - window=Window("3d"), - of="rating", - into_field=str(cls.std_rating_3d), - ), - Stddev( - window=Window("7d"), - of="rating", - into_field=str(cls.std_rating_7d), - ), - Stddev( - window=Window("10m"), - of="rating", - into_field=str(cls.std_rating_10m), - ), - Stddev( - window=Window("10m"), - of="rating", - default=-3.14159, - into_field=str(cls.std_rating_10m_other_default), - ), - ] + Count(window=Window("3d"), into_field=str(cls.num_ratings_3d)), + Sum( + window=Window("7d"), + of="rating", + into_field=str(cls.sum_ratings_7d), + ), + Average( + window=Window("6h"), + of="rating", + into_field=str(cls.avg_rating_6h), + ), + Count(window=Window("forever"), into_field=str(cls.total_ratings)), + Stddev( + window=Window("3d"), + of="rating", + into_field=str(cls.std_rating_3d), + ), + Stddev( + window=Window("7d"), + of="rating", + into_field=str(cls.std_rating_7d), + ), + Stddev( + window=Window("10m"), + of="rating", + into_field=str(cls.std_rating_10m), + ), + Stddev( + window=Window("10m"), + of="rating", + default=-3.14159, + into_field=str(cls.std_rating_10m_other_default), + ), ) diff --git a/fennel/client_tests/test_fraud_detection.py b/fennel/client_tests/test_fraud_detection.py index b8ddaba94..831e5201a 100644 --- a/fennel/client_tests/test_fraud_detection.py +++ b/fennel/client_tests/test_fraud_detection.py @@ -65,10 +65,8 @@ class UserTransactionSums: @inputs(CreditCardTransactions) def first_pipeline(cls, transactions: Dataset): return transactions.groupby("cc_num").aggregate( - [ - Sum(of="amt", window=Window("1d"), into_field="sum_amt_1d"), - Sum(of="amt", window=Window("7d"), into_field="sum_amt_7d"), - ] + Sum(of="amt", window=Window("1d"), into_field="sum_amt_1d"), + Sum(of="amt", window=Window("7d"), into_field="sum_amt_7d"), ) diff --git a/fennel/datasets/datasets.py b/fennel/datasets/datasets.py index e7412be3e..37d355182 100644 --- a/fennel/datasets/datasets.py +++ b/fennel/datasets/datasets.py @@ -443,12 +443,15 @@ def __init__(self, node: _Node, *args): self.node = node self.node.out_edges.append(self) - def aggregate(self, aggregates: List[AggregateType], *args) -> _Node: - if len(args) > 0 or not isinstance(aggregates, list): + def aggregate(self, *args) -> _Node: + if len(args) == 0: raise TypeError( - "aggregate operator, takes a list of aggregates " - "found: {}".format(type(aggregates)) + "aggregate operator expects atleast one aggregation operation" ) + if len(args) == 1 and isinstance(args[0], list): + aggregates = args[0] + else: + aggregates = list(args) if len(self.keys) == 1 and isinstance(self.keys[0], list): self.keys = self.keys[0] # type: ignore return Aggregate(self.node, list(self.keys), aggregates)