-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
assign: Include type promotion in assign
- Loading branch information
1 parent
e10f4fa
commit 1d22322
Showing
12 changed files
with
712 additions
and
419 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -606,81 +606,6 @@ class UserInfoDataset: | |
) | ||
|
||
|
||
# On demand datasets are not supported for now. | ||
|
||
# class TestDocumentDataset(unittest.TestCase): | ||
# @mock_client | ||
# def test_log_to_document_dataset(self, client): | ||
# """Log some data to the dataset and check if it is logged correctly.""" | ||
# | ||
# @meta(owner="[email protected]") | ||
# @dataset | ||
# class DocumentContentDataset: | ||
# doc_id: int = field(key=True) | ||
# bert_embedding: Embedding[4] | ||
# fast_text_embedding: Embedding[3] | ||
# num_words: int | ||
# timestamp: datetime = field(timestamp=True) | ||
# | ||
# @on_demand(expires_after="3d") | ||
# @inputs(datetime, int) | ||
# def get_embedding(cls, ts: pd.Series, doc_ids: pd.Series): | ||
# data = [] | ||
# doc_ids = doc_ids.tolist() | ||
# for i in range(len(ts)): | ||
# data.append( | ||
# [ | ||
# doc_ids[i], | ||
# [0.1, 0.2, 0.3, 0.4], | ||
# [1.1, 1.2, 1.3], | ||
# 10 * i, | ||
# ts[i], | ||
# ] | ||
# ) | ||
# columns = [ | ||
# str(cls.doc_id), | ||
# str(cls.bert_embedding), | ||
# str(cls.fast_text_embedding), | ||
# str(cls.num_words), | ||
# str(cls.timestamp), | ||
# ] | ||
# return pd.DataFrame(data, columns=columns), pd.Series( | ||
# [True] * len(ts) | ||
# ) | ||
# | ||
# # Sync the dataset | ||
# client.commit(datasets=[DocumentContentDataset]) | ||
# now = datetime.now(timezone.utc) | ||
# data = [ | ||
# [18232, np.array([1, 2, 3, 4]), np.array([1, 2, 3]), 10, now], | ||
# [ | ||
# 18234, | ||
# np.array([1, 2.2, 0.213, 0.343]), | ||
# np.array([0.87, 2, 3]), | ||
# 9, | ||
# now, | ||
# ], | ||
# [18934, [1, 2.2, 0.213, 0.343], [0.87, 2, 3], 12, now], | ||
# ] | ||
# columns = [ | ||
# "doc_id", | ||
# "bert_embedding", | ||
# "fast_text_embedding", | ||
# "num_words", | ||
# "timestamp", | ||
# ] | ||
# df = pd.DataFrame(data, columns=columns) | ||
# response = client.log("fennel_webhook","DocumentContentDataset", df) | ||
# assert response.status_code == requests.codes.OK, response.json() | ||
# | ||
# # Do some lookups | ||
# doc_ids = pd.Series([18232, 1728, 18234, 18934, 19200, 91012]) | ||
# ts = pd.Series([now, now, now, now, now, now]) | ||
# df, _ = DocumentContentDataset.lookup(ts, doc_id=doc_ids) | ||
# assert df.shape == (6, 5) | ||
# assert df["num_words"].tolist() == [10.0, 9.0, 12, 0, 10.0, 20.0] | ||
|
||
|
||
################################################################################ | ||
# Dataset & Pipelines Unit Tests | ||
################################################################################ | ||
|
@@ -1012,6 +937,25 @@ class Orders: | |
timestamp: datetime | ||
|
||
|
||
@dataset | ||
class OrdersOptional: | ||
uid: Optional[int] | ||
uid_float: float | ||
uid_twice: float | ||
skus: List[int] | ||
prices: List[float] | ||
timestamp: datetime | ||
|
||
@pipeline | ||
@inputs(Orders) | ||
def cast(cls, ds: Dataset): | ||
return ds.assign( | ||
uid=col("uid").astype(Optional[int]), # type: ignore | ||
uid_float=col("uid").astype(float), # type: ignore | ||
uid_twice=(col("uid") * 2.0).astype(float), # type: ignore | ||
) | ||
|
||
|
||
@dataset(index=True) | ||
class Derived: | ||
uid: int = field(key=True) | ||
|
@@ -1066,6 +1010,39 @@ def test_basic_explode(self, client): | |
assert df["price"].tolist()[0] == 10.1 | ||
assert pd.isna(df["price"].tolist()[1]) | ||
|
||
@pytest.mark.integration | ||
@mock | ||
def test_basic_cast(self, client): | ||
# # Sync the dataset | ||
client.commit(message="msg", datasets=[Orders, OrdersOptional]) | ||
# log some rows to the transaction dataset | ||
df = pd.DataFrame( | ||
[ | ||
{ | ||
"uid": 1, | ||
"skus": [1, 2], | ||
"prices": [10.1, 20.0], | ||
"timestamp": "2021-01-01T00:00:00", | ||
}, | ||
{ | ||
"uid": 2, | ||
"skus": [], | ||
"prices": [], | ||
"timestamp": "2021-01-01T00:00:00", | ||
}, | ||
] | ||
) | ||
client.log("webhook", "Orders", df) | ||
client.sleep() | ||
|
||
# do lookup on the WithSquare dataset | ||
df = client.inspect("OrdersOptional") | ||
assert df.shape == (2, 6) | ||
assert df["uid"].tolist() == [1, 2] | ||
assert df["uid_float"].tolist() == [1.0, 2.0] | ||
assert df["uid_twice"].tolist() == [2.0, 4.0] | ||
assert df["skus"].tolist() == [[1, 2], []] | ||
|
||
|
||
class TestBasicAssign(unittest.TestCase): | ||
@pytest.mark.integration | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.