Skip to content

Commit

Permalink
fix: Using mapInPandas for both spark stream and batch ingestions (#100)
Browse files Browse the repository at this point in the history
* fix: Using mapInPandas for both spark stream and batch ingestions

---------

Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Apr 15, 2024
1 parent d3094da commit 54ea08b
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 50 deletions.
143 changes: 140 additions & 3 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from types import MethodType
from typing import List, Optional, Union
from typing import List, Optional, Set, Union

import pandas as pd
import pyarrow
from pyspark import SparkContext
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
Expand All @@ -17,7 +18,11 @@
StreamProcessor,
StreamTable,
)
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
_SparkSerializedArtifacts,
)
from feast.stream_feature_view import StreamFeatureView
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping


class SparkProcessorConfig(ProcessorConfig):
Expand Down Expand Up @@ -100,12 +105,18 @@ def __init__(
else "/tmp/checkpoint/"
)
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
self.spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
feature_view=sfv, repo_config=fs.config
)
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
ingested_stream_df = self._ingest_stream_data()
transformed_df = self._construct_transformation_plan(ingested_stream_df)
online_store_query = self._write_stream_data(transformed_df, to)
if self.fs.config.provider == "expedia":
online_store_query = self._write_stream_data_expedia(transformed_df, to)
else:
online_store_query = self._write_stream_data(transformed_df, to)
return online_store_query

def _ingest_stream_data(self) -> StreamTable:
Expand Down Expand Up @@ -140,7 +151,6 @@ def _ingest_stream_data(self) -> StreamTable:
"subscribe": self.data_source.kafka_options.topic,
"startingOffsets": "latest",
}

stream_df = (
self.spark.readStream.format("kafka")
.options(**spark_kafka_options)
Expand Down Expand Up @@ -184,6 +194,133 @@ def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
elif isinstance(self.sfv, StreamFeatureView):
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_stream_data_expedia(self, df: StreamTable, to: PushMode):
"""
Ensures materialization logic in sync with stream ingestion.
Support only write to online store. No support for preprocess_fn also.
In Spark 3.2.2, toPandas() is throwing error when the dataframe has Boolean columns.
To fix this error, we need spark 3.4.0 or numpy < 1.20.0 but feast needs numpy >= 1.22.
Switching to use mapInPandas to solve the problem for boolean columns and
toPandas() also load all data into driver's memory.
Error Message:
AttributeError: module 'numpy' has no attribute 'bool'.
`np.bool` was a deprecated alias for the builtin `bool`.
To avoid this error in existing code, use `bool` by itself.
Doing this will not modify any behavior and is safe.
If you specifically wanted the numpy scalar type, use `np.bool_` here.
"""

# TODO: Support writing to offline store and preprocess_fn. Remove _write_stream_data method

# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys):
for pdf in iterator:
(
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if isinstance(feature_view, StreamFeatureView):
ts_field = feature_view.timestamp_field
else:
ts_field = feature_view.stream_source.timestamp_field

# Extract the latest feature values for each unique entity row (i.e. the join keys).
pdf = (
pdf.sort_values(by=[*join_keys, ts_field], ascending=False)
.groupby(join_keys)
.nth(0)
)

table = pyarrow.Table.from_pandas(pdf)
if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)

yield pd.DataFrame([pd.Series(range(1, 2))]) # dummy result

def batch_write(
sdf: DataFrame,
batch_id: int,
spark_serialized_artifacts,
join_keys,
feature_view,
):
drop_list: List[str] = []
fv_schema: Set[str] = set(
map(lambda field: field.name, feature_view.schema)
)
# Add timestamp field to the schema so we don't delete from dataframe
if isinstance(feature_view, StreamFeatureView):
fv_schema.add(feature_view.timestamp_field)
if feature_view.source.created_timestamp_column:
fv_schema.add(feature_view.source.created_timestamp_column)

if isinstance(feature_view, FeatureView):
if feature_view.stream_source is not None:
fv_schema.add(feature_view.stream_source.timestamp_field)
if feature_view.stream_source.created_timestamp_column:
fv_schema.add(
feature_view.stream_source.created_timestamp_column
)
else:
fv_schema.add(feature_view.batch_source.timestamp_field)
if feature_view.batch_source.created_timestamp_column:
fv_schema.add(
feature_view.batch_source.created_timestamp_column
)

for column in df.columns:
if column not in fv_schema:
drop_list.append(column)

if len(drop_list) > 0:
print(
f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe."
)

sdf.drop(*drop_list).mapInPandas(
lambda x: batch_write_pandas_df(
x, spark_serialized_artifacts, join_keys
),
"status int",
).count() # dummy action to force evaluation

query = (
df.writeStream.outputMode("update")
.option("checkpointLocation", self.checkpoint_location)
.trigger(processingTime=self.processing_time)
.foreachBatch(
lambda df, batch_id: batch_write(
df,
batch_id,
self.spark_serialized_artifacts,
self.join_keys,
self.sfv,
)
)
.start()
)

query.awaitTermination(timeout=self.query_timeout)
return query

def _write_stream_data(self, df: StreamTable, to: PushMode):
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write(row: DataFrame, batch_id: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, List, Literal, Optional, Sequence, Union, cast

import dill
import numpy as np
import pandas as pd
import pyarrow
from tqdm import tqdm
Expand Down Expand Up @@ -44,10 +43,6 @@ class SparkMaterializationEngineConfig(FeastConfigBaseModel):
partitions: int = 0
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done"""

batch_size: int = 10000
"""Batch size determines the number of rows to be written to the online store in a single batch per partitions.
To overwrite at each feature view level, set the tag 'batch_size' in the feature view definition."""


@dataclass
class SparkMaterializationJob(MaterializationJob):
Expand Down Expand Up @@ -184,23 +179,13 @@ def _materialize_one(
self.repo_config.batch_engine.partitions
)

# Calculate batch_size per feature_view
feature_view_batch_size = (
self.repo_config.batch_engine.batch_size
if "batch_size" not in feature_view.tags
or feature_view.tags["batch_size"] is None
else int(feature_view.tags["batch_size"])
)

print(
f"INFO!!! Processing {feature_view.name} with {spark_df.count()} records, batch size {feature_view_batch_size}"
f"INFO!!! Processing {feature_view.name} with {spark_df.count()} records"
)

spark_df.foreachPartition(
lambda x: _process_by_partition(
x, spark_serialized_artifacts, feature_view_batch_size
)
)
spark_df.mapInPandas(
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
).count() # dummy action to force evaluation

return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
Expand Down Expand Up @@ -245,23 +230,20 @@ def unserialize(self):
return feature_view, online_store, repo_config


def _process_by_partition(
rows,
def _map_by_partition(
iterator,
spark_serialized_artifacts: _SparkSerializedArtifacts,
batch_size: int,
):
"""Load pandas df to online store"""

# def write_to_online_store_in_batches(batch_dict, batch_id):
def write_to_online_store_in_batches(batch_df: pd.DataFrame):
batch_id = batch_df.name
for pdf in iterator:
pdf_row_count = pdf.shape[0]
start_time = time.time()
# convert to pyarrow table
if batch_df.shape[0] == 0:
if pdf_row_count == 0:
print("INFO!!! Dataframe has 0 records to process")
return

table = pyarrow.Table.from_pandas(batch_df)
table = pyarrow.Table.from_pandas(pdf)

# unserialize artifacts
(
Expand Down Expand Up @@ -291,24 +273,7 @@ def write_to_online_store_in_batches(batch_df: pd.DataFrame):
)
end_time = time.time()
print(
f"INFO!!! Processed batch {batch_id} in {int((end_time - start_time) * 1000)} milliseconds"
f"INFO!!! Processed batch with size {pdf_row_count} in {int((end_time - start_time) * 1000)} milliseconds"
)

start_time = time.time()
# Spark 3.3.0 or above supports toPandas() method. We are running on spark 3.2.2
pandas_dataframe = pd.DataFrame([row.asDict() for row in rows])
# TODO: For Pyspark applications, we should use py4j bridge to initialize loggers
# Temporarily using print to display logs
print(
f"INFO!!! Processing partition {pandas_dataframe.shape[0]} records, batch size {batch_size}"
)

if "fs_batch" in pandas_dataframe.columns:
raise ValueError(
"Column 'fs_batch' is reserved by Feature Store. Please rename to avoid conflicts."
)
pandas_dataframe["fs_batch"] = np.arange(len(pandas_dataframe)) // batch_size
pandas_dataframe.groupby("fs_batch").apply(write_to_online_store_in_batches)
print(
f"INFO!!! Processed partition {pandas_dataframe.shape[0]} records, batch size {batch_size}, time {int((time.time() - start_time))} Seconds"
)
yield pd.DataFrame([pd.Series(range(1, 2))]) # dummy result

0 comments on commit 54ea08b

Please sign in to comment.