diff --git a/docs/examples/examples/ecommerce.py b/docs/examples/examples/ecommerce.py index 289def205..e37d4d67d 100644 --- a/docs/examples/examples/ecommerce.py +++ b/docs/examples/examples/ecommerce.py @@ -30,7 +30,13 @@ # docsnip datasets -@source(postgres.table("orders", cursor="timestamp"), every="1m", lateness="1d") +@source( + postgres.table("orders", cursor="timestamp"), + every="1m", + lateness="1d", + tier="prod", +) +@source(Webhook(name="fennel_webhook").endpoint("Order"), tier="dev") @meta(owner="data-eng-oncall@fennel.ai") @dataset class Order: @@ -89,15 +95,14 @@ def myextractor(cls, ts: pd.Series, uids: pd.Series, sellers: pd.Series): # We can write a unit test to verify that the feature is working as expected # docsnip test -fake_webhook = Webhook(name="fennel_webhook") - class TestUserLivestreamFeatures(unittest.TestCase): @mock def test_feature(self, client): - fake_Order = Order.with_source(fake_webhook.endpoint("Order")) client.sync( - datasets=[fake_Order, UserSellerOrders], featuresets=[UserSeller] + datasets=[Order, UserSellerOrders], + featuresets=[UserSeller], + tier="dev", ) columns = ["uid", "product_id", "seller_id", "timestamp"] now = datetime.utcnow() diff --git a/docs/examples/featuresets/overview.py b/docs/examples/featuresets/overview.py index fad58b491..39a399533 100644 --- a/docs/examples/featuresets/overview.py +++ b/docs/examples/featuresets/overview.py @@ -9,7 +9,7 @@ from fennel.lib.metadata import meta from fennel.lib.schema import inputs, outputs from fennel.sources import source, Webhook -from fennel.test_lib import mock +from fennel.test_lib import mock, InternalTestClient webhook = Webhook(name="fennel_webhook") @@ -70,34 +70,41 @@ def e2(cls, ts: pd.Series, durations: pd.Series) -> pd.Series: # /docsnip -def test_multiple_extractors_of_same_feature(): - with pytest.raises(Exception): - # docsnip featureset_extractors_of_same_feature - @featureset - class Movies: - duration: int = feature(id=1) - over_2hrs: bool = feature(id=2) - # invalid: both e1 & e2 output `over_3hrs` - over_3hrs: bool = feature(id=3) - - @extractor - @inputs(duration) - @outputs(over_2hrs, over_3hrs) - def e1(cls, ts: pd.Series, durations: pd.Series) -> pd.DataFrame: - two_hrs = durations > 2 * 3600 - three_hrs = durations > 3 * 3600 - return pd.DataFrame( - {"over_2hrs": two_hrs, "over_3hrs": three_hrs} - ) - - @extractor - @inputs(duration) - @outputs(over_3hrs) - def e2(cls, ts: pd.Series, durations: pd.Series) -> pd.Series: - return pd.Series(name="over_3hrs", data=durations > 3 * 3600) +@mock +def test_multiple_extractors_of_same_feature(client): + # docsnip featureset_extractors_of_same_feature + @meta(owner="aditya@xyz.ai") + @featureset + class Movies: + duration: int = feature(id=1) + over_2hrs: bool = feature(id=2) + # invalid: both e1 & e2 output `over_3hrs` + over_3hrs: bool = feature(id=3) + + @extractor(tier=["default"]) + @inputs(duration) + @outputs(over_2hrs, over_3hrs) + def e1(cls, ts: pd.Series, durations: pd.Series) -> pd.DataFrame: + two_hrs = durations > 2 * 3600 + three_hrs = durations > 3 * 3600 + return pd.DataFrame({"over_2hrs": two_hrs, "over_3hrs": three_hrs}) + + @extractor(tier=["non-default"]) + @inputs(duration) + @outputs(over_3hrs) + def e2(cls, ts: pd.Series, durations: pd.Series) -> pd.Series: + return pd.Series(name="over_3hrs", data=durations > 3 * 3600) + # /docsnip -# /docsnip + view = InternalTestClient() + view.add(Movies) + with pytest.raises(Exception) as e: + view._get_sync_request_proto() + assert ( + str(e.value) + == "Feature `over_3hrs` is extracted by multiple extractors including `e2`." + ) # docsnip remote_feature_as_input diff --git a/docs/examples/getting-started/quickstart.py b/docs/examples/getting-started/quickstart.py index 233ebd5bc..2e5530781 100644 --- a/docs/examples/getting-started/quickstart.py +++ b/docs/examples/getting-started/quickstart.py @@ -23,6 +23,7 @@ postgres = Postgres.get(name="my_rdbms") warehouse = Snowflake.get(name="my_warehouse") kafka = Kafka.get(name="my_kafka") +webhook = Webhook(name="fennel_webhook") # /docsnip @@ -30,7 +31,12 @@ # docsnip datasets @dataset -@source(postgres.table("product_info", cursor="last_modified"), every="1m") +@source( + postgres.table("product_info", cursor="last_modified"), + every="1m", + tier="prod", +) +@source(webhook.endpoint("Product"), tier="dev") @meta(owner="chris@fennel.ai", tags=["PII"]) class Product: product_id: int = field(key=True) @@ -51,7 +57,8 @@ def get_expectations(cls): # ingesting realtime data from Kafka works exactly the same way @meta(owner="eva@fennel.ai") -@source(kafka.topic("orders"), lateness="1h") +@source(kafka.topic("orders"), lateness="1h", tier="prod") +@source(webhook.endpoint("Order"), tier="dev") @dataset class Order: uid: int @@ -122,15 +129,13 @@ def myextractor(cls, ts: pd.Series, uids: pd.Series, sellers: pd.Series): # docsnip sync from fennel.test_lib import MockClient -webhook = Webhook(name="fennel_webhook") # client = Client('') # uncomment this line to use a real Fennel server client = MockClient() # comment this line to use a real Fennel server -fake_Product = Product.with_source(webhook.endpoint("Product")) -fake_Order = Order.with_source(webhook.endpoint("Order")) client.sync( - datasets=[fake_Order, fake_Product, UserSellerOrders], + datasets=[Order, Product, UserSellerOrders], featuresets=[UserSellerFeatures], + tier="dev", ) now = datetime.utcnow() diff --git a/docs/examples/overview/concepts.py b/docs/examples/overview/concepts.py index 9ef6f5e8a..5c8d1c25c 100644 --- a/docs/examples/overview/concepts.py +++ b/docs/examples/overview/concepts.py @@ -26,11 +26,15 @@ class UserDataset: postgres = Postgres.get(name="postgres") kafka = Kafka.get(name="kafka") +webhook = Webhook(name="fennel_webhook") # docsnip external_data_sources @meta(owner="data-eng-oncall@fennel.ai") -@source(postgres.table("user", cursor="update_timestamp"), every="1m") +@source( + postgres.table("user", cursor="update_timestamp"), every="1m", tier="prod" +) +@source(webhook.endpoint("User"), tier="dev") @dataset class User: uid: int = field(key=True) @@ -40,7 +44,8 @@ class User: @meta(owner="data-eng-oncall@fennel.ai") -@source(kafka.topic("transactions")) +@source(kafka.topic("transactions"), tier="prod") +@source(webhook.endpoint("Transaction"), tier="dev") @dataset class Transaction: uid: int @@ -118,15 +123,13 @@ def get_country(cls, ts: pd.Series, uids: pd.Series): # /docsnip -webhook = Webhook(name="fennel_webhook") - # Tests to ensure that there are no run time errors in the snippets @mock def test_overview(client): - fake_User = User.with_source(webhook.endpoint("User")) - fake_Transaction = Transaction.with_source(webhook.endpoint("Transaction")) - client.sync(datasets=[fake_User, fake_Transaction, UserTransactionsAbroad]) + client.sync( + datasets=[User, Transaction, UserTransactionsAbroad], tier="dev" + ) now = datetime.now() dob = now - timedelta(days=365 * 30) data = [ diff --git a/fennel/CHANGELOG.md b/fennel/CHANGELOG.md index a2a761220..d84217724 100644 --- a/fennel/CHANGELOG.md +++ b/fennel/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## [0.18.11] - 2023-11-08 +- Add support for tier selectors. + ## [0.18.10] - 2023-10-30 - Add support for `since` in S3 source. diff --git a/fennel/client/client.py b/fennel/client/client.py index 3159cc2eb..172416cbb 100644 --- a/fennel/client/client.py +++ b/fennel/client/client.py @@ -60,6 +60,7 @@ def sync( self, datasets: Optional[List[Dataset]] = None, featuresets: Optional[List[Featureset]] = None, + tier: Optional[str] = None, ): """ Sync the client with the server. This will register any datasets or @@ -92,7 +93,7 @@ def sync( f" of type `{type(featureset)}` instead." ) self.add(featureset) - sync_request = self._get_sync_request_proto() + sync_request = self._get_sync_request_proto(tier) response = self._post_bytes( "{}/sync".format(V1_API), sync_request.SerializeToString(), @@ -634,8 +635,10 @@ def _get_session(): ) return http - def _get_sync_request_proto(self): - return to_sync_request_proto(self.to_register_objects) + def _get_sync_request_proto(self, tier: Optional[str] = None): + if tier is not None and not isinstance(tier, str): + raise ValueError(f"Expected tier to be a string, got {tier}") + return to_sync_request_proto(self.to_register_objects, tier) def _get(self, path: str): headers = None diff --git a/fennel/client_tests/test_movie_tickets.py b/fennel/client_tests/test_movie_tickets.py index 179d8653d..f79d9a7b0 100644 --- a/fennel/client_tests/test_movie_tickets.py +++ b/fennel/client_tests/test_movie_tickets.py @@ -50,7 +50,7 @@ class ActorStats: revenue: int at: datetime - @pipeline(version=1) + @pipeline(version=1, tier="prod") @inputs(MovieInfo, TicketSale) def pipeline_join(cls, info: Dataset, sale: Dataset): uniq = sale.groupby("ticket_id").first() @@ -73,7 +73,7 @@ def pipeline_join(cls, info: Dataset, sale: Dataset): ] ) - @pipeline(version=2, active=True) + @pipeline(version=2, active=True, tier="prod") @inputs(MovieInfo, TicketSale) def pipeline_join_v2(cls, info: Dataset, sale: Dataset): def foo(df): @@ -112,25 +112,29 @@ class RequestFeatures: class ActorFeatures: revenue: int = feature(id=1) - @extractor(depends_on=[ActorStats]) + @extractor(depends_on=[ActorStats], tier="prod") @inputs(RequestFeatures.name) @outputs(revenue) def extract_revenue(cls, ts: pd.Series, name: pd.Series): - import sys - - print(name, file=sys.stderr) - print("##", name.name, file=sys.stderr) df, _ = ActorStats.lookup(ts, name=name) # type: ignore df = df.fillna(0) return df["revenue"] + @extractor(depends_on=[ActorStats], tier="staging") + @inputs(RequestFeatures.name) + @outputs(revenue) + def extract_revenue2(cls, ts: pd.Series, name: pd.Series): + df, _ = ActorStats.lookup(ts, name=name) # type: ignore + df = df.fillna(0) + return df["revenue"] * 2 + class TestMovieTicketSale(unittest.TestCase): @mock def test_movie_ticket_sale(self, client): datasets = [MovieInfo, TicketSale, ActorStats] # type: ignore featuresets = [ActorFeatures, RequestFeatures] - client.sync(datasets=datasets, featuresets=featuresets) # type: ignore + client.sync(datasets=datasets, featuresets=featuresets, tier="prod") # type: ignore client.sleep() data = [ [ diff --git a/fennel/client_tests/test_tier_selector.py b/fennel/client_tests/test_tier_selector.py new file mode 100644 index 000000000..028fbb0fb --- /dev/null +++ b/fennel/client_tests/test_tier_selector.py @@ -0,0 +1,135 @@ +from fennel import Sum +from fennel.featuresets import featureset, extractor, feature +from fennel.lib.schema import outputs +from datetime import datetime + +import pandas as pd +from google.protobuf.json_format import ParseDict # type: ignore +from typing import List + +from fennel.datasets import dataset, pipeline, field, Dataset +from fennel.lib.metadata import meta +from fennel.lib.schema import inputs +from fennel.lib.window import Window +from fennel.sources import source, Webhook +from fennel.test_lib import * + +webhook = Webhook(name="fennel_webhook") + + +@meta(owner="abhay@fennel.ai") +@source(webhook.endpoint("MovieInfo"), tier="prod") +@source(webhook.endpoint("MovieInfo2"), tier="staging") +@dataset +class MovieInfo: + title: str = field(key=True) + actors: List[str] # can be an empty list + release: datetime + + +@meta(owner="abhay@fennel.ai") +@source(webhook.endpoint("TicketSale"), tier="prod") +@source(webhook.endpoint("TicketSale2"), tier="staging") +@dataset +class TicketSale: + ticket_id: str + title: str + price: int + at: datetime + + +@meta(owner="abhay@fennel.ai") +@dataset +class ActorStats: + name: str = field(key=True) + revenue: int + at: datetime + + @pipeline(version=1, tier="prod") + @inputs(MovieInfo, TicketSale) + def pipeline_join(cls, info: Dataset, sale: Dataset): + uniq = sale.groupby("ticket_id").first() + c = ( + uniq.join(info, how="inner", on=["title"]) + .explode(columns=["actors"]) + .rename(columns={"actors": "name"}) + ) + c = c.dropnull() + return c.groupby("name").aggregate( + [ + Sum( + window=Window("forever"), + of="price", + into_field="revenue", + ), + ] + ) + + @pipeline(version=2, active=True, tier="staging") + @inputs(MovieInfo, TicketSale) + def pipeline_join_v2(cls, info: Dataset, sale: Dataset): + def foo(df): + df["price"] = df["price"] * 2 + return df + + uniq = sale.groupby("ticket_id").first() + c = ( + uniq.join(info, how="inner", on=["title"]) + .explode(columns=["actors"]) + .rename(columns={"actors": "name"}) + ) + c = c.dropnull() + return c.groupby("name").aggregate( + [ + Sum( + window=Window("forever"), + of="price", + into_field="revenue", + ), + ] + ) + + +@meta(owner="zaki@fennel.ai") +@featureset +class RequestFeatures: + name: str = feature(id=1) + + +@meta(owner="abhay@fennel.ai") +@featureset +class ActorFeatures: + revenue: int = feature(id=1) + + @extractor(depends_on=[ActorStats], tier="prod") + @inputs(RequestFeatures.name) + @outputs(revenue) + def extract_revenue(cls, ts: pd.Series, name: pd.Series): + df, _ = ActorStats.lookup(ts, name=name) # type: ignore + df = df.fillna(0) + return df["revenue"] + + @extractor(depends_on=[ActorStats], tier="staging") + @inputs(RequestFeatures.name) + @outputs(revenue) + def extract_revenue2(cls, ts: pd.Series, name: pd.Series): + df, _ = ActorStats.lookup(ts, name=name) # type: ignore + df = df.fillna(0) + return df["revenue"] * 2 + + +def test_tier_selector(): + view = InternalTestClient() + view.add(MovieInfo) + view.add(TicketSale) + view.add(ActorStats) + view.add(RequestFeatures) + view.add(ActorFeatures) + + sync_request = view._get_sync_request_proto("dev") + assert len(sync_request.feature_sets) == 2 + assert len(sync_request.features) == 2 + assert len(sync_request.datasets) == 3 + assert len(sync_request.sources) == 0 + assert len(sync_request.pipelines) == 0 + assert len(sync_request.extractors) == 0 diff --git a/fennel/datasets/datasets.py b/fennel/datasets/datasets.py index 02bfa55b0..e7412be3e 100644 --- a/fennel/datasets/datasets.py +++ b/fennel/datasets/datasets.py @@ -7,6 +7,7 @@ import sys from dataclasses import dataclass import typing +import logging import numpy as np import pandas as pd @@ -43,6 +44,7 @@ duration_to_timedelta, ) from fennel.lib.expectations import Expectations, GE_ATTR_FUNC +from fennel.lib.includes import TierSelector from fennel.lib.metadata import ( meta, get_meta_attr, @@ -943,7 +945,9 @@ def f_get_type_hints(obj): def pipeline( - version: int = 1, active: bool = False + version: int = 1, + active: bool = False, + tier: Optional[Union[str, List[str]]] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: if isinstance(version, Callable) or isinstance( # type: ignore version, Dataset @@ -1011,6 +1015,7 @@ def wrapper(pipeline_func: Callable) -> Callable: func=pipeline_func, version=version, active=active, + tier=tier, ), ) return pipeline_func @@ -1067,6 +1072,7 @@ class Pipeline: name: str version: int active: bool + tier: TierSelector def __init__( self, @@ -1074,12 +1080,14 @@ def __init__( func: Callable, version: int, active: bool = False, + tier: Optional[Union[str, List[str]]] = None, ): self.inputs = inputs self.func = func # type: ignore self.name = func.__name__ self.version = version self.active = active + self.tier = TierSelector(tier) # Validate the schema of all intermediate nodes # and return the schema of the terminal node. @@ -1169,7 +1177,12 @@ def with_source( every: Optional[Duration] = None, starting_from: Optional[datetime.datetime] = None, lateness: Optional[Duration] = None, + tiers: Optional[Union[str, List[str]]] = None, ): + logger = logging.getLogger(__name__) + logger.warning( + "with_source is deprecated. Please use tier selector instead." + ) if len(self._pipelines) > 0: raise Exception( f"Dataset {self._name} is contains a pipeline. " @@ -1178,7 +1191,7 @@ def with_source( ds_copy = copy.deepcopy(self) if hasattr(ds_copy, sources.SOURCE_FIELD): delattr(ds_copy, sources.SOURCE_FIELD) - src_fn = source(conn, every, starting_from, lateness) + src_fn = source(conn, every, starting_from, lateness, None, tiers) return src_fn(ds_copy) def dsschema(self): @@ -1341,7 +1354,6 @@ def _get_on_demand(self) -> Optional[OnDemand]: def _get_pipelines(self) -> List[Pipeline]: pipelines = [] dataset_name = self._name - versions = set() names = set() for name, method in inspect.getmembers(self.__fennel_original_cls__): if not callable(method): @@ -1351,11 +1363,6 @@ def _get_pipelines(self) -> List[Pipeline]: pipeline = getattr(method, PIPELINE_ATTR) - if pipeline.version in versions: - raise ValueError( - f"Duplicate pipeline id {pipeline.version} for dataset {dataset_name}." - ) - versions.add(pipeline.version) if pipeline.name in names: raise ValueError( f"Duplicate pipeline name {pipeline.name} for dataset {dataset_name}." @@ -1388,25 +1395,13 @@ def _validate_pipelines(self, pipelines: List[Pipeline]): timestamp=self.timestamp_field, ) - found_active = False for pipeline in pipelines: pipeline_schema = pipeline.get_terminal_schema() - if pipeline.active and found_active: - raise ValueError( - f"Multiple active pipelines are not supported for dataset {self._name}." - ) - if pipeline.active: - found_active = True err = pipeline_schema.matches( ds_schema, f"pipeline {pipeline.name} output", self._name ) if len(err) > 0: exceptions.extend(err) - if not found_active and len(pipelines) > 1: - raise ValueError( - f"No active pipeline found for dataset {self._name}." - ) - if exceptions: raise TypeError(exceptions) @@ -1454,6 +1449,33 @@ def fields(self): return self._fields +def sync_validation_for_pipelines(pipelines: List[Pipeline], ds_name: str): + """ + This validation function contains the checks that are run just before the sync call. + It should only contain checks that are not possible to run during the registration phase/compilation phase. + """ + versions = set() + for pipeline in pipelines: + if pipeline.version in versions: + raise ValueError( + f"Pipeline {pipeline.fqn} has the same version as another pipeline in the dataset." + ) + versions.add(pipeline.version) + + found_active = False + for pipeline in pipelines: + if pipeline.active and found_active: + raise ValueError( + f"Multiple active pipelines are not supported for dataset {ds_name}." + ) + if pipeline.active: + found_active = True + if not found_active and len(pipelines) > 1: + raise ValueError( + f"No active pipeline found for dataset {ds_name}. Please mark one of the pipelines as active by setting `active=True`." + ) + + # --------------------------------------------------------------------- # Visitor # --------------------------------------------------------------------- diff --git a/fennel/datasets/test_dataset.py b/fennel/datasets/test_dataset.py index b5a8bba69..31cf12bc1 100644 --- a/fennel/datasets/test_dataset.py +++ b/fennel/datasets/test_dataset.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta import pandas as pd +import pytest from google.protobuf.json_format import ParseDict # type: ignore from typing import Optional, List @@ -13,7 +14,7 @@ from fennel.lib.metadata import meta from fennel.lib.schema import Embedding, inputs, oneof from fennel.lib.window import Window -from fennel.sources import source, Webhook +from fennel.sources import source, Webhook, Kafka from fennel.test_lib import * webhook = Webhook(name="fennel_webhook") @@ -2528,3 +2529,55 @@ def extract_info(df: pd.DataFrame) -> pd.DataFrame: } return x + + +def test_pipeline_with_tier_selector(): + kafka = Kafka.get(name="my_kafka") + + @meta(owner="test@test.com") + @source(kafka.topic("orders"), lateness="1h") + @dataset + class A: + a1: int = field(key=True) + t: datetime + + @meta(owner="test@test.com") + @source(kafka.topic("orders2"), lateness="1h") + @dataset + class B: + b1: int = field(key=True) + t: datetime + + @meta(owner="aditya@fennel.ai") + @dataset + class ABCDatasetDefault: + a1: int = field(key=True) + t: datetime + + @pipeline(version=1, tier="prod") + @inputs(A, B) + def pipeline1(cls, a: Dataset, b: Dataset): + return a.join(b, how="left", left_on=["a1"], right_on=["b1"]) + + @pipeline(version=1, tier="staging") + @inputs(A, B) + def pipeline2(cls, a: Dataset, b: Dataset): + return a.join(b, how="inner", left_on=["a1"], right_on=["b1"]) + + view = InternalTestClient() + view.add(A) # type: ignore + view.add(B) # type: ignore + view.add(ABCDatasetDefault) # type: ignore + with pytest.raises(ValueError) as e: + _ = view._get_sync_request_proto() + assert ( + str(e.value) + == "Pipeline ABCDatasetDefault-pipeline2 has the same version as another pipeline in the dataset." + ) + + with pytest.raises(ValueError) as e: + _ = view._get_sync_request_proto(tier=["prod"]) + assert str(e.value) == "Expected tier to be a string, got ['prod']" + sync_request = view._get_sync_request_proto(tier="prod") + pipelines = sync_request.pipelines + assert len(pipelines) == 1 diff --git a/fennel/featuresets/featureset.py b/fennel/featuresets/featureset.py index 2a0a44894..e5a155fde 100644 --- a/fennel/featuresets/featureset.py +++ b/fennel/featuresets/featureset.py @@ -17,6 +17,7 @@ List, overload, Set, + Union, ) @@ -29,6 +30,7 @@ get_meta_attr, set_meta_attr, ) +from fennel.lib.includes import TierSelector from fennel.lib.schema import FENNEL_INPUTS, FENNEL_OUTPUTS from fennel.utils import ( parse_annotation_comments, @@ -125,37 +127,11 @@ def featureset(featureset_cls: Type[T]): ) -@overload def extractor( - func: Callable[..., T], -): - ... - - -@overload -def extractor( - *, - depends_on: List[T], - version: int, -): - ... - - -@overload -def extractor( - *, - depends_on: List[T], -): - ... - - -@overload -def extractor(): - ... - - -def extractor( - func: Optional[Callable] = None, depends_on: List = [], version: int = 0 + func: Optional[Callable] = None, + depends_on: List = [], + version: int = 0, + tier: Optional[Union[str, List[str]]] = None, ): """ extractor is a decorator for a function that extracts a feature from a @@ -266,6 +242,7 @@ def _create_extractor(extractor_func: Callable, version: int): outputs, version, func=extractor_func, + tier=tier, ), ) return classmethod(extractor_func) @@ -349,6 +326,7 @@ def extract( default=None, feature: Feature = None, version: int = 0, + tier: Optional[Union[str, List[str]]] = None, ) -> Feature: """ Derives an extractor for the feature using the given params. @@ -364,6 +342,7 @@ def extract( feature: If provided, this function creates a one way alias from the calling feature to this feature. version: the version of this extractor + tiers: The tiers which are enabled for this feature. If None, then all tiers are enabled. Returns: Feature: This feature @@ -385,6 +364,7 @@ def extract( inputs=[feature], outputs=[self.id], version=version, + tier=tier, ) return self @@ -418,6 +398,7 @@ def extract( version=version, derived_extractor_info=Extractor.DatasetLookupInfo(field, default), depends_on=[ds] if ds else [], + tier=tier, ) return self @@ -551,17 +532,6 @@ def _validate(self): ) feature_id_set.add(feature.id) - # Check that each feature is extracted by at max one extractor. - extracted_features: Set[int] = set() - for extractor in self._extractors: - for feature_id in extractor.output_feature_ids: - if feature_id in extracted_features: - raise TypeError( - f"Feature `{self._id_to_feature[feature_id].name}` is " - f"extracted by multiple extractors." - ) - extracted_features.add(feature_id) - def _set_extractors_as_attributes(self): for extractor in self._extractors: if extractor.extractor_type == ExtractorType.PY_FUNC: @@ -623,6 +593,8 @@ class Extractor: # depended on datasets: used for autogenerated extractors depends_on: List[Dataset] + tiers: TierSelector + def __init__( self, name: str, @@ -633,6 +605,7 @@ def __init__( func: Optional[Callable] = None, derived_extractor_info: Optional[DatasetLookupInfo] = None, depends_on: List[Dataset] = [], + tier: Optional[Union[str, List[str]]] = None, ): self.name = name self.extractor_type = extractor_type @@ -642,6 +615,7 @@ def __init__( self.output_feature_ids = outputs self.version = version self.depends_on = depends_on + self.tiers = TierSelector(tier) def fqn(self) -> str: """Fully qualified name of the extractor.""" @@ -702,3 +676,19 @@ class DatasetLookupInfo: def __init__(self, field: Field, default_val: Any): self.field = field self.default = default_val + + +def sync_validation_for_extractors(extractors: List[Extractor]): + """ + This validation function contains the checks that are run just before the sync call. + It should only contain checks that are not possible to run during the registration phase/compilation phase. + """ + extracted_features: Set[str] = set() + for extractor in extractors: + for feature in extractor.output_features: + if feature in extracted_features: + raise TypeError( + f"Feature `{feature}` is " + f"extracted by multiple extractors including `{extractor.name}`." + ) + extracted_features.add(feature) diff --git a/fennel/featuresets/test_featureset.py b/fennel/featuresets/test_featureset.py index c4a8ad900..b2ed14c89 100644 --- a/fennel/featuresets/test_featureset.py +++ b/fennel/featuresets/test_featureset.py @@ -1,6 +1,7 @@ from datetime import datetime import pandas as pd +import pytest from google.protobuf.json_format import ParseDict # type: ignore from typing import Optional @@ -28,6 +29,7 @@ class UserInfoDataset: account_creation_date: datetime country: Optional[str] timestamp: datetime = field(timestamp=True) + avg_income: int @meta(owner="test@test.com") @@ -344,3 +346,75 @@ def get_user_info3(cls, ts: pd.Series, user_id: pd.Series) -> income: assert actual_extractor == expected_extractor, error_message( actual_extractor, expected_extractor ) + + +def test_extractor_tier_selector(): + @meta(owner="aditya@fennel.ai") + @featureset + class Request: + user_id: int = feature(id=1) + + @meta(owner="aditya@fennel.ai") + @featureset + class UserInfo: + userid: int = feature(id=1).extract( + feature=Request.user_id, tier=["~staging", "~prod"] + ) + home_geoid: int = feature(id=2) + # The users gender among male/female/non-binary + gender: str = feature(id=3) + age: int = feature(id=4).meta(owner="aditya@fennel.ai") + income: int = feature(id=5).extract( # type: ignore + field=UserInfoDataset.avg_income, + provider=Request, + default="pluto", + tier=["~prod"], + ) + + @extractor(depends_on=[UserInfoDataset], tier=["~prod", "~dev"]) + @inputs(User.id) + @outputs(userid, home_geoid) + def get_user_info1(cls, ts: pd.Series, user_id: pd.Series): + pass + + @extractor(depends_on=[UserInfoDataset], tier=["prod"]) + @inputs(User.id) + @outputs(userid, home_geoid) + def get_user_info2(cls, ts: pd.Series, user_id: pd.Series): + pass + + @extractor(depends_on=[UserInfoDataset], tier=["prod"]) + @inputs(User.id) + @outputs(income) + def get_user_income(cls, ts: pd.Series, user_id: pd.Series): + pass + + view = InternalTestClient() + view.add(UserInfoDataset) + view.add(UserInfo) + view.add(User) + with pytest.raises(TypeError) as e: + view._get_sync_request_proto() + assert ( + str(e.value) + == "Feature `income` is extracted by multiple extractors including `get_user_income`." + ) + + sync_request = view._get_sync_request_proto("prod") + assert len(sync_request.feature_sets) == 2 + assert len(sync_request.extractors) == 2 + assert len(sync_request.features) == 7 + + extractor_req = sync_request.extractors[1] + assert extractor_req.name == "get_user_info2" + + sync_request = view._get_sync_request_proto("dev") + assert len(sync_request.feature_sets) == 2 + assert len(sync_request.extractors) == 2 + assert len(sync_request.features) == 7 + + extractor_req = sync_request.extractors[0] + assert extractor_req.name == "_fennel_alias_user_id" + + extractor_req = sync_request.extractors[1] + assert extractor_req.name == "_fennel_lookup_avg_income" diff --git a/fennel/featuresets/test_invalid_derived_extractors.py b/fennel/featuresets/test_invalid_derived_extractors.py index 5d9f1f54a..139d4c908 100644 --- a/fennel/featuresets/test_invalid_derived_extractors.py +++ b/fennel/featuresets/test_invalid_derived_extractors.py @@ -74,6 +74,7 @@ class UserInfo2: # Tests a derived and manual extractor for the same feature with pytest.raises(TypeError) as e: + @meta(owner="user@xyz.ai") @featureset class UserInfo3: user_id: int = feature(id=1).extract(feature=User.id) @@ -89,7 +90,13 @@ def get_age(cls, ts: pd.Series, user_id: pd.Series): df = UserInfoDataset.lookup(ts, user_id=user_id) # type: ignore return df.fillna(0) - assert str(e.value) == "Feature `age` is extracted by multiple extractors." + view = InternalTestClient() + view.add(UserInfo3) + view._get_sync_request_proto() + assert ( + str(e.value) + == "Feature `age` is extracted by multiple extractors including `get_age`." + ) def test_invalid_missing_fields(): diff --git a/fennel/featuresets/test_invalid_featureset.py b/fennel/featuresets/test_invalid_featureset.py index d5ada4678..9d96d2f14 100644 --- a/fennel/featuresets/test_invalid_featureset.py +++ b/fennel/featuresets/test_invalid_featureset.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from fennel import meta from fennel.datasets import dataset, field from fennel.featuresets import featureset, extractor, feature from fennel.lib.schema import inputs, outputs @@ -25,6 +26,7 @@ class UserInfoDataset: timestamp: datetime = field(timestamp=True) +@meta(owner="aditya@fennel.ai") @featureset class User: id: int = feature(id=1) @@ -54,6 +56,7 @@ def get_user_info1(cls, ts: pd.Series, user: pd.Series): def test_complex_featureset(): with pytest.raises(TypeError) as e: + @meta(owner="aditya@fennel.ai") @featureset class UserInfo: userid: int = feature(id=1) @@ -81,8 +84,13 @@ def get_user_info2(cls, ts: pd.Series, user_id: pd.Series): def get_user_info3(cls, ts: pd.Series, user_id: pd.Series): pass + view = InternalTestClient() + view.add(User) + view.add(UserInfo) + view._get_sync_request_proto() assert ( - str(e.value) == "Feature `gender` is extracted by multiple extractors." + str(e.value) + == "Feature `gender` is extracted by multiple extractors including `get_user_info3`." ) diff --git a/fennel/lib/includes/__init__.py b/fennel/lib/includes/__init__.py index e8aced87c..fa56a0b15 100644 --- a/fennel/lib/includes/__init__.py +++ b/fennel/lib/includes/__init__.py @@ -1 +1,6 @@ -from fennel.lib.includes.include_mod import FENNEL_INCLUDED_MOD, includes +from fennel.lib.includes.include_mod import ( + FENNEL_INCLUDED_MOD, + FENNEL_TIER_SELECTOR, + TierSelector, + includes, +) diff --git a/fennel/lib/includes/include_mod.py b/fennel/lib/includes/include_mod.py index 643a592af..9c3cc893f 100644 --- a/fennel/lib/includes/include_mod.py +++ b/fennel/lib/includes/include_mod.py @@ -1,9 +1,11 @@ from __future__ import annotations import functools -from typing import Callable, Any +from typing import Callable, Any, List, Optional, Union +from fennel._vendor.pydantic import BaseModel, validator, ValidationError # type: ignore FENNEL_INCLUDED_MOD = "__fennel_included_module__" +FENNEL_TIER_SELECTOR = "__fennel_tier_selector__" def includes(*args: Any): @@ -20,3 +22,67 @@ def decorator(func: Callable) -> Callable: return func return decorator + + +class TierSelector(BaseModel): + """ + TierSelector is a feature that can be added to entities to specify the tiers an entity supports. + """ + + tiers: Optional[Union[str, List[str]]] = None + + def __init__(self, tiers): + super().__init__(tiers=tiers) + + @validator("tiers", pre=True, each_item=True, allow_reuse=True) + def check_string(cls, v): + if v is None: + return v + if isinstance(v, str): + if " " in v: + raise ValidationError( + "Tier string must not contain spaces, found", v + ) + if len(v.strip()) == 0: + raise ValidationError("Tier string must not be empty, found", v) + if v == "~": + raise ValidationError("Tier string must not be ~, found", v) + return v + + @validator("tiers", allow_reuse=True) + def validate_tiers(cls, v): + if isinstance(v, str) or v is None: + return v + # Cannot contain tiers with ~ and without ~, should be one or the other + if ( + v + and any(tier.startswith("~") for tier in v) + and any(not tier.startswith("~") for tier in v) + ): + raise ValidationError( + "Cannot contain tiers with ~ and without ~, should be one or the other, found", + v, + ) + return v + + def is_entity_selected(self, tier: Optional[str] = None) -> bool: + if self.tiers is None or tier is None: + return True + if tier[0] == "~": + raise ValueError( + "Tier selector cannot start with ~, found", tier[0] + ) + + if isinstance(self.tiers, str): + if self.tiers[0] == "~": + return self.tiers[1:] != tier + return self.tiers == tier + + if any(t.startswith("~") for t in self.tiers): + if any(t[1:] == tier for t in self.tiers): + return False + return True + return tier in self.tiers + + class Config: + arbitrary_types_allowed = True diff --git a/fennel/lib/includes/test_tier_selector.py b/fennel/lib/includes/test_tier_selector.py new file mode 100644 index 000000000..0e76e449f --- /dev/null +++ b/fennel/lib/includes/test_tier_selector.py @@ -0,0 +1,50 @@ +import pytest +from fennel._vendor.pydantic import ValidationError # type: ignore + +from fennel.lib.includes import TierSelector + + +def test_tier_selector(): + with pytest.raises(ValidationError): + TierSelector(tiers="") + + with pytest.raises(ValidationError): + TierSelector(tiers="a b") + + with pytest.raises(ValidationError): + TierSelector(tiers=["~gold", "silver"]) + + try: + TierSelector(tiers=["~gold", "~silver"]) + except ValidationError: + pytest.fail("ValidationError should not be raised") + + try: + TierSelector(tiers=["gold", "silver"]) + except ValidationError: + pytest.fail("ValidationError should not be raised") + + try: + TierSelector(tiers=None) + except ValidationError: + pytest.fail("ValidationError should not be raised") + + +def test_is_entity_selected(): + selector = TierSelector(tiers=None) + assert selector.is_entity_selected("gold") is True + + selector = TierSelector(tiers="~gold") + assert selector.is_entity_selected("silver") is True + with pytest.raises(ValueError): + selector.is_entity_selected("~silver") + + selector = TierSelector(tiers="gold") + assert selector.is_entity_selected("silver") is False + + selector = TierSelector(tiers=["gold", "silver"]) + assert selector.is_entity_selected("gold") is True + + selector = TierSelector(tiers=["~gold", "~silver"]) + assert selector.is_entity_selected("gold") is False + assert selector.is_entity_selected("bronze") is True diff --git a/fennel/lib/to_proto/to_proto.py b/fennel/lib/to_proto/to_proto.py index 0e0160e29..b3c62b8e9 100644 --- a/fennel/lib/to_proto/to_proto.py +++ b/fennel/lib/to_proto/to_proto.py @@ -8,7 +8,7 @@ import google.protobuf.duration_pb2 as duration_proto # type: ignore from google.protobuf.timestamp_pb2 import Timestamp from google.protobuf.wrappers_pb2 import BoolValue, StringValue -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Set import fennel.gen.schema_registry_pb2 as schema_registry_proto import fennel.gen.http_auth_pb2 as http_auth_proto @@ -23,7 +23,9 @@ import fennel.gen.services_pb2 as services_proto import fennel.sources as sources from fennel.datasets import Dataset, Pipeline, Field +from fennel.datasets.datasets import sync_validation_for_pipelines from fennel.featuresets import Featureset, Feature, Extractor, ExtractorType +from fennel.featuresets.featureset import sync_validation_for_extractors from fennel.lib.duration import ( Duration, duration_to_timedelta, @@ -79,8 +81,11 @@ def _expectations_to_proto( # ------------------------------------------------------------------------------ # Sync # ------------------------------------------------------------------------------ + + def to_sync_request_proto( registered_objs: List[Any], + tier: Optional[str] = None, ) -> services_proto.SyncRequest: datasets = [] pipelines = [] @@ -108,12 +113,10 @@ def to_sync_request_proto( ) datasets.append(dataset_to_proto(obj)) - pipelines.extend(pipelines_from_ds(obj)) + pipelines.extend(pipelines_from_ds(obj, tier)) operators.extend(operators_from_ds(obj)) expectations.extend(expectations_from_ds(obj)) - res = sources_from_ds( - obj, sources.SOURCE_FIELD, obj.timestamp_field - ) + res = sources_from_ds(obj, obj.timestamp_field, tier) if res is None: continue (ext_db, s) = res @@ -127,7 +130,7 @@ def to_sync_request_proto( elif isinstance(obj, Featureset): featuresets.append(featureset_to_proto(obj)) features.extend(features_from_fs(obj)) - extractors.extend(extractors_from_fs(obj, featureset_obj_map)) + extractors.extend(extractors_from_fs(obj, featureset_obj_map, tier)) else: raise ValueError(f"Unknown object type {type(obj)}") return services_proto.SyncRequest( @@ -223,11 +226,20 @@ def _field_to_proto(field: Field) -> schema_proto.Field: ) -def pipelines_from_ds(ds: Dataset) -> List[ds_proto.Pipeline]: +def pipelines_from_ds( + ds: Dataset, tier: Optional[str] = None +) -> List[ds_proto.Pipeline]: pipelines = [] for pipeline in ds._pipelines: - pipelines.append(_pipeline_to_proto(pipeline, ds)) - return pipelines + if pipeline.tier.is_entity_selected(tier): + pipelines.append(pipeline) + sync_validation_for_pipelines(pipelines, ds._name) + if len(pipelines) == 1: + pipelines[0].active = True + pipeline_protos = [] + for pipeline in pipelines: + pipeline_protos.append(_pipeline_to_proto(pipeline, ds)) + return pipeline_protos def _pipeline_to_proto(pipeline: Pipeline, ds: Dataset) -> ds_proto.Pipeline: @@ -290,7 +302,7 @@ def expectations_from_ds(ds: Dataset) -> List[exp_proto.Expectations]: def sources_from_ds( - ds: Dataset, source_field, timestamp_field: str + ds: Dataset, timestamp_field: str, tier: Optional[str] = None ) -> Optional[Tuple[connector_proto.ExtDatabase, connector_proto.Source]]: """ Returns the source proto for a dataset if it exists @@ -300,9 +312,25 @@ def sources_from_ds( :param timestamp_field: An optional column that can be used to sort the data from the source. """ - if hasattr(ds, source_field): - source: sources.DataConnector = getattr(ds, source_field) - return _conn_to_source_proto(source, ds._name, timestamp_field) + if hasattr(ds, sources.SOURCE_FIELD): + all_sources: List[sources.DataConnector] = getattr( + ds, sources.SOURCE_FIELD + ) + filtered_sources = [ + source + for source in all_sources + if source.tiers.is_entity_selected(tier) + ] + if len(filtered_sources) == 0: + return None + if len(filtered_sources) > 1: + raise ValueError( + f"Dataset {ds._name} has multiple sources ({len(filtered_sources)}) defined. " + f"Please define only one source per dataset, or check your tier selection." + ) + return _conn_to_source_proto( + filtered_sources[0], ds._name, timestamp_field + ) return None # type: ignore @@ -360,12 +388,20 @@ def _feature_to_proto(f: Feature) -> fs_proto.Feature: def extractors_from_fs( - fs: Featureset, fs_obj_map: Dict[str, Featureset] + fs: Featureset, + fs_obj_map: Dict[str, Featureset], + tier: Optional[str] = None, ) -> List[fs_proto.Extractor]: extractors = [] + extractor_protos = [] for extractor in fs._extractors: - extractors.append(_extractor_to_proto(extractor, fs, fs_obj_map)) - return extractors + if extractor.tiers.is_entity_selected(tier): + extractors.append(extractor) + extractor_protos.append( + _extractor_to_proto(extractor, fs, fs_obj_map) + ) + sync_validation_for_extractors(extractors) + return extractor_protos # Feature as input diff --git a/fennel/sources/sources.py b/fennel/sources/sources.py index 8d2f28bd4..6f1b5f5b8 100644 --- a/fennel/sources/sources.py +++ b/fennel/sources/sources.py @@ -4,13 +4,14 @@ from datetime import datetime from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, Callable, List, Optional, TypeVar, Union from fennel._vendor.pydantic import BaseModel # type: ignore from fennel._vendor.pydantic import validator # type: ignore from fennel.lib.duration import ( Duration, ) +from fennel.lib.includes import TierSelector T = TypeVar("T") SOURCE_FIELD = "__fennel_data_sources__" @@ -31,6 +32,7 @@ def source( since: Optional[datetime] = None, lateness: Optional[Duration] = None, cdc: Optional[str] = None, + tier: Optional[Union[str, List[str]]] = None, ) -> Callable[[T], Any]: if not isinstance(conn, DataConnector): if not isinstance(conn, DataSource): @@ -48,13 +50,10 @@ def decorator(dataset_cls: T): conn.lateness = lateness if lateness is not None else DEFAULT_LATENESS conn.cdc = cdc if cdc is not None else DEFAULT_CDC conn.starting_from = since - if hasattr(dataset_cls, SOURCE_FIELD): - raise Exception( - "Multiple sources are not supported in dataset `%s`." - % dataset_cls.__name__ # type: ignore - ) - else: - setattr(dataset_cls, SOURCE_FIELD, conn) + conn.tiers = TierSelector(tier) + connectors = getattr(dataset_cls, SOURCE_FIELD, []) + connectors.append(conn) + setattr(dataset_cls, SOURCE_FIELD, connectors) return dataset_cls return decorator @@ -366,6 +365,7 @@ class DataConnector: lateness: Duration cdc: str starting_from: Optional[datetime] = None + tiers: TierSelector def identifier(self): raise NotImplementedError diff --git a/fennel/sources/test_invalid_sources.py b/fennel/sources/test_invalid_sources.py index 62cd4b9d6..d681aeddb 100644 --- a/fennel/sources/test_invalid_sources.py +++ b/fennel/sources/test_invalid_sources.py @@ -227,7 +227,8 @@ class UserInfoDatasetKinesis: ) -def test_multiple_sources(): +@mock +def test_multiple_sources(client): with pytest.raises(Exception) as e: @meta(owner="test@test.com") @@ -254,9 +255,11 @@ class UserInfoDataset: country: Optional[str] timestamp: datetime = field(timestamp=True) + client.sync(datasets=[UserInfoDataset], featuresets=[]) + assert ( str(e.value) - == "Multiple sources are not supported in dataset `UserInfoDataset`." + == "Dataset `UserInfoDataset` has more than one source defined, found 4 sources." ) diff --git a/fennel/sources/test_sources.py b/fennel/sources/test_sources.py index 46493b27b..95a2d5011 100644 --- a/fennel/sources/test_sources.py +++ b/fennel/sources/test_sources.py @@ -1,5 +1,6 @@ from datetime import datetime +import pytest from google.protobuf.json_format import ParseDict # type: ignore from typing import Optional @@ -224,6 +225,114 @@ class UserInfoDataset: ) +def test_tier_selector_on_source(): + @meta(owner="test@test.com") + @source(kafka.topic("test_topic"), tier=["dev-2"]) + @source( + mysql.table("users_mysql", cursor="added_on"), + every="1h", + tier=["prod"], + ) + @source( + snowflake.table("users_Sf", cursor="added_on"), + every="1h", + tier=["staging"], + ) + @source( + s3.bucket( + bucket_name="all_ratings", + prefix="prod/apac/", + ), + every="1h", + lateness="2d", + tier=["dev"], + ) + @dataset + class UserInfoDataset: + user_id: int = field(key=True) + name: str + gender: str + # Users date of birth + dob: str + age: int + account_creation_date: datetime + country: Optional[str] + timestamp: datetime = field(timestamp=True) + + view = InternalTestClient() + view.add(UserInfoDataset) + with pytest.raises(ValueError) as e: + view._get_sync_request_proto() + assert ( + str(e.value) + == "Dataset UserInfoDataset has multiple sources (4) defined. Please define only one source per dataset, or check your tier selection." + ) + sync_request = view._get_sync_request_proto("prod") + assert len(sync_request.datasets) == 1 + assert len(sync_request.sources) == 1 + assert len(sync_request.extdbs) == 1 + source_request = sync_request.sources[0] + s = { + "table": { + "mysqlTable": { + "db": { + "name": "mysql", + "mysql": { + "host": "localhost", + "database": "test", + "user": "root", + "password": "root", + "port": 3306, + }, + }, + "tableName": "users_mysql", + } + }, + "dataset": "UserInfoDataset", + "every": "3600s", + "cursor": "added_on", + "lateness": "3600s", + "timestampField": "timestamp", + } + expected_source_request = ParseDict(s, connector_proto.Source()) + assert source_request == expected_source_request, error_message( + source_request, expected_source_request + ) + sync_request = view._get_sync_request_proto("staging") + assert len(sync_request.datasets) == 1 + assert len(sync_request.sources) == 1 + assert len(sync_request.extdbs) == 1 + source_request = sync_request.sources[0] + s = { + "table": { + "snowflakeTable": { + "db": { + "name": "snowflake_src", + "snowflake": { + "account": "nhb38793.us-west-2.snowflakecomputing.com", + "user": "", + "password": "", + "schema": "PUBLIC", + "warehouse": "TEST", + "role": "ACCOUNTADMIN", + "database": "MOVIELENS", + }, + }, + "tableName": "users_Sf", + } + }, + "dataset": "UserInfoDataset", + "every": "3600s", + "cursor": "added_on", + "lateness": "3600s", + "timestampField": "timestamp", + } + expected_source_request = ParseDict(s, connector_proto.Source()) + assert source_request == expected_source_request, error_message( + source_request, expected_source_request + ) + + def test_multiple_sources(): @meta(owner="test@test.com") @source( diff --git a/fennel/test_lib/mock_client.py b/fennel/test_lib/mock_client.py index 26a4f4b70..b7c9207b1 100644 --- a/fennel/test_lib/mock_client.py +++ b/fennel/test_lib/mock_client.py @@ -20,7 +20,9 @@ from fennel._vendor.requests import Response # type: ignore from fennel.client import Client from fennel.datasets import Dataset, field, Pipeline, OnDemand # noqa +from fennel.datasets.datasets import sync_validation_for_pipelines from fennel.featuresets import Featureset, Feature, Extractor +from fennel.featuresets.featureset import sync_validation_for_extractors from fennel.gen.dataset_pb2 import CoreDataset from fennel.gen.featureset_pb2 import CoreFeatureset from fennel.gen.featureset_pb2 import ( @@ -329,6 +331,7 @@ def sync( self, datasets: Optional[List[Dataset]] = None, featuresets: Optional[List[Featureset]] = None, + tier: Optional[str] = None, ): self._reset() if datasets is None: @@ -343,7 +346,7 @@ def sync( ) self.dataset_requests[dataset._name] = dataset_to_proto(dataset) if hasattr(dataset, sources.SOURCE_FIELD): - self._process_data_connector(dataset) + self._process_data_connector(dataset, tier) self.datasets[dataset._name] = dataset is_source_dataset = hasattr(dataset, sources.SOURCE_FIELD) @@ -361,10 +364,13 @@ def sync( raise ValueError( f"Dataset {dataset._name} has no pipelines and is not a source dataset" ) - - for pipeline in dataset._pipelines: - if not pipeline.active: - continue + selected_pipelines = [ + x + for x in dataset._pipelines + if x.tier.is_entity_selected(tier) and x.active + ] + sync_validation_for_pipelines(selected_pipelines, dataset._name) + for pipeline in selected_pipelines: for input in pipeline.inputs: self.listeners[input._name].append(pipeline) @@ -382,6 +388,8 @@ def sync( ) # Check if the dataset used by the extractor is registered for extractor in featureset.extractors: + if not extractor.tiers.is_entity_selected(tier): + continue datasets = [ x._name for x in extractor.get_dataset_dependencies() ] @@ -390,13 +398,19 @@ def sync( raise ValueError( f"Dataset {dataset} not found in sync call" ) - self.extractors.extend(featureset.extractors) + self.extractors.extend( + [ + x + for x in featureset.extractors + if x.tiers.is_entity_selected(tier) + ] + ) fs_obj_map = { featureset._name: featureset for featureset in featuresets } for featureset in featuresets: - proto_extractors = extractors_from_fs(featureset, fs_obj_map) + proto_extractors = extractors_from_fs(featureset, fs_obj_map, tier) for extractor in proto_extractors: if extractor.extractor_type != ProtoExtractorType.PY_FUNC: continue @@ -522,8 +536,17 @@ def is_integration_client(self) -> bool: # ----------------- Private methods -------------------------------------- - def _process_data_connector(self, dataset: Dataset): + def _process_data_connector(self, dataset: Dataset, tier): connector = getattr(dataset, sources.SOURCE_FIELD) + connector = connector if isinstance(connector, list) else [connector] + connector = [x for x in connector if x.tiers.is_entity_selected(tier)] + if len(connector) > 1: + raise ValueError( + f"Dataset `{dataset._name}` has more than one source defined, found {len(connector)} sources." + ) + if len(connector) == 0: + return + connector = connector[0] if isinstance(connector, sources.WebhookConnector): src = connector.data_source webhook_endpoint = f"{src.name}:{connector.endpoint}" diff --git a/pyproject.toml b/pyproject.toml index 616e34202..7ed954ee2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fennel-ai" -version = "0.18.9" +version = "0.18.11" description = "The modern realtime feature engineering platform" authors = ["Fennel AI "] packages = [{ include = "fennel" }]