Skip to content

Commit

Permalink
fix: Implement batch write with connector for feature value processin…
Browse files Browse the repository at this point in the history
…g in SparkKafkaProcessor
  • Loading branch information
Bhargav Dodla committed Jan 15, 2025
1 parent eb6660f commit 075c4c0
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.functions import col, from_json
from pyspark.sql.functions import col, from_json, udf
from pyspark.sql.streaming import StreamingQuery
from pyspark.sql.types import BinaryType

from feast import FeatureView
from feast.data_format import AvroFormat, ConfluentAvroFormat, JsonFormat, StreamFormat
Expand Down Expand Up @@ -332,21 +333,51 @@ def batch_write(
f"Time taken to write batch {batch_id} is: {(time.time() - start_time) * 1000:.2f} ms"
)

query = (
df.writeStream.outputMode("update")
.option("checkpointLocation", self.checkpoint_location)
.trigger(processingTime=self.processing_time)
.foreachBatch(
lambda df, batch_id: batch_write(
df,
batch_id,
self.spark_serialized_artifacts,
self.join_keys,
self.sfv,
def batch_write_with_connector(
sdf: DataFrame,
batch_id: int,
):
start_time = time.time()
sdf = sdf.drop("event_header")
convert_to_blob = udf(lambda s: s.encode("utf-8"), BinaryType())
sdf = sdf.withColumn("feature_value", convert_to_blob(col("feature_value")))
sdf.write.format("org.apache.spark.sql.cassandra").mode("append").options(
table="mlpfs_scylladb_perf_test_cc_stream_fv", keyspace="feast"
).save()
print(
f"Time taken to write batch {batch_id} is: {(time.time() - start_time) * 1000:.2f} ms"
)

query = None
if self.sfv.name != "cc_stream_fv":
query = (
df.writeStream.outputMode("update")
.option("checkpointLocation", self.checkpoint_location)
.trigger(processingTime=self.processing_time)
.foreachBatch(
lambda df, batch_id: batch_write(
df,
batch_id,
self.spark_serialized_artifacts,
self.join_keys,
self.sfv,
)
)
.start()
)
else:
query = (
df.writeStream.outputMode("update")
.option("checkpointLocation", self.checkpoint_location)
.trigger(processingTime=self.processing_time)
.foreachBatch(
lambda df, batch_id: batch_write_with_connector(
df,
batch_id,
)
)
.start()
)
.start()
)

query.awaitTermination(timeout=self.query_timeout)
return query
Expand Down

0 comments on commit 075c4c0

Please sign in to comment.