diff --git a/docs/examples/api-reference/operators/join.py b/docs/examples/api-reference/operators/join.py
index c590fd3e..8b6438c7 100644
--- a/docs/examples/api-reference/operators/join.py
+++ b/docs/examples/api-reference/operators/join.py
@@ -8,7 +8,7 @@
__owner__ = "aditya@fennel.ai"
-class TestAssignSnips(unittest.TestCase):
+class TestJoinSnips(unittest.TestCase):
@mock
def test_basic(self, client):
# docsnip basic
@@ -118,3 +118,121 @@ def join_pipeline(cls, tx: Dataset, merchant_category: Dataset):
df["timestamp"].tolist()
== [datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)] * 3
)
+
+ @mock
+ def test_optional(self, client):
+ # docsnip optional_join
+ from fennel.datasets import dataset, field, pipeline, Dataset
+ from fennel.lib import inputs
+ from fennel.connectors import source, Webhook
+ from typing import Optional
+ webhook = Webhook(name="webhook")
+
+ @source(webhook.endpoint("Transaction"), disorder="14d", cdc="append")
+ @dataset
+ class Transaction:
+ uid: int
+ merchant: Optional[int]
+ amount: int
+ timestamp: datetime
+
+ @source(
+ webhook.endpoint("MerchantCategory"), disorder="14d", cdc="upsert"
+ )
+ @dataset(index=True)
+ class MerchantCategory:
+ # docsnip-highlight start
+ # right side of the join can only be on key fields
+ merchant: int = field(key=True)
+ # docsnip-highlight end
+ category: str
+ updated_at: datetime # won't show up in joined dataset
+
+ @dataset
+ class WithCategory:
+ uid: int
+ merchant: Optional[int]
+ amount: int
+ timestamp: datetime
+ category: Optional[str]
+
+ @pipeline
+ @inputs(Transaction, MerchantCategory)
+ def join_pipeline(cls, tx: Dataset, merchant_category: Dataset):
+ # docsnip-highlight next-line
+ return tx.join(merchant_category, on=["merchant"], how="left")
+
+ # /docsnip
+
+ # log some rows to both datasets
+ client.commit(
+ message="some msg",
+ datasets=[Transaction, MerchantCategory, WithCategory],
+ )
+ client.log(
+ "webhook",
+ "Transaction",
+ pd.DataFrame(
+ [
+ {
+ "uid": 1,
+ "merchant": 4,
+ "amount": 10,
+ "timestamp": "2021-01-01T00:00:00",
+ },
+ {
+ "uid": 1,
+ "merchant": None,
+ "amount": 15,
+ "timestamp": "2021-01-01T00:00:00",
+ },
+ {
+ "uid": 2,
+ "merchant": 5,
+ "amount": 20,
+ "timestamp": "2021-01-01T00:00:00",
+ },
+ {
+ "uid": 3,
+ "merchant": 4,
+ "amount": 30,
+ "timestamp": "2021-01-01T00:00:00",
+ },
+ {
+ "uid": 3,
+ "merchant": 6,
+ "amount": 30,
+ "timestamp": "2021-01-01T00:00:00",
+ },
+ ]
+ ),
+ )
+ client.log(
+ "webhook",
+ "MerchantCategory",
+ pd.DataFrame(
+ [
+ {
+ "merchant": 4,
+ "category": "grocery",
+ "updated_at": "2021-01-01T00:00:00",
+ },
+ {
+ "merchant": 5,
+ "category": "electronics",
+ "updated_at": "2021-01-01T00:00:00",
+ },
+ ]
+ ),
+ )
+ import numpy as np
+ df = client.get_dataset_df("WithCategory")
+ df = df.replace({np.nan: None})
+ assert df["uid"].tolist() == [1, 1, 2, 3, 3]
+ assert df["merchant"].tolist() == [4, None, 5, 4, 6]
+ assert df["amount"].tolist() == [10, 15, 20, 30, 30]
+ assert df["category"].tolist() == ["grocery", None, "electronics", "grocery", None]
+ assert (
+ df["timestamp"].tolist()
+ == [datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)] * 5
+ )
diff --git a/docs/pages/api-reference/operators/join.md b/docs/pages/api-reference/operators/join.md
index 6e9648e3..99c062c9 100644
--- a/docs/pages/api-reference/operators/join.md
+++ b/docs/pages/api-reference/operators/join.md
@@ -26,8 +26,8 @@ a row even if there is no matching row on the right side.
Kwarg that specifies the list of fields along which join should happen. If present,
both left and right side datasets must have fields with these names and matching
-data types. This list must be identical to the names of all key columns of the
-right hand side.
+data types (data types on left hand side can be optional). This list must be identical
+to the names of all key columns of the right hand side.
If this isn't set, `left_on` and `right_on` must be set instead.
@@ -73,6 +73,11 @@ dataset's timestamp field.
message="Inner join on 'merchant'">
+
+
+
+
#### Returns
Returns a dataset representing the joined dataset having the same keys & timestamp
@@ -84,6 +89,10 @@ non-timestamp columns from the right dataset.
If the join was of type `inner`, the type of a joined
RHS column of type `T` stays `T` but if the join was of type `left`, the type in
the output dataset becomes `Optional[T]` if it was `T` on the RHS side.
+
+For LHS columns, the type is the same as the type in the LHS dataset if join type is `left`.
+If the join type is `inner`, if a join column on the LHS is `Optional[T]`, then the type
+in the output dataset is `T` (i.e., the `Optional` is dropped).
#### Errors
diff --git a/fennel/datasets/datasets.py b/fennel/datasets/datasets.py
index d2245b89..b49d9ac7 100644
--- a/fennel/datasets/datasets.py
+++ b/fennel/datasets/datasets.py
@@ -1062,6 +1062,9 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:
rhs_keys = set(self.dataset.dsschema().keys)
join_keys = set(self.on) if self.on is not None else set(self.right_on)
+ final_join_cols = (
+ set(self.on) if self.on is not None else set(self.left_on)
+ )
# Ensure on or right_on are the keys of the right dataset
if join_keys != rhs_keys:
raise ValueError(
@@ -1130,6 +1133,11 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:
else:
joined_dsschema.append_value_column(right_ts, datetime.datetime)
+ # Drop null on join keys if how is inner
+ if self.how == "inner":
+ for key in final_join_cols:
+ joined_dsschema.drop_null_column(key)
+
return joined_dsschema
@@ -2949,13 +2957,9 @@ def is_subset(subset: List[str], superset: List[str]) -> bool:
)
# Check the schemas of the keys
for key in obj.on:
- if fennel_is_optional(left_schema.get_type(key)):
- raise TypeError(
- f"Fields used in a join operator must not be optional in left schema, "
- f"found `{key}` of type `{dtype_to_string(left_schema.get_type(key))}` "
- f"in `{output_schema_name}`"
- )
- if left_schema.get_type(key) != right_schema.get_type(key):
+ if fennel_get_optional_inner(
+ left_schema.get_type(key)
+ ) != right_schema.get_type(key):
raise TypeError(
f"Key field `{key}` has type `{dtype_to_string(left_schema.get_type(key))}` "
f"in left schema but type "
@@ -2978,14 +2982,9 @@ def is_subset(subset: List[str], superset: List[str]) -> bool:
)
# Check the schemas of the keys
for lkey, rkey in zip(obj.left_on, obj.right_on):
- if fennel_is_optional(left_schema.get_type(lkey)):
- raise TypeError(
- f"Fields used in a join operator must not be optional "
- f"in left schema, found `{lkey}` of type "
- f"`{dtype_to_string(left_schema.get_type(lkey))}` "
- f"in `{output_schema_name}`"
- )
- if left_schema.get_type(lkey) != right_schema.get_type(rkey):
+ if fennel_get_optional_inner(
+ left_schema.get_type(lkey)
+ ) != right_schema.get_type(rkey):
raise TypeError(
f"Key field `{lkey}` has type"
f" `{dtype_to_string(left_schema.get_type(lkey))}` "
diff --git a/fennel/datasets/test_invalid_dataset.py b/fennel/datasets/test_invalid_dataset.py
index 7d61f8d9..d29c2b85 100644
--- a/fennel/datasets/test_invalid_dataset.py
+++ b/fennel/datasets/test_invalid_dataset.py
@@ -800,41 +800,6 @@ def create_pipeline(cls, a: Dataset):
== "Cannot join with an intermediate dataset, i.e something defined inside a pipeline. Only joining against keyed datasets is permitted."
)
- with pytest.raises(TypeError) as e:
-
- @dataset
- class XYZ:
- user_id: Optional[int]
- agent_id: int
- name: str
- timestamp: datetime
-
- @dataset(index=True)
- class ABC:
- user_id: int = field(key=True)
- agent_id: int = field(key=True)
- age: int
- timestamp: datetime
-
- @dataset
- class XYZJoinedABC:
- user_id: int
- name: str
- age: int
- timestamp: datetime
-
- @pipeline
- @inputs(XYZ, ABC)
- def create_pipeline(cls, a: Dataset, b: Dataset):
- c = a.join(b, how="inner", on=["user_id", "agent_id"]) # type: ignore
- return c
-
- assert (
- str(e.value)
- == "Fields used in a join operator must not be optional in left schema, found `user_id` of "
- "type `Optional[int]` in `'[Pipeline:create_pipeline]->join node'`"
- )
-
def test_dataset_incorrect_join_fields():
with pytest.raises(ValueError) as e:
diff --git a/fennel/datasets/test_schema_validator.py b/fennel/datasets/test_schema_validator.py
index 8d1e6f00..f2b4dafa 100644
--- a/fennel/datasets/test_schema_validator.py
+++ b/fennel/datasets/test_schema_validator.py
@@ -1996,3 +1996,175 @@ def pipeline_window(cls, event: Dataset):
return event.groupby("id").aggregate(
count=Count(window=Tumbling("1h")), along="ts2"
)
+
+
+def test_optional_join():
+
+ # Test that optional can join
+ # inner join drop null on join keys
+ if True:
+
+ @dataset
+ class XYZ:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ timestamp: datetime
+
+ @dataset(index=True)
+ class ABC:
+ user_id_2: int = field(key=True)
+ agent_id_2: int = field(key=True)
+ age: int
+ timestamp: datetime
+
+ @dataset
+ class XYZJoinedABC:
+ user_id: int
+ agent_id: int
+ name: str
+ age: int
+ timestamp: datetime
+
+ @pipeline
+ @inputs(XYZ, ABC)
+ def create_pipeline(cls, a: Dataset, b: Dataset):
+ c = a.join(b, how="inner", left_on=["user_id", "agent_id"], right_on=["user_id_2", "agent_id_2"]) # type: ignore
+ return c
+
+ # Test that optional can join
+ # left join should not drop null on join keys
+ if True:
+
+ @dataset
+ class XYZ:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ timestamp: datetime
+
+ @dataset(index=True)
+ class ABC:
+ user_id_2: int = field(key=True)
+ agent_id_2: int = field(key=True)
+ age: int
+ timestamp: datetime
+
+ @dataset
+ class XYZJoinedABC:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ age: Optional[int]
+ timestamp: datetime
+
+ @pipeline
+ @inputs(XYZ, ABC)
+ def create_pipeline(cls, a: Dataset, b: Dataset):
+ c = a.join(b, how="left", left_on=["user_id", "agent_id"], right_on=["user_id_2", "agent_id_2"]) # type: ignore
+ return c
+
+ # Test wrong optional
+ with pytest.raises(TypeError) as e:
+
+ @dataset
+ class XYZ:
+ user_id: Optional[str]
+ agent_id: int
+ name: str
+ timestamp: datetime
+
+ @dataset(index=True)
+ class ABC:
+ user_id: int = field(key=True)
+ agent_id: int = field(key=True)
+ age: int
+ timestamp: datetime
+
+ @dataset
+ class XYZJoinedABC:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ age: int
+ timestamp: datetime
+
+ @pipeline
+ @inputs(XYZ, ABC)
+ def create_pipeline(cls, a: Dataset, b: Dataset):
+ c = a.join(b, how="inner", on=["user_id", "agent_id"]) # type: ignore
+ return c
+
+ assert str(e.value) == (
+ "Key field `user_id` has type `Optional[str]` in left schema but type `int` in right schema for `'[Pipeline:create_pipeline]->join node'`"
+ )
+
+ # After inner join, optional type should be removed
+ with pytest.raises(TypeError) as e:
+
+ @dataset
+ class XYZ:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ timestamp: datetime
+
+ @dataset(index=True)
+ class ABC:
+ user_id: int = field(key=True)
+ agent_id: int = field(key=True)
+ age: int
+ timestamp: datetime
+
+ @dataset
+ class XYZJoinedABC:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ age: int
+ timestamp: datetime
+
+ @pipeline
+ @inputs(XYZ, ABC)
+ def create_pipeline(cls, a: Dataset, b: Dataset):
+ c = a.join(b, how="inner", on=["user_id", "agent_id"]) # type: ignore
+ return c
+
+ assert str(e.value) == (
+ "[TypeError('Field `user_id` has type `int` in `pipeline create_pipeline output value` schema but type `Optional[int]` in `XYZJoinedABC value` schema.')]"
+ )
+
+ # If left join, optional type should not be dropped
+ with pytest.raises(TypeError) as e:
+
+ @dataset
+ class XYZ:
+ user_id: Optional[int]
+ agent_id: int
+ name: str
+ timestamp: datetime
+
+ @dataset(index=True)
+ class ABC:
+ user_id: int = field(key=True)
+ agent_id: int = field(key=True)
+ age: int
+ timestamp: datetime
+
+ @dataset
+ class XYZJoinedABC:
+ user_id: int
+ agent_id: int
+ name: str
+ age: int
+ timestamp: datetime
+
+ @pipeline
+ @inputs(XYZ, ABC)
+ def create_pipeline(cls, a: Dataset, b: Dataset):
+ c = a.join(b, how="left", on=["user_id", "agent_id"]) # type: ignore
+ return c
+
+ assert str(e.value) == (
+ "[TypeError('Field `user_id` has type `Optional[int]` in `pipeline create_pipeline output value` schema but type `int` in `XYZJoinedABC value` schema.')]"
+ )