From cf0f2f29a9ec3ee7d341649def4fe99f42edc299 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla <13788369+EXPEbdodla@users.noreply.github.com> Date: Wed, 6 Nov 2024 20:33:24 -0800 Subject: [PATCH] feat: Supports nested struct columns as features, timestamp fields (#153) * feat: Supports nested struct columns as features, timestamp fields * fix: Added field mapping support for spark streaming * feat: Add support for field mapping in SparkOfflineStore * feat: Add timing for batch write operations in SparkKafkaProcessor * feat: Enhance SparkKafkaProcessor logging and add unit tests for SparkOfflineStore * fix: Remove unnecessary f-string usage in SparkOfflineStore tests * fix: Renamed integration test file name to avoid conflicts * refactor: Remove unused ingest_df method and clean up imports in ExpediaProvider --------- Co-authored-by: Bhargav Dodla --- .../feast/expediagroup/provider/expedia.py | 39 ------ .../infra/contrib/spark_kafka_processor.py | 72 +++++----- .../spark/spark_materialization_engine.py | 2 + .../contrib/spark_offline_store/spark.py | 25 +++- sdk/python/feast/utils.py | 33 ++++- ...spark.py => test_spark_materialization.py} | 0 .../contrib/spark_offline_store/test_spark.py | 129 ++++++++++++++++++ 7 files changed, 216 insertions(+), 84 deletions(-) rename sdk/python/tests/integration/materialization/contrib/spark/{test_spark.py => test_spark_materialization.py} (100%) create mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py diff --git a/sdk/python/feast/expediagroup/provider/expedia.py b/sdk/python/feast/expediagroup/provider/expedia.py index dc9fed9d49..dd991652b9 100644 --- a/sdk/python/feast/expediagroup/provider/expedia.py +++ b/sdk/python/feast/expediagroup/provider/expedia.py @@ -1,12 +1,7 @@ import logging -from typing import List, Set -import pandas as pd - -from feast.feature_view import FeatureView from feast.infra.passthrough_provider import PassthroughProvider from feast.repo_config import RepoConfig -from feast.stream_feature_view import StreamFeatureView logger = logging.getLogger(__name__) @@ -24,37 +19,3 @@ def __init__(self, config: RepoConfig): ) super().__init__(config) - - def ingest_df( - self, - feature_view: FeatureView, - df: pd.DataFrame, - ): - drop_list: List[str] = [] - fv_schema: Set[str] = set(map(lambda field: field.name, feature_view.schema)) - # Add timestamp field to the schema so we don't delete from dataframe - if isinstance(feature_view, StreamFeatureView): - fv_schema.add(feature_view.timestamp_field) - if feature_view.source.created_timestamp_column: - fv_schema.add(feature_view.source.created_timestamp_column) - - if isinstance(feature_view, FeatureView): - if feature_view.stream_source is not None: - fv_schema.add(feature_view.stream_source.timestamp_field) - if feature_view.stream_source.created_timestamp_column: - fv_schema.add(feature_view.stream_source.created_timestamp_column) - else: - fv_schema.add(feature_view.batch_source.timestamp_field) - if feature_view.batch_source.created_timestamp_column: - fv_schema.add(feature_view.batch_source.created_timestamp_column) - - for column in df.columns: - if column not in fv_schema: - drop_list.append(column) - - if len(drop_list) > 0: - print( - f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe." - ) - - super().ingest_df(feature_view, df.drop(drop_list, axis=1)) diff --git a/sdk/python/feast/infra/contrib/spark_kafka_processor.py b/sdk/python/feast/infra/contrib/spark_kafka_processor.py index 12832a5f66..20fd6b28ad 100644 --- a/sdk/python/feast/infra/contrib/spark_kafka_processor.py +++ b/sdk/python/feast/infra/contrib/spark_kafka_processor.py @@ -1,3 +1,4 @@ +import time from types import MethodType from typing import List, Optional, Set, Union, no_type_check @@ -199,7 +200,37 @@ def _ingest_stream_data(self) -> StreamTable: def _construct_transformation_plan(self, df: StreamTable) -> StreamTable: if isinstance(self.sfv, FeatureView): - return df + # Apply field mapping if it exists. + if self.sfv.stream_source is not None: + if self.sfv.stream_source.field_mapping is not None: + for ( + field_mapping_key, + field_mapping_value, + ) in self.sfv.stream_source.field_mapping.items(): + df = df.withColumn(field_mapping_value, df[field_mapping_key]) + + # Drop unused columns + ## Note: This may need reconsideration when we support writing to offline store for Feature Views + drop_list: List[str] = [] + fv_schema: Set[str] = set( + map(lambda field: field.name, self.sfv.schema) + ) + + fv_schema.add(self.sfv.stream_source.timestamp_field) + if self.sfv.stream_source.created_timestamp_column: + fv_schema.add(self.sfv.stream_source.created_timestamp_column) + + for column in df.columns: + if column not in fv_schema: + drop_list.append(column) + + if len(drop_list) > 0: + print( + f"INFO!!! Dropping extra columns in the DataFrame: {drop_list}. Avoid unnecessary columns in the dataframe." + ) + return df.drop(*drop_list) + else: + raise Exception(f"Stream source is not defined for {self.sfv.name}") elif isinstance(self.sfv, StreamFeatureView): return self.sfv.udf.__call__(df) if self.sfv.udf else df @@ -271,45 +302,16 @@ def batch_write( join_keys, feature_view, ): - drop_list: List[str] = [] - fv_schema: Set[str] = set( - map(lambda field: field.name, feature_view.schema) - ) - # Add timestamp field to the schema so we don't delete from dataframe - if isinstance(feature_view, StreamFeatureView): - fv_schema.add(feature_view.timestamp_field) - if feature_view.source.created_timestamp_column: - fv_schema.add(feature_view.source.created_timestamp_column) - - if isinstance(feature_view, FeatureView): - if feature_view.stream_source is not None: - fv_schema.add(feature_view.stream_source.timestamp_field) - if feature_view.stream_source.created_timestamp_column: - fv_schema.add( - feature_view.stream_source.created_timestamp_column - ) - else: - fv_schema.add(feature_view.batch_source.timestamp_field) - if feature_view.batch_source.created_timestamp_column: - fv_schema.add( - feature_view.batch_source.created_timestamp_column - ) - - for column in df.columns: - if column not in fv_schema: - drop_list.append(column) - - if len(drop_list) > 0: - print( - f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe." - ) - - sdf.drop(*drop_list).mapInPandas( + start_time = time.time() + sdf.mapInPandas( lambda x: batch_write_pandas_df( x, spark_serialized_artifacts, join_keys ), "status int", ).count() # dummy action to force evaluation + print( + f"Time taken to write batch {batch_id} is: {(time.time() - start_time) * 1000:.2f} ms" + ) query = ( df.writeStream.outputMode("update") diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 8adfeb67f1..d33daa0b60 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -252,6 +252,8 @@ def _map_by_partition( ) = spark_serialized_artifacts.unserialize() if feature_view.batch_source.field_mapping is not None: + # Spark offline store does the field mapping during pull_latest_from_table_or_query + # This is for the case where the offline store is not spark table = _run_pyarrow_field_mapping( table, feature_view.batch_source.field_mapping ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 237528442f..18491cb58d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -34,6 +34,7 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.type_map import spark_schema_to_np_dtypes +from feast.utils import _get_fields_with_aliases # Make sure spark warning are ignored warnings.simplefilter("ignore", RuntimeWarning) @@ -91,16 +92,23 @@ def pull_latest_from_table_or_query( if created_timestamp_column: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" - field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) + + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=join_key_columns + feature_name_columns + timestamps, + field_mappings=data_source.field_mapping, + ) + + fields_as_string = ", ".join(fields_with_aliases) + aliases_as_string = ", ".join(aliases) start_date_str = _format_datetime(start_date) end_date_str = _format_datetime(end_date) query = f""" SELECT - {field_string} + {aliases_as_string} {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} FROM ( - SELECT {field_string}, + SELECT {fields_as_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_ FROM {from_expression} t1 WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}') @@ -280,14 +288,19 @@ def pull_all_from_table_or_query( spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store ) - - fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field]) from_expression = data_source.get_table_query_string() start_date = start_date.astimezone(tz=timezone.utc) end_date = end_date.astimezone(tz=timezone.utc) + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=join_key_columns + feature_name_columns + [timestamp_field], + field_mappings=data_source.field_mapping, + ) + + fields_with_alias_string = ", ".join(fields_with_aliases) + query = f""" - SELECT {fields} + SELECT {fields_with_alias_string} FROM {from_expression} WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' """ diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 992869557a..a67647ccff 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -103,7 +103,8 @@ def _get_requested_feature_views_to_features_dict( on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]: """Create a dict of FeatureView -> List[Feature] for all requested features. - Set full_feature_names to True to have feature names prefixed by their feature view name.""" + Set full_feature_names to True to have feature names prefixed by their feature view name. + """ feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list) on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = ( @@ -209,6 +210,28 @@ def _run_pyarrow_field_mapping( return table +def _get_fields_with_aliases( + fields: List[str], + field_mappings: Dict[str, str], +) -> Tuple[List[str], List[str]]: + """ + Get a list of fields with aliases based on the field mappings. + """ + for field in fields: + if "." in field and field not in field_mappings: + raise ValueError( + f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields." + ) + fields_with_aliases = [ + f"{field} AS {field_mappings[field]}" if field in field_mappings else field + for field in fields + ] + aliases = [ + field_mappings[field] if field in field_mappings else field for field in fields + ] + return (fields_with_aliases, aliases) + + def _coerce_datetime(ts): """ Depending on underlying time resolution, arrow to_pydict() sometimes returns pd @@ -678,9 +701,11 @@ def _populate_response_from_feature_data( """ # Add the feature names to the response. requested_feature_refs = [ - f"{table.projection.name_to_use()}__{feature_name}" - if full_feature_names - else feature_name + ( + f"{table.projection.name_to_use()}__{feature_name}" + if full_feature_names + else feature_name + ) for feature_name in requested_features ] online_features_response.metadata.feature_names.val.extend(requested_feature_refs) diff --git a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py b/sdk/python/tests/integration/materialization/contrib/spark/test_spark_materialization.py similarity index 100% rename from sdk/python/tests/integration/materialization/contrib/spark/test_spark.py rename to sdk/python/tests/integration/materialization/contrib/spark/test_spark_materialization.py diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py new file mode 100644 index 0000000000..b8f8cc4247 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -0,0 +1,129 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.repo_config import RepoConfig + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_nested_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_header.event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, nested_timestamp, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') + ) t2 + WHERE feast_row_ = 1""" # noqa: W293 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip() + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_pull_latest_from_table_without_nested_timestamp_or_query( + mock_get_spark_session, +): + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source = SparkSource( + name="test_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + + # Call the method + retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + expected_query = """SELECT + key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp + + FROM ( + SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp, + ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_ + FROM `offline_store_database_name`.`offline_store_table_name` t1 + WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') + ) t2 + WHERE feast_row_ = 1""" # noqa: W293 + + assert isinstance(retrieval_job, RetrievalJob) + assert retrieval_job.query.strip() == expected_query.strip()