Skip to content

Commit

Permalink
join: Handle conflict in LHS value and RHS key
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar committed Nov 13, 2023
1 parent 5e3504d commit 9da3915
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 1 deletion.
142 changes: 142 additions & 0 deletions fennel/client_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

webhook = Webhook(name="fennel_webhook")

__owner__ = "[email protected]"


@meta(owner="[email protected]")
@source(webhook.endpoint("UserInfoDataset"))
Expand Down Expand Up @@ -2934,3 +2936,143 @@ def TransactionsCreditInternetBanking_wrapper_4d45b34b11_filter(df: pd.DataFrame
assert del_spaces_tabs_and_newlines(
sync_request.operators[3].filter.pycode.generated_code
) == del_spaces_tabs_and_newlines(expected_code)


@mock
def test_inner_join_column_name_collision(client):
webhook = Webhook(name="fennel_webhook")

@dataset
@source(webhook.endpoint("PaymentEventDataset"), tier="local")
class PaymentEventDataset:
customer: int = field(key=True)
created: datetime
outcome_risk_score: float

@dataset
@source(webhook.endpoint("PaymentAccountDataset"), tier="local")
class PaymentAccountDataset:
id: int
created: datetime
customer_id: int = field(key=True)

@dataset
@source(webhook.endpoint("PaymentAccountAssociationDataset"), tier="local")
class PaymentAccountAssociationDataset:
id: int = field(key=True)
created: datetime
account_id: int

@dataset
@source(webhook.endpoint("AccountDataset"), tier="local")
class AccountDataset:
id: int = field(key=True)
created: datetime
primary_rider_id: int

@dataset
class RiderAggRiskScore:
primary_rider_id: int = field(key=True)
created: datetime
max_risk_score: float
min_risk_score: float

@pipeline()
@inputs(
PaymentEventDataset,
PaymentAccountDataset,
PaymentAccountAssociationDataset,
AccountDataset,
)
def stripe_enrichment(
cls,
stripe_charge: Dataset,
payment_account: Dataset,
payment_account_association: Dataset,
account: Dataset,
):
return (
stripe_charge.join(
payment_account,
how="inner",
left_on=["customer"],
right_on=["customer_id"],
)
.join(payment_account_association, how="inner", on=["id"])
.join(
account,
left_on=["account_id"],
right_on=["id"],
how="inner",
)
.groupby("primary_rider_id")
.aggregate(
Max(
of="outcome_risk_score",
window=Window("forever"),
into_field="max_risk_score",
default=-1.0,
),
Min(
of="outcome_risk_score",
window=Window("forever"),
into_field="min_risk_score",
default=-1.0,
),
)
)

initial = client.sync(
datasets=[
RiderAggRiskScore,
PaymentEventDataset,
PaymentAccountDataset,
PaymentAccountAssociationDataset,
AccountDataset,
]
)
assert initial.status_code == 200
now = datetime.now()

stripe_charge_df = pd.DataFrame(
{"customer": [1], "created": [now], "outcome_risk_score": [0.5]}
)
stripe_charge_response = client.log(
webhook="fennel_webhook",
endpoint="PaymentEventDataset",
df=stripe_charge_df,
)
assert stripe_charge_response.status_code == 200

payment_account_df = pd.DataFrame(
{"id": [1], "created": [now], "customer_id": [1]}
)
payment_account_response = client.log(
webhook="fennel_webhook",
endpoint="PaymentAccountDataset",
df=payment_account_df,
)
assert payment_account_response.status_code == 200

payment_account_association_df = pd.DataFrame(
{"id": [1], "created": [now], "account_id": [1]}
)
payment_account_association_response = client.log(
webhook="fennel_webhook",
endpoint="PaymentAccountAssociationDataset",
df=payment_account_association_df,
)
assert payment_account_association_response.status_code == 200

account_df = pd.DataFrame(
{"id": [1], "created": [now], "primary_rider_id": [1]}
)
account_response = client.log(
webhook="fennel_webhook", endpoint="AccountDataset", df=account_df
)

assert account_response.status_code == 200

extracted_df = client.get_dataset_df("RiderAggRiskScore")
assert extracted_df.shape[0] == 1
assert extracted_df["max_risk_score"].iloc[0] == 0.5
6 changes: 6 additions & 0 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def __init__(
on: Optional[List[str]] = None,
left_on: Optional[List[str]] = None,
right_on: Optional[List[str]] = None,
# Currently not supported
lsuffix: str = "",
rsuffix: str = "",
):
Expand Down Expand Up @@ -615,6 +616,9 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:
right_value_schema: Dict[str, Type] = copy.deepcopy(
self.dataset.dsschema().values
)
right_key_schema: Dict[str, Type] = copy.deepcopy(
self.dataset.dsschema().keys
)

common_cols = set(left_schema.keys()) & set(right_value_schema.keys())
# for common values, suffix column name in left_schema with lsuffix and right_schema with rsuffix
Expand Down Expand Up @@ -1923,6 +1927,8 @@ def is_subset(subset: List[str], superset: List[str]) -> bool:
f"in left schema but type "
f"{dtype_to_string(right_schema.get_type(key))} in right schema."
)
# Check that none of the other fields collide

else:
# obj.right_on should be the keys of the right dataset
if set(obj.right_on) != set(right_schema.keys.keys()):
Expand Down
6 changes: 6 additions & 0 deletions fennel/datasets/test_invalid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import pytest
from typing import Optional, List, Union
import pandas as pd

from fennel import Min, Max
from fennel.datasets import dataset, pipeline, field, Dataset
from fennel.lib.aggregate import Count, Average, Stddev, Distinct
from fennel.lib.expectations import (
Expand All @@ -11,8 +14,11 @@
from fennel.lib.metadata import meta
from fennel.lib.schema import inputs, struct
from fennel.lib.window import Window
from fennel.sources import Webhook, source
from fennel.test_lib import *

__owner__ = "[email protected]"


def test_multiple_date_time():
with pytest.raises(ValueError) as e:
Expand Down
9 changes: 9 additions & 0 deletions fennel/test_lib/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ def sub_within_low(row):
left_by = copy.deepcopy(obj.left_on)
right_by = copy.deepcopy(obj.right_on)

# Rename the right_by columns to avoid conflicts with any of the left columns.
# We dont need to worry about right value columns conflicting with left key columns,
# because we have already verified that.
if set(right_by).intersection(set(left_df.columns)):
right_df = right_df.rename(
columns={col: f"__@@__{col}" for col in right_by}
)
right_by = [f"__@@__{col}" for col in right_by]

left_df = left_df.sort_values(by=ts_query_field)
right_df = right_df.sort_values(by=ts_query_field)
merged_df = pd.merge_asof(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fennel-ai"
version = "0.18.16"
version = "0.18.17"
description = "The modern realtime feature engineering platform"
authors = ["Fennel AI <[email protected]>"]
packages = [{ include = "fennel" }]
Expand Down

0 comments on commit 9da3915

Please sign in to comment.