Skip to content

Commit

Permalink
auto extractor: Bug fixes for code gen and default values
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar committed Nov 12, 2023
1 parent cedf5b4 commit a391a84
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 42 deletions.
293 changes: 293 additions & 0 deletions fennel/client_tests/test_complex_autogen_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import sys

from datetime import datetime, timedelta
from typing import Optional

import pandas as pd
import pytest

import fennel

from fennel import meta, Count, Window, featureset, feature, extractor
from fennel.client import Client
from fennel.lib.schema import inputs, oneof, outputs
from fennel.lib.aggregate import Max, Min
from fennel.sources import Webhook, S3, MySQL
from fennel.datasets import dataset, field, pipeline, Dataset
from fennel.sources import source
from fennel.test_lib import MockClient, mock

webhook = Webhook(name="fennel_webhook")

__owner__ = "[email protected]"


@dataset
@source(webhook.endpoint("RiderDataset"), tier="local")
class RiderDataset:
rider_id: int = field(key=True)
created: datetime = field(timestamp=True)
birthdate: datetime


@dataset
@source(webhook.endpoint("RiderCreditScoreDataset"), tier="local")
class RiderCreditScoreDataset:
rider_id: int = field(key=True)
created: datetime
score: float


@dataset
@source(webhook.endpoint("CountryLicenseDataset"), tier="local")
@meta(owner="[email protected]")
class CountryLicenseDataset:
rider_id: int = field(key=True)
created: datetime
country_code: str


@dataset
@source(webhook.endpoint("ReservationsDataset"), tier="local")
class ReservationsDataset:
rider_id: int
vehicle_id: int
is_completed_trip: int
created: datetime


@dataset
@source(webhook.endpoint("NumCompletedTripsDataset"), tier="local")
class NumCompletedTripsDataset:
rider_id: int = field(key=True)
count_num_completed_trips: int
created: datetime

@pipeline()
@inputs(ReservationsDataset)
def my_pipeline(cls, reservations: Dataset):
completed = reservations.filter(lambda df: df["is_completed_trip"] == 1)
return completed.groupby("rider_id").aggregate(
Count(
of="vehicle_id",
unique=True,
approx=True,
window=Window("forever"),
into_field="count_num_completed_trips",
),
)


@featureset
class RequestFeatures0:
ts: datetime = feature(id=1)
rider_id: int = feature(id=2)


@featureset
class RequestFeatures1:
ts: datetime = feature(id=1)
id1: int = feature(id=2).extract(feature=RequestFeatures0.rider_id) # type: ignore


@featureset
class RequestFeatures2:
ts: datetime = feature(id=1)
id2: int = feature(id=2).extract(feature=RequestFeatures1.id1) # type: ignore
const: int = feature(id=3)
num_trips: int = feature(id=4).extract( # type: ignore
field=NumCompletedTripsDataset.count_num_completed_trips,
provider=RequestFeatures0,
default=0,
)

@extractor()
@inputs(id2)
@outputs(const)
def extract_const(cls, ts: pd.Series, id2: pd.Series) -> pd.DataFrame:
return pd.DataFrame({"const": [1] * len(ts)})


@featureset
class RequestFeatures3:
ts: datetime = feature(id=1)
rider_id: int = feature(id=2).extract(feature=RequestFeatures2.id2) # type: ignore
vehicle_id: int = feature(id=3)
reservation_id: Optional[int] = feature(id=4)


@featureset
class RiderFeatures:
id: int = feature(id=1).extract(feature=RequestFeatures2.id2) # type: ignore
created: datetime = feature(id=2).extract( # type: ignore
field=RiderDataset.created,
provider=RequestFeatures3,
default=datetime(2000, 1, 1, 0, 0, 0),
)
birthdate: datetime = feature(id=3).extract( # type: ignore
field=RiderDataset.birthdate,
provider=RequestFeatures3,
default=datetime(2000, 1, 1, 0, 0, 0),
)
age_years: int = feature(id=4)
ais_score: float = feature(id=5).extract( # type: ignore
field=RiderCreditScoreDataset.score,
provider=RequestFeatures3,
default=-1.0,
)
dl_state: str = feature(id=6).extract( # type: ignore
field=CountryLicenseDataset.country_code,
provider=RequestFeatures3,
default="Unknown",
)
is_us_dl: bool = feature(id=7)
num_past_completed_trips: int = feature(id=8).extract( # type: ignore
field=NumCompletedTripsDataset.count_num_completed_trips,
provider=RequestFeatures3,
default=0,
)

@extractor
@inputs(dl_state)
@outputs(is_us_dl)
def extract_is_us_dl(
cls, ts: pd.Series, dl_state: pd.Series
) -> pd.DataFrame:
is_us_dl = dl_state == "US"
return pd.DataFrame({"is_us_dl": is_us_dl})

@extractor
@inputs(birthdate)
@outputs(age_years)
def extract_age_years(
cls, ts: pd.Series, birthdate: pd.Series
) -> pd.DataFrame:
age_years = (datetime.now() - birthdate).dt.total_seconds() / (
60 * 60 * 24 * 365
)
age_years = age_years.astype(int)
return pd.DataFrame({"age_years": age_years})


@mock
def test_complex_auto_gen_extractors(client):
with pytest.raises(ValueError) as e:
_ = client.sync(
datasets=[
RiderDataset,
RiderCreditScoreDataset,
CountryLicenseDataset,
ReservationsDataset,
NumCompletedTripsDataset,
],
featuresets=[
RiderFeatures,
RequestFeatures1,
RequestFeatures2,
RequestFeatures3,
],
)
error_msg1 = "Featureset `RequestFeatures0` is required by `RequestFeatures1` but is not present in the sync call. Please ensure that all featuresets are present in the sync call."
error_msg2 = error_msg1.replace("RequestFeatures1", "RequestFeatures2")
assert str(e.value) == error_msg1 or str(e.value) == error_msg2

with pytest.raises(ValueError) as e:
_ = client.sync(
datasets=[
RiderDataset,
RiderCreditScoreDataset,
CountryLicenseDataset,
ReservationsDataset,
NumCompletedTripsDataset,
],
featuresets=[
RiderFeatures,
RequestFeatures0,
RequestFeatures1,
RequestFeatures3,
],
)
error_msg1 = "Featureset `RequestFeatures2` is required by `RiderFeatures` but is not present in the sync call. Please ensure that all featuresets are present in the sync call."
error_msg2 = error_msg1.replace("RiderFeatures", "RequestFeatures3")
assert str(e.value) == error_msg1 or str(e.value) == error_msg2

with pytest.raises(ValueError) as e:
_ = client.sync(
datasets=[
RiderDataset,
RiderCreditScoreDataset,
CountryLicenseDataset,
ReservationsDataset,
],
featuresets=[
RiderFeatures,
RequestFeatures0,
RequestFeatures1,
RequestFeatures2,
RequestFeatures3,
],
)
assert (
str(e.value)
== "Dataset NumCompletedTripsDataset not found in sync call"
)

resp = client.sync(
datasets=[
RiderDataset,
RiderCreditScoreDataset,
CountryLicenseDataset,
ReservationsDataset,
NumCompletedTripsDataset,
],
featuresets=[
RiderFeatures,
RequestFeatures0,
RequestFeatures1,
RequestFeatures2,
RequestFeatures3,
],
)

assert resp.status_code == 200

rider_df = pd.DataFrame(
{
"rider_id": [1],
"created": [datetime.now()],
"birthdate": [datetime.now() - timedelta(days=365 * 30)],
"country_code": ["US"],
}
)

log_response = client.log(
webhook="fennel_webhook", endpoint="RiderDataset", df=rider_df
)
assert log_response.status_code == 200

reservation_df = pd.DataFrame(
{
"rider_id": [1],
"vehicle_id": [1],
"is_completed_trip": [1],
"created": [datetime.now()],
}
)
log_response = client.log(
webhook="fennel_webhook",
endpoint="ReservationsDataset",
df=reservation_df,
)
assert log_response.status_code == 200

extracted_df = client.extract_features(
input_feature_list=[RequestFeatures0.rider_id],
output_feature_list=[RiderFeatures],
input_dataframe=pd.DataFrame({"RequestFeatures0.rider_id": [1]}),
)
assert extracted_df.shape[0] == 1
assert (
extracted_df["RiderFeatures.created"].iloc[0]
== rider_df["created"].iloc[0]
)
assert extracted_df["RiderFeatures.dl_state"].iloc[0] == "Unknown"
41 changes: 31 additions & 10 deletions fennel/featuresets/featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def fqn(self) -> str:

def extract(
self,
*,
field: Field = None,
provider: Featureset = None,
default=None,
Expand Down Expand Up @@ -484,7 +485,7 @@ def __init__(
setattr(self, OWNER, owner)
propogate_fennel_attributes(featureset_cls, self)

def get_dataset_dependencies(self):
def get_dataset_dependencies(self) -> List[Dataset]:
"""
This function gets the list of datasets the Featureset depends upon.
This dependency is introduced by features that directly lookup a dataset
Expand All @@ -509,6 +510,24 @@ def get_dataset_dependencies(self):

return depended_datasets

def get_featureset_dependencies(self) -> List[str]:
"""
This function gets the list of featuresets the Featureset depends upon.
This dependency is introduced by features that directly lookup a featureset
via the FS-FS route, while specifying a provider.
"""
depended_featuresets = set()
for f in self._features:
if f.extractor is None:
continue
if f.extractor.extractor_type == ExtractorType.ALIAS:
# Alias extractors have exactly one input feature
depended_featuresets.add(f.extractor.inputs[0].featureset_name)
elif f.extractor.extractor_type == ExtractorType.LOOKUP:
for inp_feature in f.extractor.inputs:
depended_featuresets.add(inp_feature.featureset_name)
return list(depended_featuresets)

# ------------------- Private Methods ----------------------------------

def _add_feature_names_as_attributes(self):
Expand All @@ -527,7 +546,7 @@ def _get_extractors(self) -> List[Extractor]:
if extractor.extractor_type == ExtractorType.LOOKUP and (
extractor.inputs is None or len(extractor.inputs) == 0
):
feature.extractor.set_inputs_from_featureset(self)
feature.extractor.set_inputs_from_featureset(self, feature)
extractors.append(extractor)

# user defined extractors
Expand Down Expand Up @@ -678,7 +697,9 @@ def get_included_modules(self) -> List[Callable]:
return getattr(self.func, FENNEL_INCLUDED_MOD)
return []

def set_inputs_from_featureset(self, featureset: Featureset):
def set_inputs_from_featureset(
self, featureset: Featureset, feature: Feature
):
if self.inputs and len(self.inputs) > 0:
return
if self.extractor_type != ExtractorType.LOOKUP:
Expand All @@ -696,22 +717,22 @@ def set_inputs_from_featureset(self, featureset: Featureset):
ds = field.dataset
if not ds:
raise ValueError(
f"Dataset {field.dataset_name} not found for field {field}"
f"Dataset `{field.dataset_name}` not found for field `{field}`"
)
self.depends_on = [ds]
for k in ds.dsschema().keys:
feature = featureset.feature(k)
if not feature:
f = featureset.feature(k)
if not f:
raise ValueError(
f"Dataset key {k} not found in provider {featureset._name} for extractor {self.name}"
f"Key field `{k}` for dataset `{ds._name}` not found in provider `{featureset._name}` for feature: `{feature.name}` auto generated extractor"
)
self.inputs.append(feature)
self.inputs.append(f)

class DatasetLookupInfo:
field: Field
default: Any
default: Optional[Any] = None

def __init__(self, field: Field, default_val: Any):
def __init__(self, field: Field, default_val: Optional[Any] = None):
self.field = field
self.default = default_val

Expand Down
Loading

0 comments on commit a391a84

Please sign in to comment.