Skip to content

Commit

Permalink
assign: Include type promotion in assign
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar committed Sep 6, 2024
1 parent e10f4fa commit 1d22322
Show file tree
Hide file tree
Showing 12 changed files with 712 additions and 419 deletions.
3 changes: 3 additions & 0 deletions fennel/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## [1.5.18] - 2024-09-05
- Struct initializer + arrow fixes + type promotion in assign

## [1.5.17] - 2024-09-04
- Add support for several more expressions

Expand Down
151 changes: 150 additions & 1 deletion fennel/client_tests/test_complex_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fennel.connectors import Webhook, source
from fennel.datasets import dataset, Dataset, field, pipeline, LastK
from fennel.dtypes import struct, Continuous
from fennel.expr.expr import col, make_struct
from fennel.featuresets import featureset, feature as F, extractor
from fennel.lib import inputs, outputs
from fennel.testing import mock
Expand Down Expand Up @@ -87,6 +88,118 @@ def movie_info(cls, movie: Dataset):
)


@dataset(index=True)
class MovieInfoExpr:
director_id: int = field(key=True)
movie_id: int = field(key=True)
role_list: List[Role]
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(MovieDS)
def movie_info(cls, movie: Dataset):
return (
movie.assign(
role=make_struct(
{
"role_id": col("role_id"),
"name": col("name"),
"cost": col("cost"),
},
Role,
).astype(Role)
)
.drop(columns=["role_id", "name", "cost"])
.groupby("director_id", "movie_id")
.aggregate(
LastK(
into_field="role_list",
of="role",
window=Continuous("forever"),
limit=3,
dedup=False,
),
)
)


@dataset(index=True)
class MovieInfoExpr2:
director_id: int = field(key=True)
movie_id: int = field(key=True)
role_list: List[Role]
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(MovieDS)
def movie_info(cls, movie: Dataset):
return (
movie.assign(
role=Role.expr( # type: ignore
role_id=col("role_id"), name=col("name"), cost=col("cost")
).astype(Role)
)
.drop(columns=["role_id", "name", "cost"])
.groupby("director_id", "movie_id")
.aggregate(
LastK(
into_field="role_list",
of="role",
window=Continuous("forever"),
limit=3,
dedup=False,
),
)
)


@struct
class FullName:
first_name: str
last_name: str


@struct
class RoleExtended:
role_id: int
name: FullName
cost: int


@dataset(index=True)
class MovieInfoExprNested:
director_id: int = field(key=True)
movie_id: int = field(key=True)
role_list: List[RoleExtended]
timestamp: datetime = field(timestamp=True)

@pipeline
@inputs(MovieDS)
def movie_info(cls, movie: Dataset):
return (
movie.assign(
role=RoleExtended.expr( # type: ignore
role_id=col("role_id"),
name=FullName.expr( # type: ignore
first_name=col("name"), last_name="rando"
),
cost=col("cost"),
).astype(RoleExtended)
)
.drop(columns=["role_id", "name", "cost"])
.groupby("director_id", "movie_id")
.aggregate(
LastK(
into_field="role_list",
of="role",
window=Continuous("forever"),
limit=3,
dedup=False,
),
)
)


@featureset
class Request:
director_id: int
Expand Down Expand Up @@ -218,7 +331,13 @@ def test_complex_struct(client):

client.commit(
message="msg",
datasets=[MovieDS, MovieInfo],
datasets=[
MovieDS,
MovieInfo,
MovieInfoExpr,
MovieInfoExpr2,
MovieInfoExprNested,
],
featuresets=[Request, MovieFeatures],
)

Expand Down Expand Up @@ -255,6 +374,36 @@ def test_complex_struct(client):
input_dataframe=input_df,
)

res1, found1 = client.lookup(
"MovieInfo",
keys=pd.DataFrame({"director_id": [1, 2], "movie_id": [1, 3]}),
)
res2, found2 = client.lookup(
"MovieInfoExpr",
keys=pd.DataFrame({"director_id": [1, 2], "movie_id": [1, 3]}),
)
res3, found3 = client.lookup(
"MovieInfoExpr2",
keys=pd.DataFrame({"director_id": [1, 2], "movie_id": [1, 3]}),
)
assert res1.shape == res2.shape
assert res1.shape == res3.shape
for c in res1.columns:
assert res1[c].equals(res2[c])
assert res1[c].equals(res3[c])
assert list(found1) == list(found2)
assert list(found1) == list(found3)

res4, found4 = client.lookup(
"MovieInfoExprNested",
keys=pd.DataFrame({"director_id": [1, 2], "movie_id": [1, 3]}),
)
assert res1.shape == res4.shape
assert list(found1) == list(found4)
for r in res4["role_list"]:
for role in r:
assert role.name.last_name == "rando"

assert df.shape[0] == 4
assert len(df["MovieFeatures.role_list_py"].tolist()[0]) == 3
assert df["MovieFeatures.role_list_py"].tolist()[0][0].as_json() == {
Expand Down
127 changes: 52 additions & 75 deletions fennel/client_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,81 +606,6 @@ class UserInfoDataset:
)


# On demand datasets are not supported for now.

# class TestDocumentDataset(unittest.TestCase):
# @mock_client
# def test_log_to_document_dataset(self, client):
# """Log some data to the dataset and check if it is logged correctly."""
#
# @meta(owner="[email protected]")
# @dataset
# class DocumentContentDataset:
# doc_id: int = field(key=True)
# bert_embedding: Embedding[4]
# fast_text_embedding: Embedding[3]
# num_words: int
# timestamp: datetime = field(timestamp=True)
#
# @on_demand(expires_after="3d")
# @inputs(datetime, int)
# def get_embedding(cls, ts: pd.Series, doc_ids: pd.Series):
# data = []
# doc_ids = doc_ids.tolist()
# for i in range(len(ts)):
# data.append(
# [
# doc_ids[i],
# [0.1, 0.2, 0.3, 0.4],
# [1.1, 1.2, 1.3],
# 10 * i,
# ts[i],
# ]
# )
# columns = [
# str(cls.doc_id),
# str(cls.bert_embedding),
# str(cls.fast_text_embedding),
# str(cls.num_words),
# str(cls.timestamp),
# ]
# return pd.DataFrame(data, columns=columns), pd.Series(
# [True] * len(ts)
# )
#
# # Sync the dataset
# client.commit(datasets=[DocumentContentDataset])
# now = datetime.now(timezone.utc)
# data = [
# [18232, np.array([1, 2, 3, 4]), np.array([1, 2, 3]), 10, now],
# [
# 18234,
# np.array([1, 2.2, 0.213, 0.343]),
# np.array([0.87, 2, 3]),
# 9,
# now,
# ],
# [18934, [1, 2.2, 0.213, 0.343], [0.87, 2, 3], 12, now],
# ]
# columns = [
# "doc_id",
# "bert_embedding",
# "fast_text_embedding",
# "num_words",
# "timestamp",
# ]
# df = pd.DataFrame(data, columns=columns)
# response = client.log("fennel_webhook","DocumentContentDataset", df)
# assert response.status_code == requests.codes.OK, response.json()
#
# # Do some lookups
# doc_ids = pd.Series([18232, 1728, 18234, 18934, 19200, 91012])
# ts = pd.Series([now, now, now, now, now, now])
# df, _ = DocumentContentDataset.lookup(ts, doc_id=doc_ids)
# assert df.shape == (6, 5)
# assert df["num_words"].tolist() == [10.0, 9.0, 12, 0, 10.0, 20.0]


################################################################################
# Dataset & Pipelines Unit Tests
################################################################################
Expand Down Expand Up @@ -1012,6 +937,25 @@ class Orders:
timestamp: datetime


@dataset
class OrdersOptional:
uid: Optional[int]
uid_float: float
uid_twice: float
skus: List[int]
prices: List[float]
timestamp: datetime

@pipeline
@inputs(Orders)
def cast(cls, ds: Dataset):
return ds.assign(
uid=col("uid").astype(Optional[int]), # type: ignore
uid_float=col("uid").astype(float), # type: ignore
uid_twice=(col("uid") * 2.0).astype(float), # type: ignore
)


@dataset(index=True)
class Derived:
uid: int = field(key=True)
Expand Down Expand Up @@ -1066,6 +1010,39 @@ def test_basic_explode(self, client):
assert df["price"].tolist()[0] == 10.1
assert pd.isna(df["price"].tolist()[1])

@pytest.mark.integration
@mock
def test_basic_cast(self, client):
# # Sync the dataset
client.commit(message="msg", datasets=[Orders, OrdersOptional])
# log some rows to the transaction dataset
df = pd.DataFrame(
[
{
"uid": 1,
"skus": [1, 2],
"prices": [10.1, 20.0],
"timestamp": "2021-01-01T00:00:00",
},
{
"uid": 2,
"skus": [],
"prices": [],
"timestamp": "2021-01-01T00:00:00",
},
]
)
client.log("webhook", "Orders", df)
client.sleep()

# do lookup on the WithSquare dataset
df = client.inspect("OrdersOptional")
assert df.shape == (2, 6)
assert df["uid"].tolist() == [1, 2]
assert df["uid_float"].tolist() == [1.0, 2.0]
assert df["uid_twice"].tolist() == [2.0, 4.0]
assert df["skus"].tolist() == [[1, 2], []]


class TestBasicAssign(unittest.TestCase):
@pytest.mark.integration
Expand Down
5 changes: 3 additions & 2 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,15 +2935,16 @@ def visitAssign(self, obj) -> DSSchema:
raise ValueError(
f"invalid assign - {output_schema_name} error in expression for column `{col}`: {str(e)}"
)
if typed_expr.dtype != expr_type:
if not typed_expr.expr.matches_type(
typed_expr.dtype, input_schema.schema()
):
printer = ExprPrinter()
type_errors.append(
f"'{col}' is expected to be of type `{dtype_to_string(typed_expr.dtype)}`, but evaluates to `{dtype_to_string(expr_type)}`. Full expression: `{printer.print(typed_expr.expr.root)}`"
)

if len(type_errors) > 0:
joined_errors = "\n\t".join(type_errors)
print(joined_errors)
raise TypeError(
f"found type errors in assign node of `{self.dsname}.{self.pipeline_name}`:\n\t{joined_errors}"
)
Expand Down
12 changes: 12 additions & 0 deletions fennel/dtypes/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from functools import partial
import inspect
import sys
from dataclasses import dataclass
Expand Down Expand Up @@ -75,6 +76,16 @@ def get_fennel_struct(annotation) -> Any:
return None


def make_struct_expr(cls, **kwargs):
from fennel.expr.expr import Expr, make_expr, make_struct

fields = {}
for name, value in kwargs.items():
fields[name] = make_expr(value)

return make_struct(fields, cls)


def struct(cls):
for name, member in inspect.getmembers(cls):
if inspect.isfunction(member) and name in cls.__dict__:
Expand Down Expand Up @@ -131,6 +142,7 @@ def struct(cls):
setattr(cls, FENNEL_STRUCT_SRC_CODE, "")
setattr(cls, FENNEL_STRUCT_DEPENDENCIES_SRC_CODE, dependency_code)
cls.as_json = as_json
cls.expr = partial(make_struct_expr, cls)
return dataclasses.dataclass(cls)


Expand Down
Loading

0 comments on commit 1d22322

Please sign in to comment.