Skip to content

Commit

Permalink
NU-1848: Improve missing Flink Kafka Source / Sink TypeInformation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lciolecki authored Nov 18, 2024
1 parent 14306e2 commit 5ebf989
Show file tree
Hide file tree
Showing 16 changed files with 580 additions and 72 deletions.
1 change: 1 addition & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### 1.19.0 (Not released yet)

* [#7145](https://github.com/TouK/nussknacker/pull/7145) Lift TypingResult information for dictionaries
* [#7116](https://github.com/TouK/nussknacker/pull/7116) Improve missing Flink Kafka Source / Sink TypeInformation

## 1.18

Expand Down
8 changes: 8 additions & 0 deletions docs/MigrationGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

To see the biggest differences please consult the [changelog](Changelog.md).

## In version 1.19.0 (Not released yet)

### Other changes

* [#7116](https://github.com/TouK/nussknacker/pull/7116) Improve missing Flink Kafka Source / Sink TypeInformation
* We lost support for old ConsumerRecord constructor supported by Flink 1.14 / 1.15
* If you used Kafka source/sink components in your scenarios then state of these scenarios won't be restored

## In version 1.18.0 (Not released yet)

### Configuration changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pl.touk.nussknacker.engine.flink.api.typeinformation

import org.apache.flink.api.common.typeinfo.{TypeInformation, Types}
import pl.touk.nussknacker.engine.api.context.ValidationContext
import pl.touk.nussknacker.engine.api.generics.GenericType
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypingResult}
import pl.touk.nussknacker.engine.api.{Context, ValueWithContext}
import pl.touk.nussknacker.engine.util.Implicits.RichStringList
Expand Down Expand Up @@ -38,8 +39,14 @@ trait TypeInformationDetection extends Serializable {
forClass(klass)
}

def forClass[T](klass: Class[T]): TypeInformation[T] =
forType[T](Typed.typedClass(klass))
def forClass[T](klass: Class[T]): TypeInformation[T] = {
// Typed.typedClass doesn't support Any
if (klass == classOf[Any]) {
Types.GENERIC(klass)
} else {
forType[T](Typed.typedClass(klass))
}
}

def forType[T](typingResult: TypingResult): TypeInformation[T]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import pl.touk.nussknacker.engine.api.component.{
import pl.touk.nussknacker.engine.api.process.ProcessObjectDependencies
import pl.touk.nussknacker.engine.kafka.KafkaConfig
import pl.touk.nussknacker.engine.kafka.source.flink.FlinkKafkaSourceImplFactory
import pl.touk.nussknacker.engine.schemedkafka.FlinkUniversalSchemaBasedSerdeProvider
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.SchemaRegistryClientFactory
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.universal.{
UniversalSchemaBasedSerdeProvider,
UniversalSchemaRegistryClientFactory
}
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.universal.UniversalSchemaRegistryClientFactory
import pl.touk.nussknacker.engine.schemedkafka.sink.UniversalKafkaSinkFactory
import pl.touk.nussknacker.engine.schemedkafka.sink.flink.FlinkKafkaUniversalSinkImplFactory
import pl.touk.nussknacker.engine.schemedkafka.source.UniversalKafkaSourceFactory
Expand All @@ -36,7 +34,7 @@ class FlinkKafkaComponentProvider extends ComponentProvider {
import docsConfig._
def universal(componentType: ComponentType) = s"DataSourcesAndSinks#kafka-$componentType"

val universalSerdeProvider = UniversalSchemaBasedSerdeProvider.create(schemaRegistryClientFactory)
val universalSerdeProvider = FlinkUniversalSchemaBasedSerdeProvider.create(schemaRegistryClientFactory)

List(
ComponentDefinition(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package pl.touk.nussknacker.engine.schemedkafka.flink.typeinfo;

import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
import org.apache.kafka.clients.consumer.ConsumerRecord;

/**
* A {@link TypeSerializerSnapshot} for the Scala {@link ConsumerRecordTypeInfo}.
*/
public final class ConsumerRecordTypeSerializerSnapshot<K, V>
extends CompositeTypeSerializerSnapshot<ConsumerRecord<K, V>, ConsumerRecordSerializer<K, V>> {

final private static int VERSION = 1;

public ConsumerRecordTypeSerializerSnapshot() {
super();
}

public ConsumerRecordTypeSerializerSnapshot(ConsumerRecordSerializer<K, V> serializerInstance) {
super(serializerInstance);
}

@Override
protected int getCurrentOuterSnapshotVersion() {
return VERSION;
}

@Override
protected TypeSerializer<?>[] getNestedSerializers(ConsumerRecordSerializer<K, V> outerSerializer) {
return new TypeSerializer[] { outerSerializer.keySerializer(), outerSerializer.valueSerializer() };
}

@Override
protected ConsumerRecordSerializer<K, V> createOuterSerializerWithNestedSerializers(TypeSerializer<?>[] nestedSerializers) {
@SuppressWarnings("unchecked")
TypeSerializer<K> keySerializer = (TypeSerializer<K>) nestedSerializers[0];

@SuppressWarnings("unchecked")
TypeSerializer<V> valueSerializer = (TypeSerializer<V>) nestedSerializers[1];

return new ConsumerRecordSerializer<>(keySerializer, valueSerializer);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pl.touk.nussknacker.engine.schemedkafka

import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.serialization.KafkaSchemaRegistryBasedValueSerializationSchemaFactory
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.universal.UniversalSchemaBasedSerdeProvider.createSchemaIdFromMessageExtractor
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.universal.{UniversalKafkaDeserializerFactory, UniversalSchemaValidator, UniversalSerializerFactory, UniversalToJsonFormatterFactory}
import pl.touk.nussknacker.engine.schemedkafka.schemaregistry.{SchemaBasedSerdeProvider, SchemaRegistryClientFactory}
import pl.touk.nussknacker.engine.schemedkafka.source.flink.FlinkKafkaSchemaRegistryBasedKeyValueDeserializationSchemaFactory

object FlinkUniversalSchemaBasedSerdeProvider {

def create(schemaRegistryClientFactory: SchemaRegistryClientFactory): SchemaBasedSerdeProvider = {
SchemaBasedSerdeProvider(
new KafkaSchemaRegistryBasedValueSerializationSchemaFactory(
schemaRegistryClientFactory,
UniversalSerializerFactory
),
new FlinkKafkaSchemaRegistryBasedKeyValueDeserializationSchemaFactory(
schemaRegistryClientFactory,
new UniversalKafkaDeserializerFactory(createSchemaIdFromMessageExtractor)
),
new UniversalToJsonFormatterFactory(schemaRegistryClientFactory, createSchemaIdFromMessageExtractor),
UniversalSchemaValidator
)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package pl.touk.nussknacker.engine.schemedkafka.flink.typeinfo

import com.github.ghik.silencer.silent
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.{TypeSerializer, TypeSerializerSnapshot}
import org.apache.flink.core.memory.{DataInputView, DataOutputView}
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.header.internals.{RecordHeader, RecordHeaders}
import org.apache.kafka.common.record.TimestampType

import java.util.{Objects, Optional}

class ConsumerRecordTypeInfo[K, V](val keyTypeInfo: TypeInformation[K], val valueTypeInfo: TypeInformation[V])
extends TypeInformation[ConsumerRecord[K, V]] {

override def getTypeClass: Class[ConsumerRecord[K, V]] = classOf[ConsumerRecord[K, V]]

@silent("deprecated")
override def createSerializer(
config: org.apache.flink.api.common.ExecutionConfig
): TypeSerializer[ConsumerRecord[K, V]] = {
new ConsumerRecordSerializer[K, V](keyTypeInfo.createSerializer(config), valueTypeInfo.createSerializer(config))
}

// ConsumerRecord 8 simple fields
override def getArity: Int = 8

// TODO: find out what's the correct value here
// ConsumerRecord 8 fields (w/o: headers, key, value) + Headers 2 fields + key.fields + value.fields
override def getTotalFields: Int = 8 + 2 + keyTypeInfo.getTotalFields + valueTypeInfo.getTotalFields

override def isKeyType: Boolean = false

override def isBasicType: Boolean = false

override def isTupleType: Boolean = false

override def toString: String =
s"ConsumerRecordTypeInfo($keyTypeInfo, $valueTypeInfo)"

override def canEqual(obj: Any): Boolean =
obj.isInstanceOf[ConsumerRecordTypeInfo[_, _]]

override def equals(obj: Any): Boolean =
obj match {
case info: ConsumerRecordTypeInfo[_, _] =>
keyTypeInfo.equals(info.keyTypeInfo) && valueTypeInfo.equals(info.valueTypeInfo)
case _ => false
}

override def hashCode(): Int =
Objects.hashCode(keyTypeInfo, valueTypeInfo)
}

class ConsumerRecordSerializer[K, V](val keySerializer: TypeSerializer[K], val valueSerializer: TypeSerializer[V])
extends TypeSerializer[ConsumerRecord[K, V]] {

override def getLength: Int = -1

override def isImmutableType: Boolean = true

override def createInstance(): ConsumerRecord[K, V] =
new ConsumerRecord[K, V](null, 0, 0, null.asInstanceOf[K], null.asInstanceOf[V])

override def duplicate(): TypeSerializer[ConsumerRecord[K, V]] = {
val keyDuplicated = keySerializer.duplicate()
val valueDuplicated = valueSerializer.duplicate()

if (keyDuplicated.equals(keySerializer) && valueDuplicated.equals(valueSerializer)) {
this
} else {
new ConsumerRecordSerializer(keyDuplicated, valueDuplicated)
}
}

override def copy(record: ConsumerRecord[K, V]): ConsumerRecord[K, V] =
new ConsumerRecord[K, V](
record.topic(),
record.partition(),
record.offset(),
record.timestamp(),
record.timestampType(),
ConsumerRecord.NULL_SIZE,
ConsumerRecord.NULL_SIZE,
record.key(),
record.value(),
record.headers(),
record.leaderEpoch()
)

override def copy(record: ConsumerRecord[K, V], reuse: ConsumerRecord[K, V]): ConsumerRecord[K, V] =
copy(record)

override def copy(source: DataInputView, target: DataOutputView): Unit =
serialize(deserialize(source), target)

override def serialize(record: ConsumerRecord[K, V], target: DataOutputView): Unit = {
target.writeUTF(record.topic())
target.writeInt(record.partition())
target.writeLong(record.offset())
target.writeLong(record.timestamp())

// Short takes less space than int
target.writeShort(record.timestampType().id)

target.writeInt(record.serializedKeySize())
target.writeInt(record.serializedValueSize())

// Serialize the key (can be null)
if (record.key() == null) {
target.writeBoolean(false)
} else {
target.writeBoolean(true)
keySerializer.serialize(record.key(), target)
}

// Serialize the value (can be null)
if (record.value() == null) {
target.writeBoolean(false)
} else {
target.writeBoolean(true)
valueSerializer.serialize(record.value(), target)
}

if (record.leaderEpoch().isPresent) {
target.writeBoolean(true)
target.writeInt(record.leaderEpoch.get())
} else {
target.writeBoolean(false)
}

target.writeInt(record.headers().toArray.length)
record.headers().forEach { header =>
target.writeUTF(header.key())
target.writeInt(header.value().length)
target.write(header.value())
}
}

override def deserialize(reuse: ConsumerRecord[K, V], source: DataInputView): ConsumerRecord[K, V] =
deserialize(source)

override def deserialize(source: DataInputView): ConsumerRecord[K, V] = {
val topic = source.readUTF()
val partition = source.readInt()
val offset = source.readLong()
val timestamp = source.readLong()
val timestampTypeId = source.readShort().toInt
val serializedKeySize = source.readInt()
val serializedValueSize = source.readInt()

val key = if (source.readBoolean()) keySerializer.deserialize(source) else null.asInstanceOf[K]
val value = if (source.readBoolean()) valueSerializer.deserialize(source) else null.asInstanceOf[V]
val leaderEpoch = if (source.readBoolean()) Optional.of[Integer](source.readInt()) else Optional.empty[Integer]

val headers = (0 until source.readInt()).foldLeft(new RecordHeaders) { (headers, _) =>
val name = source.readUTF()
val len = source.readInt()

val value = new Array[Byte](len)
source.read(value)

val header = new RecordHeader(name, value)
headers.add(header)
headers
}

val timestampType =
TimestampType
.values()
.toList
.find(_.id == timestampTypeId)
.getOrElse(throw new IllegalArgumentException(s"Unknown TimestampType id: $timestampTypeId."))

new ConsumerRecord[K, V](
topic,
partition,
offset,
timestamp,
timestampType,
serializedKeySize,
serializedValueSize,
key,
value,
headers,
leaderEpoch
)
}

override def snapshotConfiguration(): TypeSerializerSnapshot[ConsumerRecord[K, V]] =
new ConsumerRecordTypeSerializerSnapshot()

override def equals(obj: Any): Boolean = {
obj match {
case other: ConsumerRecordSerializer[_, _] =>
keySerializer.equals(other.keySerializer) && valueSerializer.equals(other.valueSerializer)
case _ => false
}
}

override def hashCode(): Int =
Objects.hashCode(keySerializer, valueSerializer)

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pl.touk.nussknacker.engine.schemedkafka.sink.flink
import com.typesafe.scalalogging.LazyLogging
import io.confluent.kafka.schemaregistry.ParsedSchema
import org.apache.flink.api.common.functions.{RichMapFunction, RuntimeContext}
import org.apache.flink.api.common.typeinfo.{TypeInformation, Types}
import org.apache.flink.configuration.Configuration
import org.apache.flink.formats.avro.typeutils.NkSerializableParsedSchema
import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink}
Expand All @@ -13,6 +14,7 @@ import pl.touk.nussknacker.engine.api.validation.ValidationMode
import pl.touk.nussknacker.engine.api.{Context, LazyParameter, ValueWithContext}
import pl.touk.nussknacker.engine.flink.api.exception.{ExceptionHandler, WithExceptionHandler}
import pl.touk.nussknacker.engine.flink.api.process.{FlinkCustomNodeContext, FlinkSink}
import pl.touk.nussknacker.engine.flink.typeinformation.KeyedValueType
import pl.touk.nussknacker.engine.flink.util.keyed
import pl.touk.nussknacker.engine.flink.util.keyed.KeyedValueMapper
import pl.touk.nussknacker.engine.kafka.serialization.KafkaSerializationSchema
Expand Down Expand Up @@ -40,12 +42,21 @@ class FlinkKafkaUniversalSink(
override def registerSink(
dataStream: DataStream[ValueWithContext[Value]],
flinkNodeContext: FlinkCustomNodeContext
): DataStreamSink[_] =
// FIXME: Missing map TypeInformation
): DataStreamSink[_] = {

// TODO: Creating TypeInformation for Avro / Json Schema is difficult because of schema evolution, therefore we rely on Kryo, e.g. serializer for GenericRecordWithSchemaId
val typeInfo = KeyedValueType
.info(
Types.STRING, // KafkaSink for key supports only String
Types.GENERIC(classOf[AnyRef])
)
.asInstanceOf[TypeInformation[KeyedValue[AnyRef, AnyRef]]]

dataStream
.map(new EncodeAvroRecordFunction(flinkNodeContext))
.map(new EncodeAvroRecordFunction(flinkNodeContext), typeInfo)
.filter(_.value != null)
.addSink(toFlinkFunction)
}

def prepareValue(
ds: DataStream[Context],
Expand Down
Loading

0 comments on commit 5ebf989

Please sign in to comment.