Skip to content

Commit

Permalink
Add AvroDeserializer to Spark Kafka Processor
Browse files Browse the repository at this point in the history
  • Loading branch information
expediamatt committed Feb 11, 2024
1 parent 35a022c commit a448480
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 55 deletions.
198 changes: 155 additions & 43 deletions sdk/python/feast/expediagroup/schema_registry/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,166 @@
Author: [email protected]
"""

import requests
import json
import os
import tempfile
from typing import Dict

from confluent_kafka.schema_registry import SchemaRegistryClient
import requests
from confluent_kafka.schema_registry import RegisteredSchema, SchemaRegistryClient


class SchemaRegistry():
# spark: SparkSession
# format: str
# preprocess_fn: Optional[MethodType]
# join_keys: List[str]
props: Dict[str, str]
kafka_params: Dict[str, str]
schema_registry_config: Dict[str, str]
client: SchemaRegistryClient

def __init__(self):
pass

def get_properties(
user: String,
password: String,
urn: String,
environment: String,
cert_path: String, #https://stackoverflow.com/questions/55203791/python-requests-using-certificate-value-instead-of-path
) -> dict:
"""Discover a Schema Registry with the provided urn and credentials,
and obtain a set of properties for use in Schema Registry calls."""
discovery_url = "https://stream-discovery-service-{environment}.rcp.us-east-1.data.{environment}.exp-aws.net/v2/discovery/urn/{urn}".format(
environment=environment, urn=urn
)

response = requests.get(
discovery_url,
auth=(user, password),
headers={"Accept": "application/json"},
verify=cert_path,
)

if response.status_code != 200:
raise RuntimeError(
"Discovery API returned unexpected HTTP status: {status}".format(
status=str(response.status_code)
)
)

try:
props = json.loads(response.text)
except (TypeError, UnicodeDecodeError):
raise TypeError(
"Discovery API response did not contain valid json: {response}".format(
response=response.text
)
)

return props
def initialize_client(
self,
user: str,
password: str,
urn: str,
environment: str,
cert_path: str, # https://stackoverflow.com/questions/55203791/python-requests-using-certificate-value-instead-of-path
) -> None:
"""
Discover a Schema Registry with the provided urn and credentials,
obtain a set of properties for use in Schema Registry calls,
and initialize the SchemaRegistryClient.
"""

discovery_url = "https://stream-discovery-service-{environment}.rcp.us-east-1.data.{environment}.exp-aws.net/v2/discovery/urn/{urn}".format(
environment=environment, urn=urn
)

response = requests.get(
discovery_url,
auth=(user, password),
headers={"Accept": "application/json"},
verify=cert_path,
)

if response.status_code != 200:
raise RuntimeError(
"Discovery API returned unexpected HTTP status: {status}".format(
status=str(response.status_code)
)
)

try:
props = json.loads(response.text)
except (TypeError, UnicodeDecodeError):
raise TypeError(
"Discovery API response did not contain valid json: {response}".format(
response=response.text
)
)

self.props = props

# write ssl key and cert to disk
ssl_key_file, ssl_key_path = tempfile.mkstemp()
with os.fdopen(ssl_key_file, "w") as f:
f.write(props["serde"]["schema.registry.ssl.keystore.key"])

ssl_certificate_file, ssl_certificate_path = tempfile.mkstemp()
with os.fdopen(ssl_certificate_file, "w") as f:
f.write(props["serde"]["schema.registry.ssl.keystore.certificate.chain"])

self.kafka_params = {
"kafka.security.protocol": props["security"]["security.protocol"],
"kafka.bootstrap.servers": props["connection"]["bootstrap.servers"],
"subscribe": props["connection"]["topic"],
"startingOffsets": props["connection"]["auto.offset.reset"],
"kafka.ssl.truststore.certificates": props["security"][
"ssl.truststore.certificates"
],
"kafka.ssl.keystore.certificate.chain": props["security"][
"ssl.keystore.certificate.chain"
],
"kafka.ssl.keystore.key": props["security"]["ssl.keystore.key"],
"kafka.ssl.endpoint.identification.algorithm": props["security"][
"ssl.endpoint.identification.algorithm"
],
"kafka.ssl.truststore.type": props["security"]["ssl.truststore.type"],
"kafka.ssl.keystore.type": props["security"]["ssl.keystore.type"],
"kafka.topic": props["connection"]["topic"],
"kafka.schema.registry.url": props["serde"]["schema.registry.url"],
"kafka.schema.registry.topic": props["connection"]["topic"],
"kafka.schema.registry.ssl.keystore.type": props["serde"][
"schema.registry.ssl.keystore.type"
],
"kafka.schema.registry.ssl.keystore.certificate.chain": props["serde"][
"schema.registry.ssl.keystore.certificate.chain"
],
"kafka.schema.registry.ssl.keystore.key": props["serde"][
"schema.registry.ssl.keystore.key"
],
"kafka.schema.registry.ssl.truststore.certificates": props["serde"][
"schema.registry.ssl.truststore.certificates"
],
"kafka.schema.registry.ssl.truststore.type": props["serde"][
"schema.registry.ssl.truststore.type"
],
"value.subject.name.strategy": "io.confluent.kafka.serializers.subject.TopicRecordNameStrategy",
}

self.schema_registry_config = {
"schema.registry.topic": props["connection"]["topic"],
"schema.registry.url": props["serde"]["schema.registry.url"],
"schema.registry.ssl.keystore.type": props["serde"][
"schema.registry.ssl.keystore.type"
],
"schema.registry.ssl.keystore.certificate.chain": props["serde"][
"schema.registry.ssl.keystore.certificate.chain"
],
"schema.registry.ssl.keystore.key": props["serde"][
"schema.registry.ssl.keystore.key"
],
"schema.registry.ssl.truststore.certificates": props["serde"][
"schema.registry.ssl.truststore.certificates"
],
"schema.registry.ssl.truststore.type": props["serde"][
"schema.registry.ssl.truststore.type"
],
}

schema_registry_url = props["serde"]["schema.registry.url"]

self.client = SchemaRegistryClient(
{
"url": schema_registry_url,
"ssl.ca.location": cert_path,
"ssl.key.location": ssl_key_path,
"ssl.certificate.location": ssl_certificate_path,
}
)

def get_latest_version(
self,
topic_name: str,
) -> RegisteredSchema:
"""
Get the latest version of the topic.
"""
if not self.client:
raise RuntimeError("Client has not been initialized. Please call initialize_client first.")

latest_version = self.client.get_latest_version(topic_name)

return latest_version

def get_client(
self
) -> SchemaRegistryClient:
"""
Return the client.
"""
if not self.client:
raise RuntimeError("Client has not been initialized. Please call initialize_client first.")

return self.client
31 changes: 19 additions & 12 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import List, Optional

import pandas as pd
from feast.expediagroup.schema_registry.schema_registry import SchemaRegistry
from confluent_kafka.schema_registry.avro import AvroDeserializer
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col, from_json

from feast.data_format import AvroFormat, ConfluentAvroFormat, JsonFormat
from feast.data_source import KafkaSource, PushMode
from feast.expediagroup.schema_registry.schema_registry import SchemaRegistry
from feast.feature_store import FeatureStore
from feast.infra.contrib.stream_processor import (
ProcessorConfig,
Expand All @@ -23,6 +23,7 @@ class SparkProcessorConfig(ProcessorConfig):
spark_session: SparkSession
processing_time: str
query_timeout: int
schema_registry_client: Optional[SchemaRegistry]


class SparkKafkaProcessor(StreamProcessor):
Expand Down Expand Up @@ -57,26 +58,24 @@ def __init__(
self.format = "json"
elif isinstance(sfv.stream_source.kafka_options.message_format, ConfluentAvroFormat):
self.format = "confluent_avro"
self.init_confluent_avro_processor()

if not isinstance(config, SparkProcessorConfig):
raise ValueError("config is not spark processor config")
self.spark = config.spark_session
self.preprocess_fn = preprocess_fn
self.processing_time = config.processing_time
self.query_timeout = config.query_timeout
self.schema_registry_client = config.schema_registry_client if config.schema_registry_client else None
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

if isinstance(sfv.stream_source.kafka_options.message_format, ConfluentAvroFormat):
self.init_confluent_avro_processor()

super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

def init_confluent_avro_processor(self) -> None:
"""Extra initialization for Confluent Avro processor, which uses
SchemaRegistry and the Avro Deserializer, both of which need initialization."""

user = "VAULT_SECRETS"
password = "VAULT_SECRETS"
urn = "NOT SURE"
environment = "NOT SURE"
"""Extra initialization for Confluent Avro processor."""
self.deserializer = AvroDeserializer(schema_registry_client=self.schema_registry_client.get_client())

def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
ingested_stream_df = self._ingest_stream_data()
Expand Down Expand Up @@ -115,8 +114,16 @@ def _ingest_stream_data(self) -> StreamTable:
self.data_source.kafka_options.message_format, ConfluentAvroFormat
):
raise ValueError("kafka source message format is not confluent_avro format")
raise ValueError("HOLY MOLY I AM NOT READY TO DEAL WITH CONFLUENT AVRO, GUYS")
stream_df = None

stream_df = (
self.spark.readStream.format("kafka")
.options(**self.kafka_options_config)
.load()
.select(
self.deserializer(col("value"))
)
.select("table.*")
)
else:
if not isinstance(
self.data_source.kafka_options.message_format, AvroFormat
Expand Down

0 comments on commit a448480

Please sign in to comment.