Skip to content

Commit

Permalink
feat: Added expedia provider and batching materialization process
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargav Dodla committed Mar 21, 2024
1 parent 8e4bdc8 commit 52958c6
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 60 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,5 @@ sdk/python/feast/binaries/
/sdk/python/feast/open_api/
/sdk/python/feast/test_open_api/

venv39/*
venv39/*
test.go
Empty file.
60 changes: 60 additions & 0 deletions sdk/python/feast/expediagroup/provider/expedia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
from typing import List, Set

import pandas as pd

from feast.feature_view import FeatureView
from feast.infra.passthrough_provider import PassthroughProvider
from feast.repo_config import RepoConfig
from feast.stream_feature_view import StreamFeatureView

logger = logging.getLogger(__name__)


class ExpediaProvider(PassthroughProvider):
def __init__(self, config: RepoConfig):
logger.info("Initializing Expedia provider...")

if config.batch_engine.type != "spark.engine":
logger.warning("Expedia provider recommends spark materialization engine")

if config.offline_store.type != "spark":
logger.warning(
"Expedia provider recommends spark offline store as it only support SparkSource as Batch source"
)

super().__init__(config)

def ingest_df(
self,
feature_view: FeatureView,
df: pd.DataFrame,
):
drop_list: List[str] = []
fv_schema: Set[str] = set(map(lambda field: field.name, feature_view.schema))
# Add timestamp field to the schema so we don't delete from dataframe
if isinstance(feature_view, StreamFeatureView):
fv_schema.add(feature_view.timestamp_field)
if feature_view.source.created_timestamp_column:
fv_schema.add(feature_view.source.created_timestamp_column)

if isinstance(feature_view, FeatureView):
if feature_view.stream_source is not None:
fv_schema.add(feature_view.stream_source.timestamp_field)
if feature_view.stream_source.created_timestamp_column:
fv_schema.add(feature_view.stream_source.created_timestamp_column)
else:
fv_schema.add(feature_view.batch_source.timestamp_field)
if feature_view.batch_source.created_timestamp_column:
fv_schema.add(feature_view.batch_source.created_timestamp_column)

for column in df.columns:
if column not in fv_schema:
drop_list.append(column)

if len(drop_list) > 0:
print(
f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe."
)

super().ingest_df(feature_view, df.drop(drop_list, axis=1))
9 changes: 9 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import copy
import itertools
import logging
import os
import warnings
from collections import Counter, defaultdict
Expand Down Expand Up @@ -106,6 +107,8 @@

warnings.simplefilter("ignore", DeprecationWarning)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from feast.embedded_go.online_features_service import EmbeddedOnlineFeatureServer

Expand Down Expand Up @@ -2626,12 +2629,18 @@ def _print_materialization_log(
f" to {Style.BRIGHT + Fore.GREEN}{end_date.replace(microsecond=0).astimezone()}{Style.RESET_ALL}"
f" into the {Style.BRIGHT + Fore.GREEN}{online_store}{Style.RESET_ALL} online store.\n"
)
logger.info(
f"Materializing {num_feature_views} feature views from {start_date.replace(microsecond=0).astimezone()} to {end_date.replace(microsecond=0).astimezone()} into the {online_store} online store."
)
else:
print(
f"Materializing {Style.BRIGHT + Fore.GREEN}{num_feature_views}{Style.RESET_ALL} feature views"
f" to {Style.BRIGHT + Fore.GREEN}{end_date.replace(microsecond=0).astimezone()}{Style.RESET_ALL}"
f" into the {Style.BRIGHT + Fore.GREEN}{online_store}{Style.RESET_ALL} online store.\n"
)
logger.info(
f"Materializing {num_feature_views} feature views to {end_date.replace(microsecond=0).astimezone()} into the {online_store} online store."
)


def _validate_feature_views(feature_views: List[BaseFeatureView]):
Expand Down
22 changes: 8 additions & 14 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import MethodType
from typing import List, Optional, Set, Union
from typing import List, Optional, Union

import pandas as pd
from pyspark import SparkContext
Expand Down Expand Up @@ -180,18 +180,9 @@ def _ingest_stream_data(self) -> StreamTable:

def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
if isinstance(self.sfv, FeatureView):
drop_list: List[str] = []
fv_schema: Set[str] = set(map(lambda field: field.name, self.sfv.schema))
# Add timestamp field to the schema so we don't delete from dataframe
if isinstance(self.sfv, StreamFeatureView):
fv_schema.add(self.sfv.timestamp_field)
else:
fv_schema.add(self.sfv.stream_source.timestamp_field)
for column in df.columns:
if column not in fv_schema:
drop_list.append(column)
return df.drop(*drop_list)
return self.sfv.udf.__call__(df) if self.sfv.udf else df
return df
elif isinstance(self.sfv, StreamFeatureView):
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_stream_data(self, df: StreamTable, to: PushMode):
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
Expand All @@ -209,7 +200,10 @@ def batch_write(row: DataFrame, batch_id: int):
.groupby(self.join_keys)
.nth(0)
)
rows["created"] = pd.to_datetime("now", utc=True)
# Created column is not used anywhere in the code, but it is added to the dataframe.
# Expedia provider drops the unused columns from dataframe
# Commenting this out as it is not used anywhere in the code
# rows["created"] = pd.to_datetime("now", utc=True)

# Reset indices to ensure the dataframe has all the required columns.
rows = rows.reset_index()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, List, Literal, Optional, Sequence, Union, cast

import dill
import numpy as np
import pandas as pd
import pyarrow
from tqdm import tqdm
Expand Down Expand Up @@ -33,8 +34,6 @@
_run_pyarrow_field_mapping,
)

logger = logging.getLogger(__name__)


class SparkMaterializationEngineConfig(FeastConfigBaseModel):
"""Batch Materialization Engine config for spark engine"""
Expand All @@ -45,6 +44,11 @@ class SparkMaterializationEngineConfig(FeastConfigBaseModel):
partitions: int = 0
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done"""

batch_size: int = 10000
"""Batch size determines the number of rows to be written to the online store in a single batch per partitions.
Adjust this value based on the number of parallel executors defined so that the online store can handle the load.
To overwrite at each feature view level, set the tag 'batch_size' in the feature view definition."""


@dataclass
class SparkMaterializationJob(MaterializationJob):
Expand Down Expand Up @@ -181,8 +185,22 @@ def _materialize_one(
self.repo_config.batch_engine.partitions
)

# Calculate batch_size per feature_view
feature_view_batch_size = (
self.repo_config.batch_engine.batch_size
if "batch_size" not in feature_view.tags
or feature_view.tags["batch_size"] is None
else int(feature_view.tags["batch_size"])
)

print(
f"INFO!!! Processing {feature_view.name} with {spark_df.count()} records, batch size {feature_view_batch_size}"
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
lambda x: _process_by_partition(
x, spark_serialized_artifacts, feature_view_batch_size
)
)

return SparkMaterializationJob(
Expand Down Expand Up @@ -228,38 +246,66 @@ def unserialize(self):
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
def _process_by_partition(
rows,
spark_serialized_artifacts: _SparkSerializedArtifacts,
batch_size: int,
):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())
# def write_to_online_store_in_batches(batch_dict, batch_id):
def write_to_online_store_in_batches(batch_df: pd.DataFrame):
batch_id = batch_df.name
start_time = time.time()
# convert to pyarrow table
if batch_df.shape[0] == 0:
print("INFO!!! Dataframe has 0 records to process")
return

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
logger.info("Dataframe has 0 records to process")
return
table = pyarrow.Table.from_pandas(batch_df)

table = pyarrow.Table.from_pandas(df)
# unserialize artifacts
(
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
end_time = time.time()
print(
f"INFO!!! Time taken to write batch {batch_id} is {int((end_time - start_time) * 1000)} milliseconds"
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
# Spark 3.3.0 or above supports toPandas() method. We are running on spark 3.2.2
pandas_dataframe = pd.DataFrame([row.asDict() for row in rows])
# TODO: For Pyspark applications, we should use py4j bridge to initialize loggers
# Temporarily using print to display logs
print(
f"INFO!!! Processing a partition with {pandas_dataframe.shape[0]} records and batch size {batch_size}"
)

if "fs_batch" in pandas_dataframe.columns:
raise ValueError(
"Column 'fs_batch' is reserved by Feature Store. Please rename to avoid conflicts."
)
pandas_dataframe["fs_batch"] = np.arange(len(pandas_dataframe)) // batch_size
pandas_dataframe.groupby("fs_batch").apply(write_to_online_store_in_batches)
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def pull_latest_from_table_or_query(
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
assert isinstance(data_source, SparkSource)

warnings.warn(
"The spark offline store is an experimental feature in alpha development. "
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)
# warnings.warn(
# "The spark offline store is an experimental feature in alpha development. "
# "Some functionality may still be unstable so functionality can change in the future.",
# RuntimeWarning,
# )

print("Pulling latest features from spark offline store")

Expand Down Expand Up @@ -134,11 +134,11 @@ def get_historical_features(
for fv in feature_views:
assert isinstance(fv.batch_source, SparkSource)

warnings.warn(
"The spark offline store is an experimental feature in alpha development. "
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)
# warnings.warn(
# "The spark offline store is an experimental feature in alpha development. "
# "Some functionality may still be unstable so functionality can change in the future.",
# RuntimeWarning,
# )

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
Expand Down Expand Up @@ -276,11 +276,11 @@ def pull_all_from_table_or_query(
"""
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
assert isinstance(data_source, SparkSource)
warnings.warn(
"The spark offline store is an experimental feature in alpha development. "
"This API is unstable and it could and most probably will be changed in the future.",
RuntimeWarning,
)
# warnings.warn(
# "The spark offline store is an experimental feature in alpha development. "
# "This API is unstable and it could and most probably will be changed in the future.",
# RuntimeWarning,
# )

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"aws": "feast.infra.aws.AwsProvider",
"local": "feast.infra.local.LocalProvider",
"azure": "feast.infra.contrib.azure_provider.AzureProvider",
"expedia": "feast.expediagroup.provider.expedia.ExpediaProvider",
}


Expand Down
16 changes: 16 additions & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def __init__(self, **data: Any):
self._offline_config = "redshift"
elif data["provider"] == "azure":
self._offline_config = "mssql"
elif data["provider"] == "expedia":
spark_config = {
"type": "spark",
"spark_conf": {
"spark.sql.catalog.spark_catalog": "org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.spark_catalog.type": "hive",
"spark.sql.iceberg.handle-timestamp-without-timezone": "true",
},
}
self._offline_config = spark_config

self._online_store = None
if "online_store" in data:
Expand All @@ -224,12 +234,16 @@ def __init__(self, **data: Any):
self._online_config = "dynamodb"
elif data["provider"] == "rockset":
self._online_config = "rockset"
elif data["provider"] == "expedia":
self._online_config = "redis"

self._batch_engine = None
if "batch_engine" in data:
self._batch_engine_config = data["batch_engine"]
elif "batch_engine_config" in data:
self._batch_engine_config = data["batch_engine_config"]
elif data["provider"] == "expedia":
self._batch_engine_config = "spark.engine"
else:
# Defaults to using local in-process materialization engine.
self._batch_engine_config = "local"
Expand Down Expand Up @@ -390,6 +404,8 @@ def _validate_offline_store_config(cls, values):
values["offline_store"]["type"] = "redshift"
if values["provider"] == "azure":
values["offline_store"]["type"] = "mssql"
if values["provider"] == "expedia":
values["offline_store"]["type"] = "spark"

offline_store_type = values["offline_store"]["type"]

Expand Down

0 comments on commit 52958c6

Please sign in to comment.