diff --git a/docs/examples/datasets/lookups.py b/docs/examples/datasets/lookups.py index b51557306..db204775f 100644 --- a/docs/examples/datasets/lookups.py +++ b/docs/examples/datasets/lookups.py @@ -36,7 +36,8 @@ class UserFeature: def func(cls, ts: pd.Series, uid: pd.Series): df, _found = User.lookup(ts, uid=uid) return pd.Series( - name="in_home_city", data=df["home_city"] == df["cur_city"] + name="in_home_city", + data=df["home_city"] == df["cur_city"], ) diff --git a/fennel/CHANGELOG.md b/fennel/CHANGELOG.md index 0e38c0d06..ddc166bd7 100644 --- a/fennel/CHANGELOG.md +++ b/fennel/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## [0.18.14] - 2023-11-11 +- Use pd types rather than python types + ## [0.18.12] - 2023-11-08 - Add support for strings in extract_features and extract_historical_features diff --git a/fennel/client/client.py b/fennel/client/client.py index ecbdbbe5d..c6194f801 100644 --- a/fennel/client/client.py +++ b/fennel/client/client.py @@ -492,6 +492,10 @@ def extract_historical_features( raise Exception( f"Timestamp column {timestamp_column} not found in input dataframe." ) + # Convert timestamp column to string to make it JSON serializable + input_dataframe[timestamp_column] = input_dataframe[ + timestamp_column + ].astype(str) extract_historical_input["Pandas"] = input_dataframe.to_dict( orient="list" ) diff --git a/fennel/client_tests/test_dataset.py b/fennel/client_tests/test_dataset.py index 35516a51f..576b4950a 100644 --- a/fennel/client_tests/test_dataset.py +++ b/fennel/client_tests/test_dataset.py @@ -178,7 +178,9 @@ def test_simple_drop_null(self, client): if client.is_integration_client(): client.sleep() ts = pd.Series([now, now, now, now, now, now]) - user_id_keys = pd.Series([18232, 18234, 18235, 18236, 18237, 18238]) + user_id_keys = pd.Series( + [18232, 18234, 18235, 18236, 18237, 18238], dtype="Int64" + ) df, found = UserInfoDataset.lookup(ts, user_id=user_id_keys) assert df.shape == (6, 5) @@ -294,12 +296,11 @@ def test_log_to_dataset(self, client): assert ( response.json()["error"] == "Schema validation failed during data insertion to " - "`UserInfoDataset` [ValueError('Field `age` is of type int, but the column " - "in the dataframe is of type `object`. Error found during checking schema for `UserInfoDataset`.')]" + """`UserInfoDataset`: Failed to cast data logged to column `age` of type `optional(int)`: Unable to parse string "32yrs" at position 0""" ) client.sleep(10) # Do some lookups - user_ids = pd.Series([18232, 18234, 1920]) + user_ids = pd.Series([18232, 18234, 1920], dtype="Int64") lookup_now = datetime.now() + pd.Timedelta(minutes=1) ts = pd.Series([lookup_now, lookup_now, lookup_now]) df, found = UserInfoDataset.lookup( @@ -323,7 +324,7 @@ def test_log_to_dataset(self, client): pd.Timestamp(yday_rounded), ] # Do some lookups with a timestamp - user_ids = pd.Series([18232, 18234]) + user_ids = pd.Series([18232, 18234], dtype="Int64") six_hours_ago = now - pd.Timedelta(hours=6) ts = pd.Series([six_hours_ago, six_hours_ago]) df, found = UserInfoDataset.lookup( diff --git a/fennel/client_tests/test_invalid.py b/fennel/client_tests/test_invalid.py index 3f078861b..5b4dfd677 100644 --- a/fennel/client_tests/test_invalid.py +++ b/fennel/client_tests/test_invalid.py @@ -34,7 +34,7 @@ class MemberActivityDataset: domain: str = field(key=True) hasShortcut: bool country: str - DOMAIN_USED_COUNT: int + domain_used_count: int @meta(owner="test@fennel.ai") @@ -53,7 +53,7 @@ class MemberDataset: @dataset class MemberActivityDatasetCopy: domain: str = field(key=True) - DOMAIN_USED_COUNT: int + domain_used_count: int time: datetime = field(timestamp=True) url: str uid: str @@ -71,11 +71,11 @@ def copy(cls, ds: Dataset): @featureset class DomainFeatures: domain: str = feature(id=1) - DOMAIN_USED_COUNT: int = feature(id=2) + domain_used_count: int = feature(id=2) @extractor(depends_on=[MemberActivityDatasetCopy]) @inputs(Query.domain) - @outputs(domain, DOMAIN_USED_COUNT) + @outputs(domain, domain_used_count) def get_domain_feature(cls, ts: pd.Series, domain: pd.Series): df, found = MemberActivityDatasetCopy.lookup( # type: ignore ts, domain=domain @@ -106,16 +106,16 @@ def test_invalid_sync(self, client): @featureset class DomainFeatures2: domain: str = feature(id=1) - DOMAIN_USED_COUNT: int = feature(id=2) + domain_used_count: int = feature(id=2) @extractor() @inputs(Query.domain) - @outputs(domain, DOMAIN_USED_COUNT) + @outputs(domain, domain_used_count) def get_domain_feature(cls, ts: pd.Series, domain: pd.Series): df, found = MemberActivityDatasetCopy.lookup( # type: ignore ts, domain=domain ) - return df[[str(cls.domain), str(cls.DOMAIN_USED_COUNT)]] + return df[[str(cls.domain), str(cls.domain_used_count)]] class TestInvalidExtractorDependsOn(unittest.TestCase): @@ -133,7 +133,7 @@ class MemberActivityDataset: domain: str = field(key=True) hasShortcut: bool country: str - DOMAIN_USED_COUNT: int + domain_used_count: int @meta(owner="test@fennel.ai") @source(webhook.endpoint("MemberDataset")) @@ -150,7 +150,7 @@ class MemberDataset: @dataset class MemberActivityDatasetCopy: domain: str = field(key=True) - DOMAIN_USED_COUNT: int + domain_used_count: int time: datetime = field(timestamp=True) url: str uid: str @@ -167,11 +167,11 @@ def copy(cls, ds: Dataset): @featureset class DomainFeatures: domain: str = feature(id=1) - DOMAIN_USED_COUNT: int = feature(id=2) + domain_used_count: int = feature(id=2) @extractor(depends_on=[MemberActivityDatasetCopy]) @inputs(Query.domain) - @outputs(domain, DOMAIN_USED_COUNT) + @outputs(domain, domain_used_count) def get_domain_feature(cls, ts: pd.Series, domain: pd.Series): df, found = MemberActivityDatasetCopy.lookup( # type: ignore ts, domain=domain @@ -188,7 +188,7 @@ def get_domain_feature(cls, ts: pd.Series, domain: pd.Series): ) client.extract_features( output_feature_list=[DomainFeatures2], - input_feature_list=[Query], + input_feature_list=[Query.member_id], input_dataframe=pd.DataFrame( { "Query.domain": [ @@ -262,9 +262,8 @@ def test_no_access(self, client): ) else: assert ( - "Extractor `get_domain_feature` in `DomainFeatures2` " - "failed to run with error: name " - "'MemberActivityDatasetCopy' is not defined. " == str(e.value) + str(e.value) + == """Dataset `MemberActivityDataset` is an input to the pipelines: `['copy']` but is not synced. Please add it to the sync call.""" ) @mock diff --git a/fennel/client_tests/test_movie_tickets.py b/fennel/client_tests/test_movie_tickets.py index f79d9a7b0..6df3c1fdd 100644 --- a/fennel/client_tests/test_movie_tickets.py +++ b/fennel/client_tests/test_movie_tickets.py @@ -8,16 +8,15 @@ from fennel import featureset, extractor, feature from fennel.datasets import dataset, field from fennel.lib.metadata import meta -from fennel.lib.schema import inputs, outputs +from fennel.lib.schema import inputs, outputs, between from fennel.sources import source from fennel.datasets import pipeline, Dataset -from fennel.lib.aggregate import Sum +from fennel.lib.aggregate import Sum, LastK, Distinct from fennel.lib.window import Window from fennel.sources import Webhook from fennel.test_lib import mock, MockClient -from typing import List - +from typing import List, Optional client = MockClient() @@ -29,7 +28,7 @@ @dataset class MovieInfo: title: str = field(key=True) - actors: List[str] # can be an empty list + actors: List[Optional[str]] # can be an empty list release: datetime @@ -39,7 +38,7 @@ class MovieInfo: class TicketSale: ticket_id: str title: str - price: int + price: between(int, 0, 1000) # type: ignore at: datetime @@ -47,7 +46,7 @@ class TicketSale: @dataset class ActorStats: name: str = field(key=True) - revenue: int + revenue: between(int, 0, 1000) # type: ignore at: datetime @pipeline(version=1, tier="prod") @@ -101,6 +100,46 @@ def foo(df): ) +@meta(owner="abhay@fennel.ai") +@dataset +class ActorStatsList: + name: str = field(key=True) + revenue: List[between(int, 0, 1000)] # type: ignore + revenue_distinct: List[between(int, 0, 1000)] # type: ignore + at: datetime + + @pipeline(version=1, tier="prod") + @inputs(MovieInfo, TicketSale) + def pipeline_join(cls, info: Dataset, sale: Dataset): + uniq = sale.groupby("ticket_id").first() + c = ( + uniq.join(info, how="inner", on=["title"]) + .explode(columns=["actors"]) + .rename(columns={"actors": "name"}) + ) + # name -> Option[str] + schema = c.schema() + schema["name"] = str + c = c.transform(lambda x: x, schema) + return c.groupby("name").aggregate( + [ + LastK( + window=Window("forever"), + of="price", + into_field="revenue", + limit=10, + dedup=False, + ), + Distinct( + window=Window("forever"), + of="price", + into_field="revenue_distinct", + unordered=True, + ), + ] + ) + + @meta(owner="zaki@fennel.ai") @featureset class RequestFeatures: @@ -132,7 +171,7 @@ def extract_revenue2(cls, ts: pd.Series, name: pd.Series): class TestMovieTicketSale(unittest.TestCase): @mock def test_movie_ticket_sale(self, client): - datasets = [MovieInfo, TicketSale, ActorStats] # type: ignore + datasets = [MovieInfo, TicketSale, ActorStats, ActorStatsList] # type: ignore featuresets = [ActorFeatures, RequestFeatures] client.sync(datasets=datasets, featuresets=featuresets, tier="prod") # type: ignore client.sleep() @@ -166,10 +205,10 @@ def test_movie_ticket_sale(self, client): two_hours_ago = now - timedelta(hours=2) columns = ["ticket_id", "title", "price", "at"] data = [ - ["1", "Titanic", 50, one_hour_ago], - ["2", "Titanic", 100, one_day_ago], - ["3", "Jumanji", 25, one_hour_ago], - ["4", "The Matrix", 50, two_hours_ago], # no match + ["1", "Titanic", "50", one_hour_ago], + ["2", "Titanic", "100", one_day_ago], + ["3", "Jumanji", "25", one_hour_ago], + ["4", "The Matrix", "50", two_hours_ago], # no match ["5", "Great Gatbsy", 49, one_hour_ago], ] df = pd.DataFrame(data, columns=columns) diff --git a/fennel/client_tests/test_outbrain.py b/fennel/client_tests/test_outbrain.py index a9f994329..191d3e2e8 100644 --- a/fennel/client_tests/test_outbrain.py +++ b/fennel/client_tests/test_outbrain.py @@ -26,7 +26,9 @@ @source( s3.bucket("fennel-demo-data", prefix="outbrain/page_views_filter.csv"), every="1d", + tier="prod", ) +@source(webhook.endpoint("PageViews"), tier="dev") @meta(owner="xiao@fennel.ai") @dataset class PageViews: @@ -101,16 +103,16 @@ def extract(cls, ts: pd.Series, uuids: pd.Series): @pytest.mark.integration @mock def test_outbrain(client): - fake_PageViews = PageViews.with_source(webhook.endpoint("PageViews")) client.sync( datasets=[ - fake_PageViews, + PageViews, PageViewsByUser, ], featuresets=[ Request, UserPageViewFeatures, ], + tier="dev", ) df = pd.read_csv("fennel/client_tests/data/page_views_sample.csv") # Current time in ms @@ -143,7 +145,7 @@ def test_outbrain(client): Request, UserPageViewFeatures, ], - input_feature_list=[Request], + input_feature_list=[Request.uuid, Request.document_id], input_dataframe=input_df, ) assert feature_df.shape[0] == 347 diff --git a/fennel/client_tests/test_search.py b/fennel/client_tests/test_search.py index e7bebe01e..3e98caac0 100644 --- a/fennel/client_tests/test_search.py +++ b/fennel/client_tests/test_search.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd import pytest -from typing import Dict, List +from typing import Dict, List, Optional import fennel._vendor.requests as requests from fennel import sources @@ -14,7 +14,7 @@ from fennel.lib.aggregate import Count, Sum from fennel.lib.includes import includes from fennel.lib.metadata import meta -from fennel.lib.schema import Embedding +from fennel.lib.schema import Embedding, oneof from fennel.lib.schema import inputs, outputs from fennel.lib.window import Window from fennel.sources import source @@ -217,7 +217,7 @@ def top_words_count(cls, ds: Dataset): class UserActivity: user_id: int doc_id: int - action_type: str + action_type: oneof(str, ["view", "edit"]) # type: ignore view_time: float timestamp: datetime @@ -597,7 +597,7 @@ def test_search_e2e(self, client): DocumentFeatures, DocumentContentFeatures, ], - input_feature_list=[Query], + input_feature_list=[Query.doc_id, Query.user_id], input_dataframe=input_df, ) assert df.shape == (2, 15) diff --git a/fennel/client_tests/test_social_network.py b/fennel/client_tests/test_social_network.py index 7f35316d9..68e8241f1 100644 --- a/fennel/client_tests/test_social_network.py +++ b/fennel/client_tests/test_social_network.py @@ -36,7 +36,7 @@ class UserInfo: @meta(owner="data-eng@myspace.com") class PostInfo: title: str - category: str + category: str # type: ignore post_id: int = field(key=True) timestamp: datetime @@ -168,7 +168,7 @@ def test_social_network(client): user_data_df = pd.read_csv("fennel/client_tests/data/user_data.csv") post_data_df = pd.read_csv("fennel/client_tests/data/post_data.csv") view_data_df = pd.read_csv("fennel/client_tests/data/view_data_sampled.csv") - ts = datetime(2018, 1, 1, 0, 0, 0) + ts = "2018-01-01 00:00:00" user_data_df["timestamp"] = ts post_data_df["timestamp"] = ts view_data_df["time_stamp"] = view_data_df["time_stamp"].apply( @@ -194,7 +194,7 @@ def test_social_network(client): feature_df = client.extract_features( output_feature_list=[UserFeatures], - input_feature_list=[Request], + input_feature_list=[Request.user_id, Request.category], input_dataframe=pd.DataFrame( { "Request.user_id": [ diff --git a/fennel/datasets/datasets.py b/fennel/datasets/datasets.py index bfd4e668b..96dd088c7 100644 --- a/fennel/datasets/datasets.py +++ b/fennel/datasets/datasets.py @@ -53,10 +53,14 @@ from fennel.lib.schema import ( dtype_to_string, get_primitive_dtype, + fennel_is_optional, + fennel_get_optional_inner, + get_pd_dtype, FENNEL_INPUTS, is_hashable, parse_json, get_fennel_struct, + get_python_type_from_pd, FENNEL_STRUCT_SRC_CODE, FENNEL_STRUCT_DEPENDENCIES_SRC_CODE, ) @@ -88,6 +92,7 @@ "fqn", ] +primitive_numeric_types = [int, float, pd.Int64Dtype, pd.Float64Dtype] # --------------------------------------------------------------------- # Field @@ -327,7 +332,7 @@ def dsschema(self): return DSSchema( keys=inp_keys, values={ - f: dtype + f: get_pd_dtype(dtype) for f, dtype in self.new_schema.items() if f not in inp_keys.keys() and f != input_schema.timestamp }, @@ -363,7 +368,7 @@ def signature(self): def dsschema(self): input_schema = self.node.dsschema() - input_schema.update_column(self.column, self.output_type) + input_schema.update_column(self.column, get_pd_dtype(self.output_type)) return input_schema @@ -407,26 +412,31 @@ def dsschema(self): values = {} for agg in self.aggregates: if isinstance(agg, Count): - values[agg.into_field] = int + values[agg.into_field] = pd.Int64Dtype elif isinstance(agg, Sum): dtype = input_schema.get_type(agg.of) - if dtype not in [int, float]: + dtype = get_primitive_dtype(dtype) + if dtype not in primitive_numeric_types: raise TypeError( f"Cannot sum field {agg.of} of type {dtype_to_string(dtype)}" ) - values[agg.into_field] = dtype # type: ignore + values[agg.into_field] = dtype elif isinstance(agg, Min) or isinstance(agg, Max): - values[agg.into_field] = input_schema.get_type(agg.of) + dtype = input_schema.get_type(agg.of) + dtype = get_primitive_dtype(dtype) + values[agg.into_field] = dtype elif isinstance(agg, Distinct): dtype = input_schema.get_type(agg.of) - values[agg.into_field] = List[dtype] # type: ignore + list_type = get_python_type_from_pd(dtype) + values[agg.into_field] = List[list_type] # type: ignore elif isinstance(agg, Average): - values[agg.into_field] = float # type: ignore + values[agg.into_field] = pd.Float64Dtype # type: ignore elif isinstance(agg, LastK): dtype = input_schema.get_type(agg.of) - values[agg.into_field] = List[dtype] # type: ignore + list_type = get_python_type_from_pd(dtype) + values[agg.into_field] = List[list_type] # type: ignore elif isinstance(agg, Stddev): - values[agg.into_field] = float # type: ignore + values[agg.into_field] = pd.Float64Dtype # type: ignore else: raise TypeError(f"Unknown aggregate type {type(agg)}") return DSSchema( @@ -498,6 +508,7 @@ def dsschema(self): for c in self.columns: # extract type T from List[t] dsschema.values[c] = Optional[get_args(dsschema.values[c])[0]] + dsschema.values[c] = get_pd_dtype(dsschema.values[c]) return dsschema @@ -918,17 +929,6 @@ def wrap(c: Type[T]) -> Dataset: return wrap(cls) -def fennel_is_optional(type_): - return ( - typing.get_origin(type_) is Union - and type(None) is typing.get_args(type_)[1] - ) - - -def fennel_get_optional_inner(type_): - return typing.get_args(type_)[0] - - # Fennel implementation of get_type_hints which does not error on forward # references not being types such as Embedding[4]. def f_get_type_hints(obj): @@ -1202,9 +1202,9 @@ def with_source( def dsschema(self): return DSSchema( - keys={f.name: f.dtype for f in self._fields if f.key}, + keys={f.name: get_pd_dtype(f.dtype) for f in self._fields if f.key}, values={ - f.name: f.dtype + f.name: get_pd_dtype(f.dtype) for f in self._fields if not f.key and f.name != self._timestamp_field }, @@ -1391,15 +1391,7 @@ def _get_pipelines(self) -> List[Pipeline]: def _validate_pipelines(self, pipelines: List[Pipeline]): exceptions = [] - ds_schema = DSSchema( - keys={f.name: f.dtype for f in self.fields if f.key}, - values={ - f.name: f.dtype - for f in self.fields - if not f.key and f.name != self._timestamp_field - }, - timestamp=self.timestamp_field, - ) + ds_schema = self.dsschema() for pipeline in pipelines: pipeline_schema = pipeline.get_terminal_schema() @@ -1571,7 +1563,11 @@ class DSSchema: name: str = "" def schema(self) -> Dict[str, Type]: - return {**self.keys, **self.values, self.timestamp: datetime.datetime} + schema = {**self.keys, **self.values, self.timestamp: datetime.datetime} + # Convert int -> Int64, float -> Float64 and string -> String + for k, v in schema.items(): + schema[k] = get_pd_dtype(v) + return schema def fields(self) -> List[str]: return ( @@ -1647,17 +1643,17 @@ def drop_column(self, name: str): else: raise Exception(f"field {name} not found in schema of {self.name}") - def update_column(self, name: str, tpe: Type): + def update_column(self, name: str, type: Type): if name in self.keys: - self.keys[name] = tpe + self.keys[name] = type elif name in self.values: - self.values[name] = tpe + self.values[name] = type elif name == self.timestamp: raise Exception( f"cannot assign timestamp field {name} from {self.name}" ) else: - self.values[name] = tpe # Add to values + self.values[name] = type # Add to values def matches( self, other_schema: DSSchema, this_name: str, other_name: str @@ -1734,16 +1730,7 @@ def visit(self, obj) -> DSSchema: return vis def visitDataset(self, obj) -> DSSchema: - return DSSchema( - keys={f.name: f.dtype for f in obj.fields if f.key}, - values={ - f.name: f.dtype - for f in obj.fields - if not f.key and f.name != obj.timestamp_field - }, - timestamp=obj.timestamp_field, - name=f"'[Dataset:{obj._name}]'", - ) + return obj.dsschema() def visitTransform(self, obj) -> DSSchema: input_schema = self.visit(obj.node) @@ -1762,7 +1749,7 @@ def visitTransform(self, obj) -> DSSchema: f"Key field {name} must be present in schema of " f"{node_name}." ) - if dtype != obj.new_schema[name]: + if dtype != get_pd_dtype(obj.new_schema[name]): raise TypeError( f"Key field {name} has type {dtype_to_string(dtype)} in " f"input schema " @@ -1774,7 +1761,7 @@ def visitTransform(self, obj) -> DSSchema: return DSSchema( keys=inp_keys, values={ - f: dtype + f: get_pd_dtype(dtype) for f, dtype in obj.new_schema.items() if f not in inp_keys.keys() and f != input_schema.timestamp }, @@ -1808,7 +1795,7 @@ def visitAggregate(self, obj) -> DSSchema: f"type `{dtype_to_string(dtype)}`, as it is not " # type: ignore f"hashable" ) - values[agg.into_field] = int + values[agg.into_field] = pd.Int64Dtype elif isinstance(agg, Distinct): if agg.of is None: raise ValueError( @@ -1821,31 +1808,33 @@ def visitAggregate(self, obj) -> DSSchema: f"type `{dtype_to_string(dtype)}`, as it is not hashable" # type: ignore ) - values[agg.into_field] = List[dtype] # type: ignore + list_type = get_python_type_from_pd(dtype) + values[agg.into_field] = List[list_type] # type: ignore elif isinstance(agg, Sum): dtype = input_schema.get_type(agg.of) - if get_primitive_dtype(dtype) not in [int, float]: + if get_primitive_dtype(dtype) not in primitive_numeric_types: raise TypeError( f"Cannot sum field `{agg.of}` of type `{dtype_to_string(dtype)}`" ) values[agg.into_field] = dtype # type: ignore elif isinstance(agg, Average): dtype = input_schema.get_type(agg.of) - if get_primitive_dtype(dtype) not in [int, float]: + if get_primitive_dtype(dtype) not in primitive_numeric_types: raise TypeError( f"Cannot take average of field `{agg.of}` of type `{dtype_to_string(dtype)}`" ) - values[agg.into_field] = float # type: ignore + values[agg.into_field] = pd.Float64Dtype # type: ignore elif isinstance(agg, LastK): dtype = input_schema.get_type(agg.of) - values[agg.into_field] = List[dtype] # type: ignore + list_type = get_python_type_from_pd(dtype) + values[agg.into_field] = List[list_type] # type: ignore elif isinstance(agg, Min): dtype = input_schema.get_type(agg.of) - if get_primitive_dtype(dtype) not in [int, float]: + if get_primitive_dtype(dtype) not in primitive_numeric_types: raise TypeError( f"invalid min: type of field `{agg.of}` is not int or float" ) - if get_primitive_dtype(dtype) == int and ( + if get_primitive_dtype(dtype) == pd.Int64Dtype and ( int(agg.default) != agg.default ): raise TypeError( @@ -1854,11 +1843,11 @@ def visitAggregate(self, obj) -> DSSchema: values[agg.into_field] = dtype # type: ignore elif isinstance(agg, Max): dtype = input_schema.get_type(agg.of) - if get_primitive_dtype(dtype) not in [int, float]: + if get_primitive_dtype(dtype) not in primitive_numeric_types: raise TypeError( f"invalid max: type of field `{agg.of}` is not int or float" ) - if get_primitive_dtype(dtype) == int and ( + if get_primitive_dtype(dtype) == pd.Int64Dtype and ( int(agg.default) != agg.default ): raise TypeError( @@ -1867,11 +1856,11 @@ def visitAggregate(self, obj) -> DSSchema: values[agg.into_field] = dtype # type: ignore elif isinstance(agg, Stddev): dtype = input_schema.get_type(agg.of) - if get_primitive_dtype(dtype) not in [int, float]: + if get_primitive_dtype(dtype) not in primitive_numeric_types: raise TypeError( f"Cannot get standard deviation of field {agg.of} of type {dtype_to_string(dtype)}" ) - values[agg.into_field] = float # type: ignore + values[agg.into_field] = pd.Float64Dtype # type: ignore else: raise TypeError(f"Unknown aggregate type {type(agg)}") return DSSchema( @@ -2040,7 +2029,7 @@ def visitDropNull(self, obj): output_schema_name = f"'[Pipeline:{self.pipeline_name}]->dropnull node'" if obj.columns is None or len(obj.columns) == 0: raise ValueError( - f"invalid dropnull - {output_schema_name} must have at least one column" + f"invalid dropnull - `{output_schema_name}` must have at least one column" ) for field in obj.columns: if ( @@ -2048,11 +2037,11 @@ def visitDropNull(self, obj): or field == input_schema.timestamp ): raise ValueError( - f"invalid dropnull column {field} not present in {input_schema.name}" + f"invalid dropnull column `{field}` not present in `{input_schema.name}`" ) if not fennel_is_optional(input_schema.get_type(field)): raise ValueError( - f"invalid dropnull {field} has type {input_schema.get_type(field)} expected Optional type" + f"invalid dropnull `{field}` has type `{dtype_to_string(input_schema.get_type(field))}` expected Optional type" ) output_schema = obj.dsschema() output_schema.name = output_schema_name diff --git a/fennel/datasets/test_dataset.py b/fennel/datasets/test_dataset.py index 31cf12bc1..7e00166f9 100644 --- a/fennel/datasets/test_dataset.py +++ b/fennel/datasets/test_dataset.py @@ -2473,10 +2473,10 @@ def extract_info(df: pd.DataFrame) -> pd.DataFrame: ] assert activity.schema() == { - "action_type": float, + "action_type": pd.Float64Dtype, "timestamp": datetime, - "amount": Optional[float], - "user_id": int, + "amount": Optional[pd.Float64Dtype], + "user_id": pd.Int64Dtype, } filtered_ds = activity.filter( @@ -2484,26 +2484,26 @@ def extract_info(df: pd.DataFrame) -> pd.DataFrame: ) assert filtered_ds.schema() == { - "action_type": float, - "amount": Optional[float], + "action_type": pd.Float64Dtype, + "amount": Optional[pd.Float64Dtype], "timestamp": datetime, - "user_id": int, + "user_id": pd.Int64Dtype, } x = filtered_ds.transform( extract_info, schema={ - "transaction_amount": float, - "merchant_id": int, - "user_id": int, + "transaction_amount": pd.Float64Dtype, + "merchant_id": pd.Int64Dtype, + "user_id": pd.Int64Dtype, "timestamp": datetime, }, ) assert x.schema() == { - "merchant_id": int, - "transaction_amount": float, - "user_id": int, + "merchant_id": pd.Int64Dtype, + "transaction_amount": pd.Float64Dtype, + "user_id": pd.Int64Dtype, "timestamp": datetime, } @@ -2511,21 +2511,21 @@ def extract_info(df: pd.DataFrame) -> pd.DataFrame: "transaction_amount", int, lambda df: df["user_id"] * 2 ) assert assign_ds.schema() == { - "action_type": float, + "action_type": pd.Float64Dtype, "timestamp": datetime, - "transaction_amount": int, - "user_id": int, - "amount": Optional[float], + "transaction_amount": pd.Int64Dtype, + "user_id": pd.Int64Dtype, + "amount": Optional[pd.Float64Dtype], } assign_ds_str = activity.assign( "user_id", str, lambda df: str(df["user_id"]) ) assert assign_ds_str.schema() == { - "action_type": float, + "action_type": pd.Float64Dtype, "timestamp": datetime, - "user_id": str, - "amount": Optional[float], + "user_id": pd.StringDtype, + "amount": Optional[pd.Float64Dtype], } return x diff --git a/fennel/datasets/test_invalid_dataset.py b/fennel/datasets/test_invalid_dataset.py index 4ff396aec..0e4b65be3 100644 --- a/fennel/datasets/test_invalid_dataset.py +++ b/fennel/datasets/test_invalid_dataset.py @@ -2,7 +2,6 @@ import pytest from typing import Optional, List, Union - from fennel.datasets import dataset, pipeline, field, Dataset from fennel.lib.aggregate import Count, Average, Stddev, Distinct from fennel.lib.expectations import ( @@ -590,6 +589,36 @@ def create_pipeline(a: Dataset): # type: ignore ) +@mock +def test_pipeline_input_validation_during_sync(client): + with pytest.raises(ValueError) as e: + + @meta(owner="eng@fennel.ai") + @dataset + class XYZ: + user_id: int + name: str + timestamp: datetime + + @meta(owner="eng@fennel.ai") + @dataset + class ABCDataset: + user_id: int + name: str + timestamp: datetime + + @pipeline(version=1) + @inputs(XYZ) + def create_pipeline(cls, a: Dataset): + return a + + client.sync(datasets=[ABCDataset]) + assert ( + str(e.value) + == "Dataset `XYZ` is an input to the pipelines: `['create_pipeline']` but is not synced. Please add it to the sync call." + ) + + def test_dataset_incorrect_join(): with pytest.raises(ValueError) as e: diff --git a/fennel/datasets/test_schema_validator.py b/fennel/datasets/test_schema_validator.py index 18f38b282..174eca82a 100644 --- a/fennel/datasets/test_schema_validator.py +++ b/fennel/datasets/test_schema_validator.py @@ -918,7 +918,6 @@ class C1: def drop_null_noargs(cls, c: Dataset): return c.dropnull() - print(e.value) assert ( str(e.value) == """[TypeError('Field `b2` has type `int` in `pipeline drop_null_noargs output value` schema but type `Optional[int]` in `C1 value` schema.')]""" @@ -940,7 +939,7 @@ def drop_null_non_opt(cls, c: Dataset): assert ( str(e.value) - == "invalid dropnull b1 has type expected Optional type" + == "invalid dropnull `b1` has type `int` expected Optional type" ) with pytest.raises(ValueError) as e: @@ -959,7 +958,7 @@ def drop_null_non_present(cls, c: Dataset): assert ( str(e.value) - == "invalid dropnull column b4 not present in '[Dataset:C]'" + == "invalid dropnull column `b4` not present in `'[Dataset:C]'`" ) diff --git a/fennel/lib/schema/__init__.py b/fennel/lib/schema/__init__.py index 913e46b21..112d4bd62 100644 --- a/fennel/lib/schema/__init__.py +++ b/fennel/lib/schema/__init__.py @@ -1,9 +1,12 @@ from fennel.lib.schema.schema import ( get_datatype, + get_pd_dtype, between, oneof, dtype_to_string, data_schema_check, + fennel_is_optional, + fennel_get_optional_inner, FENNEL_INPUTS, FENNEL_OUTPUTS, Embedding, @@ -11,6 +14,7 @@ inputs, outputs, get_primitive_dtype, + get_python_type_from_pd, is_hashable, get_fennel_struct, parse_json, diff --git a/fennel/lib/schema/schema.py b/fennel/lib/schema/schema.py index 3f521492b..0d5c55cbc 100644 --- a/fennel/lib/schema/schema.py +++ b/fennel/lib/schema/schema.py @@ -4,6 +4,7 @@ import inspect import re import sys +import typing from dataclasses import dataclass from datetime import datetime from textwrap import dedent @@ -20,6 +21,8 @@ get_origin, get_args, ForwardRef, + Type, + Optional, ) import fennel.gen.schema_pb2 as schema_proto @@ -53,6 +56,14 @@ def _optional_inner(type_): def dtype_to_string(type_: Any) -> str: if _is_optional(type_): return f"Optional[{dtype_to_string(_optional_inner(type_))}]" + if type_ == pd.Int64Dtype: + return "int" + if type_ == pd.Float64Dtype: + return "float" + if type_ == pd.StringDtype: + return "str" + if type_ == pd.BooleanDtype: + return "bool" if isinstance(type_, type): return type_.__name__ return str(type_) @@ -63,7 +74,7 @@ def get_primitive_dtype(dtype): if isinstance(dtype, oneof) or isinstance(dtype, between): return dtype.dtype if isinstance(dtype, regex): - return str + return pd.StringDtype return dtype @@ -376,24 +387,34 @@ def to_proto(self): ) +def fennel_is_optional(type_): + return ( + typing.get_origin(type_) is Union + and type(None) is typing.get_args(type_)[1] + ) + + +def fennel_get_optional_inner(type_): + return typing.get_args(type_)[0] + + def get_datatype(type_: Any) -> schema_proto.DataType: - # typing.Optional[x] is an alias for typing.Union[x, None] - if _get_origin(type_) is Union and type(None) is _get_args(type_)[1]: + if fennel_is_optional(type_): dtype = get_datatype(_get_args(type_)[0]) return schema_proto.DataType( optional_type=schema_proto.OptionalType(of=dtype) ) - elif type_ is int: + elif type_ is int or type_ is np.int64 or type_ == pd.Int64Dtype: return schema_proto.DataType(int_type=schema_proto.IntType()) - elif type_ is float: + elif type_ is float or type_ is np.float64 or type_ == pd.Float64Dtype: return schema_proto.DataType(double_type=schema_proto.DoubleType()) - elif type_ is str: + elif type_ is str or type_ is np.str_ or type_ == pd.StringDtype: return schema_proto.DataType(string_type=schema_proto.StringType()) - elif type_ is datetime: + elif type_ is datetime or type_ is np.datetime64: return schema_proto.DataType( timestamp_type=schema_proto.TimestampType() ) - elif type_ is bool: + elif type_ is bool or type_ == pd.BooleanDtype: return schema_proto.DataType(bool_type=schema_proto.BoolType()) elif _get_origin(type_) is list: return schema_proto.DataType( @@ -439,7 +460,41 @@ def get_datatype(type_: Any) -> schema_proto.DataType: raise ValueError(f"Cannot serialize type {type_}.") -# TODO(Aditya): Add support for nested schema checks for arrays and maps +def get_pd_dtype(type: Type): + """ + Convert int -> Int64, float -> Float64 and string -> String, bool -> Bool + """ + if type == int: + return pd.Int64Dtype + elif type == float: + return pd.Float64Dtype + elif type == str: + return pd.StringDtype + elif type == bool: + return pd.BooleanDtype + elif fennel_is_optional(type): + return Optional[get_pd_dtype(fennel_get_optional_inner(type))] + else: + return type + + +def get_python_type_from_pd(type): + if type == pd.Int64Dtype: + return int + elif type == pd.Float64Dtype: + return float + elif type == pd.StringDtype: + return str + elif type == pd.BooleanDtype: + return bool + elif fennel_is_optional(type): + return Optional[ + get_python_type_from_pd(fennel_get_optional_inner(type)) + ] + return type + + +# TODO(Aditya): Add support for nested schema checks for arrays and maps and structs def _validate_field_in_df( field: schema_proto.Field, df: pd.DataFrame, @@ -466,7 +521,7 @@ def _validate_field_in_df( if not is_nullable and df[name].isnull().any(): raise ValueError( f"Field `{name}` is not nullable, but the " - f"column in the dataframe has null values. Error found during" + f"column in the dataframe has null values. Error found during " f"checking schema for `{entity_name}`." ) @@ -529,7 +584,7 @@ def _validate_field_in_df( f"checking schema for `{entity_name}`." ) elif dtype == schema_proto.DataType(bool_type=schema_proto.BoolType()): - if df[name].dtype != np.bool_: + if df[name].dtype != np.bool_ and df[name].dtype != pd.BooleanDtype(): raise ValueError( f"Field `{name}` is of type bool, but the " f"column in the dataframe is of type " @@ -718,7 +773,14 @@ def is_hashable(dtype: Any) -> bool: and type(None) is _get_args(primitive_type)[1] ): return is_hashable(_get_args(primitive_type)[0]) - elif primitive_type in [int, str, bool]: + elif primitive_type in [ + int, + str, + bool, + pd.Int64Dtype, + pd.StringDtype, + pd.BooleanDtype, + ]: return True elif _get_origin(primitive_type) is list: return is_hashable(_get_args(primitive_type)[0]) diff --git a/fennel/lib/to_proto/to_proto.py b/fennel/lib/to_proto/to_proto.py index 69da639f0..dbcc20243 100644 --- a/fennel/lib/to_proto/to_proto.py +++ b/fennel/lib/to_proto/to_proto.py @@ -176,7 +176,7 @@ def dataset_to_proto(ds: Dataset) -> ds_proto.CoreDataset: return ds_proto.CoreDataset( name=ds.__name__, metadata=get_metadata_proto(ds), - dsschema=_fields_to_dsschema(ds.fields), + dsschema=fields_to_dsschema(ds.fields), history=history, retention=retention, field_metadata=_field_metadata(ds._fields), @@ -193,7 +193,7 @@ def dataset_to_proto(ds: Dataset) -> ds_proto.CoreDataset: ) -def _fields_to_dsschema(fields: List[Field]) -> schema_proto.DSSchema: +def fields_to_dsschema(fields: List[Field]) -> schema_proto.DSSchema: keys = [] values = [] ts = None diff --git a/fennel/test_lib/execute_aggregation.py b/fennel/test_lib/execute_aggregation.py index b940c9ce8..0c1475331 100644 --- a/fennel/test_lib/execute_aggregation.py +++ b/fennel/test_lib/execute_aggregation.py @@ -4,11 +4,12 @@ from math import sqrt import pandas as pd -from typing import Dict, List +from typing import Dict, List, Type from fennel.lib.aggregate import AggregateType, Distinct from fennel.lib.aggregate import Count, Sum, Average, LastK, Min, Max, Stddev from fennel.lib.duration import duration_to_timedelta +from fennel.lib.schema import get_pd_dtype # Type of data, 1 indicates insert -1 indicates delete. FENNEL_ROW_TYPE = "__fennel_row_type__" @@ -333,6 +334,7 @@ def get_aggregated_df( aggregate: AggregateType, ts_field: str, key_fields: List[str], + output_dtype: Type, ) -> pd.DataFrame: df = input_df.copy() df[FENNEL_ROW_TYPE] = 1 @@ -417,4 +419,12 @@ def get_aggregated_df( subset = key_fields + [ts_field] df = df.drop_duplicates(subset=subset, keep="last") df = df.reset_index(drop=True) + pd_dtype = get_pd_dtype(output_dtype) + if pd_dtype in [ + pd.Int64Dtype, + pd.BooleanDtype, + pd.Float64Dtype, + pd.StringDtype, + ]: + df[aggregate.into_field] = df[aggregate.into_field].astype(pd_dtype()) return df diff --git a/fennel/test_lib/executor.py b/fennel/test_lib/executor.py index 1bddf5836..fda1fcabb 100644 --- a/fennel/test_lib/executor.py +++ b/fennel/test_lib/executor.py @@ -114,9 +114,15 @@ def visitTransform(self, obj) -> Optional[NodeRet]: sorted_df = t_df.sort_values(input_ret.timestamp_field) # Cast sorted_df to obj.schema() for col_name, col_type in obj.schema().items(): - if col_type in [float, int, str, bool]: + if col_type in [ + pd.BooleanDtype, + pd.Int64Dtype, + pd.Float64Dtype, + pd.StringDtype, + ]: + sorted_df[col_name] = sorted_df[col_name].astype(col_type()) + elif col_type in [int, float, bool, str]: sorted_df[col_name] = sorted_df[col_name].astype(col_type) - return NodeRet( sorted_df, input_ret.timestamp_field, input_ret.key_fields ) @@ -156,6 +162,7 @@ def visitAggregate(self, obj): # of fields to the dataframe that contains the aggregate values # for each timestamp for that field. result = {} + output_schema = obj.dsschema() for aggregate in obj.aggregates: # Select the columns that are needed for the aggregate # and drop the rest @@ -163,8 +170,13 @@ def visitAggregate(self, obj): if not isinstance(aggregate, Count) or aggregate.unique: fields.append(aggregate.of) filtered_df = df[fields] + result[aggregate.into_field] = get_aggregated_df( - filtered_df, aggregate, input_ret.timestamp_field, obj.keys + filtered_df, + aggregate, + input_ret.timestamp_field, + obj.keys, + output_schema.values[aggregate.into_field], ) return NodeRet( pd.DataFrame(), input_ret.timestamp_field, obj.keys, result, True @@ -321,7 +333,6 @@ def emited_ts(row): columns=[tmp_ts_low, ts_query_field, tmp_left_ts, tmp_right_ts], inplace=True, ) - # sort the dataframe by the timestamp sorted_df = merged_df.sort_values(left_timestamp_field) return NodeRet(sorted_df, left_timestamp_field, input_ret.key_fields) diff --git a/fennel/test_lib/integration_client.py b/fennel/test_lib/integration_client.py index 90820d4d2..d8446bd8f 100644 --- a/fennel/test_lib/integration_client.py +++ b/fennel/test_lib/integration_client.py @@ -9,6 +9,12 @@ try: import pyarrow as pa + import sys + + sys.path.insert( + 0, + "/nix/store/wrkjic4qykdb8gkg271b388cdqhzxf7d-python3-3.11.5-env/lib/python3.11/site-packages", + ) from fennel_client_lib import RustClient # type: ignore from fennel_dataset import lookup # type: ignore except ImportError: @@ -69,7 +75,9 @@ def log(self, webhook: str, endpoint: str, df: pd.DataFrame): return FakeResponse(200, "OK") def sync( - self, datasets: List[Dataset] = [], featuresets: List[Featureset] = [] + self, + datasets: List[Dataset] = [], + featuresets: List[Featureset] = [], ): self.to_register_objects = [] self.to_register = set() @@ -79,7 +87,7 @@ def sync( self.add(featureset) sync_request = self._get_sync_request_proto() - self._client.sync(sync_request.SerializeToString()) + self._client.sync(sync_request.SerializeToString(), _dry_run=False) time.sleep(1.1) return FakeResponse(200, "OK") diff --git a/fennel/test_lib/mock_client.py b/fennel/test_lib/mock_client.py index 98e28e51c..6bde64c52 100644 --- a/fennel/test_lib/mock_client.py +++ b/fennel/test_lib/mock_client.py @@ -36,7 +36,7 @@ is_extractor_graph_cyclic, ) from fennel.lib.includes import includes # noqa -from fennel.lib.schema import data_schema_check +from fennel.lib.schema import data_schema_check, get_datatype from fennel.lib.to_proto import ( dataset_to_proto, features_from_fs, @@ -142,6 +142,10 @@ def dataset_lookup_impl( for col, right_df in data_dict.items(): right_df[FENNEL_LOOKUP] = True right_df[FENNEL_TIMESTAMP] = right_df[timestamp_field] + # Cast the column in keys to the same dtype as the column in right_df + for col in keys: + if col in right_df and keys[col].dtype != right_df[col].dtype: + keys[col] = keys[col].astype(right_df[col].dtype) try: df = pd.merge_asof( left=keys, @@ -155,8 +159,8 @@ def dataset_lookup_impl( raise ValueError( f"Error while performing lookup on dataset {cls_name} " f"with key fields {join_columns}, key length " - f"{keys.shape}, and shape of dataset being" - f"looked up{right_df.shape}: {e} " + f"{keys.shape}, and shape of dataset being " + f"looked up {right_df.shape}: {e} " ) df.drop(timestamp_field, axis=1, inplace=True) df = df.set_index(FENNEL_ORDER).loc[np.arange(len(df)), :] @@ -176,6 +180,9 @@ def dataset_lookup_impl( right_df = data[cls_name] right_df[FENNEL_LOOKUP] = True right_df[FENNEL_TIMESTAMP] = right_df[timestamp_field] + for col in keys: + if col in right_df and keys[col].dtype != right_df[col].dtype: + keys[col] = keys[col].astype(right_df[col].dtype) try: df = pd.merge_asof( left=keys, @@ -338,6 +345,8 @@ def sync( datasets = [] if featuresets is None: featuresets = [] + + input_datasets_for_pipelines = defaultdict(list) for dataset in datasets: if not isinstance(dataset, Dataset): raise TypeError( @@ -370,10 +379,20 @@ def sync( if x.tier.is_entity_selected(tier) and x.active ] sync_validation_for_pipelines(selected_pipelines, dataset._name) + for pipeline in selected_pipelines: for input in pipeline.inputs: + input_datasets_for_pipelines[input._name].append( + pipeline.name + ) self.listeners[input._name].append(pipeline) + # Check that input_datasets_for_pipelines is a subset of self.datasets. + for ds, pipelines in input_datasets_for_pipelines.items(): + if ds not in self.datasets: + raise ValueError( + f"Dataset `{ds}` is an input to the pipelines: `{pipelines}` but is not synced. Please add it to the sync call." + ) for featureset in featuresets: if not isinstance(featureset, Featureset): raise TypeError( @@ -445,8 +464,9 @@ def extract_features( ): input_feature_names.append(input_feature) elif isinstance(input_feature, Featureset): - input_feature_names.extend( - [f.fqn_ for f in input_feature.features] + raise Exception( + "Providing a featureset as input is deprecated. " + f"List the features instead. {[f.fqn() for f in input_feature.features]}." ) # Check if the input dataframe has all the required features if not set(input_feature_names).issubset(set(input_dataframe.columns)): @@ -455,6 +475,16 @@ def extract_features( f"Required features: {input_feature_names}. " f"Input dataframe columns: {input_dataframe.columns}" ) + for input_col, feature in zip( + input_dataframe.columns, input_feature_list + ): + if isinstance(feature, str): + continue + col_type = get_datatype(feature.dtype) # type: ignore + input_dataframe[input_col] = cast_col_to_dtype( + input_dataframe[input_col], col_type + ) + extractors = get_extractor_order( input_feature_list, output_feature_list, self.extractors ) @@ -493,6 +523,7 @@ def extract_historical_features( if input_dataframe.empty: return pd.DataFrame() timestamps = input_dataframe[timestamp_column] + timestamps = pd.to_datetime(timestamps) input_feature_names = [] for inp_feature in input_feature_list: if isinstance(inp_feature, Feature): @@ -500,8 +531,9 @@ def extract_historical_features( elif isinstance(inp_feature, str) and is_valid_feature(inp_feature): input_feature_names.append(inp_feature) elif isinstance(inp_feature, Featureset): - input_feature_names.extend( - [f.fqn_ for f in inp_feature.features] + raise Exception( + "Providing a featureset as input is deprecated. " + f"List the features instead. {[f.fqn() for f in inp_feature.features]}." ) # Check if the input dataframe has all the required features if not set(input_feature_names).issubset(set(input_dataframe.columns)): @@ -510,6 +542,15 @@ def extract_historical_features( f"Required features: {input_feature_names}. " f"Input dataframe columns: {input_dataframe.columns}" ) + for input_col, feature in zip( + input_dataframe.columns, input_feature_list + ): + if isinstance(feature, str): + continue + col_type = get_datatype(feature.dtype) # type: ignore + input_dataframe[input_col] = cast_col_to_dtype( + input_dataframe[input_col], col_type + ) extractors = get_extractor_order( input_feature_list, output_feature_list, self.extractors ) @@ -574,6 +615,19 @@ def _internal_log(self, dataset_name: str, df: pd.DataFrame): f"Timestamp field {timestamp_field} not found in dataframe " f"while logging to dataset `{dataset_name}`", ) + for col in df.columns: + # If any of the columns is a dictionary, convert it to a frozen dict + if df[col].apply(lambda x: isinstance(x, dict)).any(): + df[col] = df[col].apply(lambda x: frozendict(x)) + # Check if the dataframe has the same schema as the dataset + schema = dataset_req.dsschema + try: + df = cast_df_to_schema(df, schema) + except Exception as e: + return FakeResponse( + 400, + f"Schema validation failed during data insertion to `{dataset_name}`: {str(e)}", + ) if str(df[timestamp_field].dtype) != "datetime64[ns]": return FakeResponse( 400, @@ -581,26 +635,6 @@ def _internal_log(self, dataset_name: str, df: pd.DataFrame): f"datetime64[ns] but found {df[timestamp_field].dtype} in " f"dataset {dataset_name}", ) - for col in df.columns: - # If any of the columns is a dictionary, convert it to a frozen dict - if df[col].apply(lambda x: isinstance(x, dict)).any(): - df[col] = df[col].apply(lambda x: frozendict(x)) - # Check if the dataframe has the same schema as the dataset - schema = dataset_req.dsschema - # TODO(mohit, aditya): Instead of validating data schema, we should attempt to cast the - # df returned to the likely pd dtypes for a DF corresponding to the Dataset and re-raise - # the exception from it. - # - # The following scenario is currently possible - # Assume D1, D2 and D3 are datasets. - # - D3 = transform(D1.left_join(D2)) - # Since entries of D1 may not be present in D2, few columns could have NaN values. - # Say one of these columns with NaN is a key field for D3, which as part of transform is filled - # with a default value with a schema to match the field type in D3. - # Since left_join will automatically convert the int columns to a float type and - # this code not enforcing type casting at the Transform layer, the resulting DF will have a float - # column. Future lookups on this DF will fail as well since Lookup impl uses `merge_asof` - # and requires the merge columns to have the same type. exceptions = data_schema_check(schema, df, dataset_name) if len(exceptions) > 0: return FakeResponse( @@ -949,6 +983,96 @@ def _reset(self): self.extractors: List[Extractor] = [] +def proto_to_dtype(proto_dtype) -> str: + if proto_dtype.HasField("int_type"): + return "int" + elif proto_dtype.HasField("double_type"): + return "float" + elif proto_dtype.HasField("string_type"): + return "string" + elif proto_dtype.HasField("bool_type"): + return "bool" + elif proto_dtype.HasField("timestamp_type"): + return "timestamp" + elif proto_dtype.HasField("optional_type"): + return f"optional({proto_to_dtype(proto_dtype.optional_type.of)})" + else: + return str(proto_dtype) + + +def cast_col_to_dtype(series: pd.Series, dtype) -> pd.Series: + if not dtype.HasField("optional_type"): + if series.isnull().any(): + raise ValueError("Null values found in non-optional field.") + if dtype.HasField("int_type"): + return pd.to_numeric(series).astype(pd.Int64Dtype()) + elif dtype.HasField("double_type"): + return pd.to_numeric(series).astype(pd.Float64Dtype()) + elif dtype.HasField("string_type") or dtype.HasField("regex_type"): + return pd.Series([str(x) for x in series]).astype(pd.StringDtype()) + elif dtype.HasField("bool_type"): + return series.astype(pd.BooleanDtype()) + elif dtype.HasField("timestamp_type"): + return pd.to_datetime(series) + elif dtype.HasField("optional_type"): + # Those fields which are not null should be casted to the right type + if series.notnull().any(): + # collect the non-null values + tmp_series = series[series.notnull()] + non_null_idx = tmp_series.index + tmp_series = cast_col_to_dtype( + tmp_series, + dtype.optional_type.of, + ) + tmp_series.index = non_null_idx + # set the non-null values with the casted values using the index + series.loc[non_null_idx] = tmp_series + series.replace({np.nan: None}, inplace=True) + if callable(tmp_series.dtype): + series = series.astype(tmp_series.dtype()) + else: + series = series.astype(tmp_series.dtype) + return series + elif dtype.HasField("one_of_type"): + return cast_col_to_dtype(series, dtype.one_of_type.of) + elif dtype.HasField("between_type"): + return cast_col_to_dtype(series, dtype.between_type.dtype) + return series + + +def cast_df_to_schema(df: pd.DataFrame, dsschema: DSSchema) -> pd.DataFrame: + # Handle fields in keys and values + fields = list(dsschema.keys.fields) + list(dsschema.values.fields) + df = df.copy() + df = df.reset_index(drop=True) + for f in fields: + if f.name not in df.columns: + raise ValueError( + f"Field {f.name} not found in dataframe while logging to dataset" + ) + try: + series = cast_col_to_dtype(df[f.name], f.dtype) + series.name = f.name + df[f.name] = series + except Exception as e: + raise ValueError( + f"Failed to cast data logged to column `{f.name}` of type `{proto_to_dtype(f.dtype)}`: {e}" + ) + if dsschema.timestamp not in df.columns: + raise ValueError( + f"Timestamp column `{dsschema.timestamp}` not found in dataframe while logging to dataset" + ) + try: + df[dsschema.timestamp] = pd.to_datetime(df[dsschema.timestamp]).astype( + "datetime64[ns]" + ) + except Exception as e: + raise ValueError( + f"Failed to cast data logged to timestamp column {dsschema.timestamp}: {e}" + ) + return df + + def mock(test_func): def wrapper(*args, **kwargs): f = True diff --git a/fennel/test_lib/test_cast_df_to_schema.py b/fennel/test_lib/test_cast_df_to_schema.py new file mode 100644 index 000000000..39dd79412 --- /dev/null +++ b/fennel/test_lib/test_cast_df_to_schema.py @@ -0,0 +1,225 @@ +from datetime import datetime +from typing import Optional + +import numpy as np +import pandas as pd +import pytest + +from fennel import dataset, field +from fennel.gen.schema_pb2 import ( + DSSchema, + DataType, + Field, + IntType, + StringType, + TimestampType, +) +from fennel.lib.schema import between, oneof, regex +from fennel.lib.to_proto.to_proto import fields_to_dsschema +from fennel.test_lib.mock_client import cast_df_to_schema + + +# Example tests +def test_cast_int(): + @dataset + class TestDataset: + int_field: int + created_ts: datetime + + df = pd.DataFrame( + { + "int_field": ["1", "2", "3"], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + assert all(result_df["int_field"] == pd.Series([1, 2, 3], dtype="Int64")) + + +def test_cast_string(): + @dataset + class TestDataset: + string_field: str + created_ts: datetime + + df = pd.DataFrame( + { + "string_field": [123, 456, 789], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + assert all(result_df["string_field"] == pd.Series(["123", "456", "789"])) + + +def test_cast_optional_string(): + @dataset + class TestDataset: + int_field: Optional[int] + created_ts: datetime + + df = pd.DataFrame( + { + "int_field": ["123", None, "789"], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + expected = pd.Series([123, None, 789], name="int_field", dtype="Int64") + for i in range(3): + if pd.isna(expected.iloc[i]): + assert pd.isna(result_df["int_field"].iloc[i]) + else: + assert result_df["int_field"].iloc[i] == expected.iloc[i] + + +def test_cast_bool(): + @dataset + class TestDataset: + bool_field: bool + created_ts: datetime + + df = pd.DataFrame( + { + "bool_field": [1, 0, 1], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + assert all(result_df["bool_field"] == pd.Series([True, False, True])) + + +def test_cast_type_restrictions(): + @dataset + class TestDataset: + age: between(int, min=0, max=100) + gender: oneof(str, ["male", "female"]) + email: regex("^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]") + created_ts: datetime + + df = pd.DataFrame( + { + "age": ["21", "22", "23"], + "gender": [1, 2, 3], + "email": [1223423, 1223423, 1223423], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + assert all(result_df["age"] == pd.Series([21, 22, 23], dtype="Int64")) + assert all(result_df["gender"] == pd.Series(["1", "2", "3"])) + assert all( + result_df["email"] == pd.Series(["1223423", "1223423", "1223423"]) + ) + + +def test_cast_timestamp(): + @dataset + class TestDataset: + created_ts: datetime + + df = pd.DataFrame( + { + "created_ts": ["2021-01-01", "2021-01-02"], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + expected_timestamps = pd.Series( + [datetime(2021, 1, 1), datetime(2021, 1, 2)] + ) + assert all(result_df["created_ts"] == expected_timestamps) + assert result_df["created_ts"].dtype == expected_timestamps.dtype + + +def test_cast_timestamp_with_timezone(): + @dataset + class TestDataset: + created_ts: datetime + + df = pd.DataFrame( + { + "created_ts": [ + "2021-01-01T00:00:00.000Z", + "2021-01-02T00:00:00.000Z", + ], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + result_df = cast_df_to_schema(df, schema) + expected_timestamps = pd.Series( + [datetime(2021, 1, 1), datetime(2021, 1, 2)] + ) + assert all(result_df["created_ts"] == expected_timestamps) + assert result_df["created_ts"].dtype == expected_timestamps.dtype + + +def test_cast_invalid_timestamp(): + @dataset + class TestDataset: + created_ts: datetime + + df = pd.DataFrame( + { + "created_ts": [ + "2021-01-01T00:00:00.000Z", + "2021-01-02T00:00:00.000Z", + "2021-01-02T00:00:00.000Z", + "not a timestamp", + ], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + with pytest.raises(ValueError) as e: + cast_df_to_schema(df, schema) + assert ( + str(e.value) + == """Failed to cast data logged to timestamp column created_ts: Unknown string format: not a timestamp present at position 3""" + ) + + +def test_null_in_non_optional_field(): + @dataset + class TestDataset: + non_optional_field: int + created_ts: datetime + + df = pd.DataFrame( + { + "non_optional_field": [1, None, 2], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + with pytest.raises(ValueError) as e: + cast_df_to_schema(df, schema) + assert ( + str(e.value) + == """Failed to cast data logged to column `non_optional_field` of type `int`: Null values found in non-optional field.""" + ) + + +def test_cast_failure_for_incorrect_type(): + @dataset + class TestDataset: + int_field: int + created_ts: datetime + + df = pd.DataFrame( + { + "int_field": ["not_an_int", "123", "456"], + "created_ts": [datetime.now() for _ in range(3)], + } + ) + schema = fields_to_dsschema(TestDataset.fields) + with pytest.raises(ValueError) as e: + cast_df_to_schema(df, schema) + assert ( + str(e.value) + == """Failed to cast data logged to column `int_field` of type `int`: Unable to parse string "not_an_int" at position 0""" + ) diff --git a/pyproject.toml b/pyproject.toml index 8a8754ad7..309a51914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fennel-ai" -version = "0.18.13" +version = "0.18.14" description = "The modern realtime feature engineering platform" authors = ["Fennel AI "] packages = [{ include = "fennel" }]