Skip to content

Commit

Permalink
IPROTO-266 Allow for custom Encoder/Decoder implementations for TagWr…
Browse files Browse the repository at this point in the history
…iter/TagReader
  • Loading branch information
wburns committed Aug 10, 2023
1 parent 4eadd2d commit ca65721
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 262 deletions.
55 changes: 55 additions & 0 deletions core/src/main/java/org/infinispan/protostream/Decoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.infinispan.protostream;

import java.io.IOException;
import java.nio.ByteBuffer;

public interface Decoder {
int getEnd();

int getPos();

byte[] getBufferArray() throws IOException;

boolean isAtEnd() throws IOException;

int readTag() throws IOException;

void checkLastTagWas(int expectedTag) throws IOException;

boolean skipField(int tag) throws IOException;

void skipVarint() throws IOException;

void skipRawBytes(int length) throws IOException;

String readString() throws IOException;

byte readRawByte() throws IOException;

byte[] readRawByteArray(int length) throws IOException;

ByteBuffer readRawByteBuffer(int length) throws IOException;

int readVarint32() throws IOException;

long readVarint64() throws IOException;

int readFixed32() throws IOException;

long readFixed64() throws IOException;

int pushLimit(int newLimit) throws IOException;

void popLimit(int oldLimit);

Decoder decoderFromLength(int length) throws IOException;

/**
* Sets a hard limit on how many bytes we can continue to read while parsing a message from current position. This is
* useful to prevent corrupted or malicious messages with wrong length values to abuse memory allocation. Initially
* this limit is set to {@code Integer.MAX_INT}, which means the protection mechanism is disabled by default.
* The limit is only useful when processing streams. Setting a limit for a decoder backed by a byte array is useless
* because the memory allocation already happened.
*/
int setGlobalLimit(int globalLimit);
}
50 changes: 50 additions & 0 deletions core/src/main/java/org/infinispan/protostream/Encoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.infinispan.protostream;

import java.io.IOException;
import java.nio.ByteBuffer;

public interface Encoder {
void flush() throws IOException;

void close() throws IOException;

int remainingSpace();

void writeUInt32Field(int fieldNumber, int value) throws IOException;

void writeUInt64Field(int fieldNumber, long value) throws IOException;

void writeFixed32Field(int fieldNumber, int value) throws IOException;

void writeFixed64Field(int fieldNumber, long value) throws IOException;

void writeBoolField(int fieldNumber, boolean value) throws IOException;

void writeLengthDelimitedField(int fieldNumber, int length) throws IOException;

void writeVarint32(int value) throws IOException;

void writeVarint64(long value) throws IOException;

void writeFixed32(int value) throws IOException;

void writeFixed64(long value) throws IOException;

void writeByte(byte value) throws IOException;

void writeBytes(byte[] value, int offset, int length) throws IOException;

void writeBytes(ByteBuffer value) throws IOException;

default int skipFixedVarint() {
throw new UnsupportedOperationException();
}

default void writePositiveFixedVarint(int pos) {
throw new UnsupportedOperationException();
}

default boolean supportsFixedVarint() {
return false;
}
}
16 changes: 16 additions & 0 deletions core/src/main/java/org/infinispan/protostream/ProtobufUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public static void writeTo(ImmutableSerializationContext ctx, OutputStream out,
write(ctx, TagWriterImpl.newInstance(ctx, out), t);
}

public static void writeTo(ImmutableSerializationContext ctx, Encoder encoder, Object t) throws IOException {
write(ctx, TagWriterImpl.newInstance(ctx, encoder), t);
}

public static byte[] toByteArray(ImmutableSerializationContext ctx, Object t) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream(DEFAULT_ARRAY_BUFFER_SIZE);
writeTo(ctx, baos, t);
Expand Down Expand Up @@ -112,6 +116,10 @@ public static <A> A fromByteBuffer(ImmutableSerializationContext ctx, ByteBuffer
return readFrom(TagReaderImpl.newInstance(ctx, byteBuffer), clazz);
}

public static <A> A fromDecoder(ImmutableSerializationContext ctx, Decoder decoder, Class<A> clazz) throws IOException {
return readFrom(TagReaderImpl.newInstance(ctx, decoder), clazz);
}

/**
* Parses a top-level message that was wrapped according to the org.infinispan.protostream.WrappedMessage proto
* definition.
Expand All @@ -137,6 +145,10 @@ public static <A> A fromWrappedStream(ImmutableSerializationContext ctx, InputSt
return WrappedMessage.read(ctx, TagReaderImpl.newInstance(ctx, in));
}

public static <A> A fromWrappedDecoder(ImmutableSerializationContext ctx, Decoder decoder) throws IOException {
return WrappedMessage.read(ctx, TagReaderImpl.newInstance(ctx, decoder));
}

//todo [anistor] should make it possible to plug in a custom wrapping strategy instead of the default one
public static byte[] toWrappedByteArray(ImmutableSerializationContext ctx, Object t) throws IOException {
return toWrappedByteArray(ctx, t, DEFAULT_ARRAY_BUFFER_SIZE);
Expand All @@ -162,6 +174,10 @@ public static void toWrappedStream(ImmutableSerializationContext ctx, OutputStre
WrappedMessage.write(ctx, TagWriterImpl.newInstance(ctx, out, bufferSize), t);
}

public static void toWrappedEncoder(ImmutableSerializationContext ctx, Encoder encoder, Object t) throws IOException {
WrappedMessage.write(ctx, TagWriterImpl.newInstance(ctx, encoder), t);
}

/**
* Converts a Protobuf encoded message to its <a href="https://developers.google.com/protocol-buffers/docs/proto3#json">
* canonical JSON representation</a>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ default int readEnum() throws IOException {
* itself, possibly avoiding byte[] allocations
* @return a new TagReader
*/
TagReader subReaderFromArray() throws IOException;
ProtobufTagMarshaller.ReadContext subReaderFromArray() throws IOException;

default double readDouble() throws IOException {
return Double.longBitsToDouble(readFixed64());
Expand Down
16 changes: 10 additions & 6 deletions core/src/main/java/org/infinispan/protostream/TagWriter.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.infinispan.protostream;

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;

Expand All @@ -9,14 +10,14 @@
* @author [email protected]
* @since 4.4
*/
public interface TagWriter extends RawProtoStreamWriter {
public interface TagWriter extends RawProtoStreamWriter, Closeable {

// start low level ops
void flush() throws IOException;

/**
* Invoke after done with writer, this implies a flush if necessary
* It is necessary to invoke this on a writer returned from {@link #subWriter(int)} to actually push the data
* It is necessary to invoke this on a writer returned from {@link #subWriter(int, boolean)} to actually push the data
*/
void close() throws IOException;

Expand Down Expand Up @@ -99,9 +100,12 @@ default void writeBytes(int number, byte[] value) throws IOException {

/**
* Used to write a sub message that can be optimized by implementation. When the sub writer is complete, flush
* should be invoked to ensure
* @return
* @throws IOException
* should be invoked to ensure bytes are written and close should be invoked to free any resources related to the
* context (note close will flush as well)
* @param number the message number of the sub message
* @param nested whether this is a nested message or a new one
* @return a write context for a sub message
* @throws IOException exception if there is an issue
*/
TagWriter subWriter(int number, boolean nested) throws IOException;
ProtobufTagMarshaller.WriteContext subWriter(int number, boolean nested) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.infinispan.protostream;

import java.io.Closeable;
import java.io.IOException;
import java.time.Instant;
import java.util.Date;
Expand Down Expand Up @@ -296,9 +297,9 @@ private static void writeMessage(ImmutableSerializationContext ctx, TagWriter ou
if (t.getClass().isEnum()) {
((EnumMarshallerDelegate) marshallerDelegate).encode(WRAPPED_ENUM, (Enum) t, out);
} else {
TagWriter nestedWriter = out.subWriter(WRAPPED_MESSAGE, false);
marshallerDelegate.marshall((ProtobufTagMarshaller.WriteContext) nestedWriter, null, t);
nestedWriter.close();
ProtobufTagMarshaller.WriteContext nestedWriter = out.subWriter(WRAPPED_MESSAGE, false);
marshallerDelegate.marshall(nestedWriter, null, t);
nestedWriter.getWriter().close();
}
}
}
Expand Down Expand Up @@ -353,7 +354,7 @@ private static <T> T readMessage(ImmutableSerializationContext ctx, TagReader in
String typeName = null;
Integer typeId = null;
int enumValue = -1;
TagReader messageReader = null;
ProtobufTagMarshaller.ReadContext readContext = null;
Object value = null;
int fieldCount = 0;
int expectedFieldCount = 1;
Expand Down Expand Up @@ -396,7 +397,7 @@ private static <T> T readMessage(ImmutableSerializationContext ctx, TagReader in
}
case WRAPPED_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: {
expectedFieldCount = 2;
messageReader = in.subReaderFromArray();
readContext = in.subReaderFromArray();
break;
}
case WRAPPED_STRING << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: {
Expand Down Expand Up @@ -512,7 +513,7 @@ private static <T> T readMessage(ImmutableSerializationContext ctx, TagReader in
}
}

if (value == null && typeName == null && typeId == null && messageReader == null) {
if (value == null && typeName == null && typeId == null && readContext == null) {
return null;
}

Expand All @@ -531,9 +532,9 @@ private static <T> T readMessage(ImmutableSerializationContext ctx, TagReader in
typeName = ctx.getDescriptorByTypeId(typeId).getFullName();
}
BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(typeName);
if (messageReader != null) {
if (readContext != null) {
// it's a Message type
return (T) marshallerDelegate.unmarshall((ProtobufTagMarshaller.ReadContext) messageReader, null);
return (T) marshallerDelegate.unmarshall(readContext, null);
} else {
// it's an Enum
EnumMarshaller marshaller = (EnumMarshaller) marshallerDelegate.getMarshaller();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package org.infinispan.protostream.annotations.impl;

import java.io.Closeable;
import java.io.IOException;

import org.infinispan.protostream.ProtobufTagMarshaller;
import org.infinispan.protostream.TagWriter;
import org.infinispan.protostream.impl.BaseMarshallerDelegate;
import org.infinispan.protostream.impl.ByteArrayOutputStreamEx;
import org.infinispan.protostream.impl.Log;
import org.infinispan.protostream.impl.TagWriterImpl;

/**
* Base class for generated message marshallers. Provides some handy helper methods.
Expand Down Expand Up @@ -47,20 +46,11 @@ protected final <T> void writeNestedMessage(BaseMarshallerDelegate<T> marshaller
throw log.maxNestedMessageDepth(maxNestedMessageDepth, message.getClass());
}

if (ctx instanceof TagWriter) {
TagWriter nestedWriter = ((TagWriter) ctx).subWriter(fieldNumber, true);
marshallerDelegate.marshall((ProtobufTagMarshaller.WriteContext) nestedWriter, null, message);
nestedWriter.close();
} else {
handleNonTagWriter(marshallerDelegate, ctx, fieldNumber, message);
TagWriter tagWriter = ctx.getWriter();
ProtobufTagMarshaller.WriteContext nestedWriter = tagWriter.subWriter(fieldNumber, true);
marshallerDelegate.marshall(nestedWriter, null, message);
if (nestedWriter instanceof Closeable) {
((Closeable) nestedWriter).close();
}
}

private <T> void handleNonTagWriter(BaseMarshallerDelegate<T> marshallerDelegate, ProtobufTagMarshaller.WriteContext ctx,
int fieldNumber, T message) throws IOException {
ByteArrayOutputStreamEx baos = new ByteArrayOutputStreamEx();
TagWriterImpl nested = TagWriterImpl.newNestedInstance(ctx, baos);
writeMessage(marshallerDelegate, nested, message);
ctx.getWriter().writeBytes(fieldNumber, baos.getByteBuffer());
}
}
Loading

0 comments on commit ca65721

Please sign in to comment.