Skip to content

Commit

Permalink
Add nullable join
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoang Phan authored and nikhilgarg28 committed Oct 21, 2024
1 parent 6f6d2c6 commit 96bd4f1
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 53 deletions.
120 changes: 119 additions & 1 deletion docs/examples/api-reference/operators/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
__owner__ = "[email protected]"


class TestAssignSnips(unittest.TestCase):
class TestJoinSnips(unittest.TestCase):
@mock
def test_basic(self, client):
# docsnip basic
Expand Down Expand Up @@ -118,3 +118,121 @@ def join_pipeline(cls, tx: Dataset, merchant_category: Dataset):
df["timestamp"].tolist()
== [datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)] * 3
)

@mock
def test_optional(self, client):
# docsnip optional_join
from fennel.datasets import dataset, field, pipeline, Dataset
from fennel.lib import inputs
from fennel.connectors import source, Webhook
from typing import Optional
webhook = Webhook(name="webhook")

@source(webhook.endpoint("Transaction"), disorder="14d", cdc="append")
@dataset
class Transaction:
uid: int
merchant: Optional[int]
amount: int
timestamp: datetime

@source(
webhook.endpoint("MerchantCategory"), disorder="14d", cdc="upsert"
)
@dataset(index=True)
class MerchantCategory:
# docsnip-highlight start
# right side of the join can only be on key fields
merchant: int = field(key=True)
# docsnip-highlight end
category: str
updated_at: datetime # won't show up in joined dataset

@dataset
class WithCategory:
uid: int
merchant: Optional[int]
amount: int
timestamp: datetime
category: Optional[str]

@pipeline
@inputs(Transaction, MerchantCategory)
def join_pipeline(cls, tx: Dataset, merchant_category: Dataset):
# docsnip-highlight next-line
return tx.join(merchant_category, on=["merchant"], how="left")

# /docsnip

# log some rows to both datasets
client.commit(
message="some msg",
datasets=[Transaction, MerchantCategory, WithCategory],
)
client.log(
"webhook",
"Transaction",
pd.DataFrame(
[
{
"uid": 1,
"merchant": 4,
"amount": 10,
"timestamp": "2021-01-01T00:00:00",
},
{
"uid": 1,
"merchant": None,
"amount": 15,
"timestamp": "2021-01-01T00:00:00",
},
{
"uid": 2,
"merchant": 5,
"amount": 20,
"timestamp": "2021-01-01T00:00:00",
},
{
"uid": 3,
"merchant": 4,
"amount": 30,
"timestamp": "2021-01-01T00:00:00",
},
{
"uid": 3,
"merchant": 6,
"amount": 30,
"timestamp": "2021-01-01T00:00:00",
},
]
),
)
client.log(
"webhook",
"MerchantCategory",
pd.DataFrame(
[
{
"merchant": 4,
"category": "grocery",
"updated_at": "2021-01-01T00:00:00",
},
{
"merchant": 5,
"category": "electronics",
"updated_at": "2021-01-01T00:00:00",
},
]
),
)
import numpy as np
df = client.get_dataset_df("WithCategory")
df = df.replace({np.nan: None})
assert df["uid"].tolist() == [1, 1, 2, 3, 3]
assert df["merchant"].tolist() == [4, None, 5, 4, 6]
assert df["amount"].tolist() == [10, 15, 20, 30, 30]
assert df["category"].tolist() == ["grocery", None, "electronics", "grocery", None]
assert (
df["timestamp"].tolist()
== [datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)] * 5
)
13 changes: 11 additions & 2 deletions docs/pages/api-reference/operators/join.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ a row even if there is no matching row on the right side.
<Expandable title="on" type="Optional[List[str]]" defaultVal="None">
Kwarg that specifies the list of fields along which join should happen. If present,
both left and right side datasets must have fields with these names and matching
data types. This list must be identical to the names of all key columns of the
right hand side.
data types (data types on left hand side can be optional). This list must be identical
to the names of all key columns of the right hand side.

If this isn't set, `left_on` and `right_on` must be set instead.
</Expandable>
Expand Down Expand Up @@ -73,6 +73,11 @@ dataset's timestamp field.
message="Inner join on 'merchant'">
</pre>


<pre snippet="api-reference/operators/join#optional_join" status="success"
message="Left join on 'merchant' with optional LHS fields">
</pre>

#### Returns
<Expandable type="Dataset">
Returns a dataset representing the joined dataset having the same keys & timestamp
Expand All @@ -84,6 +89,10 @@ non-timestamp columns from the right dataset.
If the join was of type `inner`, the type of a joined
RHS column of type `T` stays `T` but if the join was of type `left`, the type in
the output dataset becomes `Optional[T]` if it was `T` on the RHS side.

For LHS columns, the type is the same as the type in the LHS dataset if join type is `left`.
If the join type is `inner`, if a join column on the LHS is `Optional[T]`, then the type
in the output dataset is `T` (i.e., the `Optional` is dropped).
</Expandable>

#### Errors
Expand Down
29 changes: 14 additions & 15 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,9 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:

rhs_keys = set(self.dataset.dsschema().keys)
join_keys = set(self.on) if self.on is not None else set(self.right_on)
final_join_cols = (
set(self.on) if self.on is not None else set(self.left_on)
)
# Ensure on or right_on are the keys of the right dataset
if join_keys != rhs_keys:
raise ValueError(
Expand Down Expand Up @@ -1130,6 +1133,11 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:
else:
joined_dsschema.append_value_column(right_ts, datetime.datetime)

# Drop null on join keys if how is inner
if self.how == "inner":
for key in final_join_cols:
joined_dsschema.drop_null_column(key)

return joined_dsschema


Expand Down Expand Up @@ -2949,13 +2957,9 @@ def is_subset(subset: List[str], superset: List[str]) -> bool:
)
# Check the schemas of the keys
for key in obj.on:
if fennel_is_optional(left_schema.get_type(key)):
raise TypeError(
f"Fields used in a join operator must not be optional in left schema, "
f"found `{key}` of type `{dtype_to_string(left_schema.get_type(key))}` "
f"in `{output_schema_name}`"
)
if left_schema.get_type(key) != right_schema.get_type(key):
if fennel_get_optional_inner(
left_schema.get_type(key)
) != right_schema.get_type(key):
raise TypeError(
f"Key field `{key}` has type `{dtype_to_string(left_schema.get_type(key))}` "
f"in left schema but type "
Expand All @@ -2978,14 +2982,9 @@ def is_subset(subset: List[str], superset: List[str]) -> bool:
)
# Check the schemas of the keys
for lkey, rkey in zip(obj.left_on, obj.right_on):
if fennel_is_optional(left_schema.get_type(lkey)):
raise TypeError(
f"Fields used in a join operator must not be optional "
f"in left schema, found `{lkey}` of type "
f"`{dtype_to_string(left_schema.get_type(lkey))}` "
f"in `{output_schema_name}`"
)
if left_schema.get_type(lkey) != right_schema.get_type(rkey):
if fennel_get_optional_inner(
left_schema.get_type(lkey)
) != right_schema.get_type(rkey):
raise TypeError(
f"Key field `{lkey}` has type"
f" `{dtype_to_string(left_schema.get_type(lkey))}` "
Expand Down
35 changes: 0 additions & 35 deletions fennel/datasets/test_invalid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,41 +800,6 @@ def create_pipeline(cls, a: Dataset):
== "Cannot join with an intermediate dataset, i.e something defined inside a pipeline. Only joining against keyed datasets is permitted."
)

with pytest.raises(TypeError) as e:

@dataset
class XYZ:
user_id: Optional[int]
agent_id: int
name: str
timestamp: datetime

@dataset(index=True)
class ABC:
user_id: int = field(key=True)
agent_id: int = field(key=True)
age: int
timestamp: datetime

@dataset
class XYZJoinedABC:
user_id: int
name: str
age: int
timestamp: datetime

@pipeline
@inputs(XYZ, ABC)
def create_pipeline(cls, a: Dataset, b: Dataset):
c = a.join(b, how="inner", on=["user_id", "agent_id"]) # type: ignore
return c

assert (
str(e.value)
== "Fields used in a join operator must not be optional in left schema, found `user_id` of "
"type `Optional[int]` in `'[Pipeline:create_pipeline]->join node'`"
)


def test_dataset_incorrect_join_fields():
with pytest.raises(ValueError) as e:
Expand Down
Loading

0 comments on commit 96bd4f1

Please sign in to comment.