diff --git a/core/src/main/java/org/infinispan/protostream/ProtobufUtil.java b/core/src/main/java/org/infinispan/protostream/ProtobufUtil.java index c672a9592..6be3e86ad 100644 --- a/core/src/main/java/org/infinispan/protostream/ProtobufUtil.java +++ b/core/src/main/java/org/infinispan/protostream/ProtobufUtil.java @@ -143,7 +143,7 @@ public static byte[] toWrappedByteArray(ImmutableSerializationContext ctx, Objec } public static byte[] toWrappedByteArray(ImmutableSerializationContext ctx, Object t, int bufferSize) throws IOException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(bufferSize); + ByteArrayOutputStream baos = new ByteArrayOutputStreamEx(bufferSize); WrappedMessage.write(ctx, TagWriterImpl.newInstanceNoBuffer(ctx, baos), t); return baos.toByteArray(); } @@ -155,7 +155,7 @@ public static ByteBuffer toWrappedByteBuffer(ImmutableSerializationContext ctx, } public static void toWrappedStream(ImmutableSerializationContext ctx, OutputStream out, Object t) throws IOException { - toWrappedStream(ctx, out, t, DEFAULT_STREAM_BUFFER_SIZE); + WrappedMessage.write(ctx, TagWriterImpl.newInstance(ctx, out), t); } public static void toWrappedStream(ImmutableSerializationContext ctx, OutputStream out, Object t, int bufferSize) throws IOException { diff --git a/core/src/main/java/org/infinispan/protostream/TagReader.java b/core/src/main/java/org/infinispan/protostream/TagReader.java index 8a9f208e8..bcac6912b 100644 --- a/core/src/main/java/org/infinispan/protostream/TagReader.java +++ b/core/src/main/java/org/infinispan/protostream/TagReader.java @@ -33,7 +33,9 @@ public interface TagReader extends RawProtoStreamReader { boolean readBool() throws IOException; - int readEnum() throws IOException; + default int readEnum() throws IOException { + return readInt32(); + } /** * Reads a {@code string} value. @@ -50,29 +52,48 @@ public interface TagReader extends RawProtoStreamReader { */ ByteBuffer readByteBuffer() throws IOException; - double readDouble() throws IOException; + /** + * Similar to {@link #readByteArray()} except that the reader impl may optimize creation of a sub TagReader from + * itself, possibly avoiding byte[] allocations + * @return a new TagReader + */ + TagReader subReaderFromArray() throws IOException; + + default double readDouble() throws IOException { + return Double.longBitsToDouble(readFixed64()); + } - float readFloat() throws IOException; + default float readFloat() throws IOException { + return Float.intBitsToFloat(readFixed32()); + } long readInt64() throws IOException; - long readUInt64() throws IOException; + default long readUInt64() throws IOException { + return readInt64(); + } long readSInt64() throws IOException; long readFixed64() throws IOException; - long readSFixed64() throws IOException; + default long readSFixed64() throws IOException { + return readFixed64(); + } int readInt32() throws IOException; - int readUInt32() throws IOException; + default int readUInt32() throws IOException { + return readInt32(); + } int readSInt32() throws IOException; int readFixed32() throws IOException; - int readSFixed32() throws IOException; + default int readSFixed32() throws IOException { + return readFixed32(); + } /** * Sets a limit (based on the length of the length delimited value) when entering an embedded message. diff --git a/core/src/main/java/org/infinispan/protostream/TagWriter.java b/core/src/main/java/org/infinispan/protostream/TagWriter.java index df11aea73..83eb5728e 100644 --- a/core/src/main/java/org/infinispan/protostream/TagWriter.java +++ b/core/src/main/java/org/infinispan/protostream/TagWriter.java @@ -14,9 +14,19 @@ public interface TagWriter extends RawProtoStreamWriter { // start low level ops void flush() throws IOException; - void writeTag(int number, int wireType) throws IOException; + /** + * Invoke after done with writer, this implies a flush if necessary + * It is necessary to invoke this on a writer returned from {@link #subWriter(int)} to actually push the data + */ + void close() throws IOException; - void writeTag(int number, WireType wireType) throws IOException; + default void writeTag(int number, int wireType) throws IOException { + writeVarint32(WireType.makeTag(number, wireType)); + } + + default void writeTag(int number, WireType wireType) throws IOException { + writeVarint32(WireType.makeTag(number, wireType)); + } void writeVarint32(int value) throws IOException; @@ -28,38 +38,70 @@ public interface TagWriter extends RawProtoStreamWriter { // start high level ops void writeString(int number, String value) throws IOException; - void writeInt32(int number, int value) throws IOException; + default void writeInt32(int number, int value) throws IOException { + if (value >= 0) { + writeUInt32(number, value); + } else { + writeUInt64(number, value); + } + } void writeUInt32(int number, int value) throws IOException; - void writeSInt32(int number, int value) throws IOException; + default void writeSInt32(int number, int value) throws IOException { + // Roll the bits in order to move the sign bit from position 31 to position 0, to reduce the wire length of negative numbers. + writeUInt32(number, (value << 1) ^ (value >> 31)); + } void writeFixed32(int number, int value) throws IOException; - void writeSFixed32(int number, int value) throws IOException; + default void writeSFixed32(int number, int value) throws IOException { + writeFixed32(number, value); + } void writeInt64(int number, long value) throws IOException; void writeUInt64(int number, long value) throws IOException; - void writeSInt64(int number, long value) throws IOException; + default void writeSInt64(int number, long value) throws IOException { + // Roll the bits in order to move the sign bit from position 63 to position 0, to reduce the wire length of negative numbers. + writeUInt64(number, (value << 1) ^ (value >> 63)); + } void writeFixed64(int number, long value) throws IOException; - void writeSFixed64(int number, long value) throws IOException; + default void writeSFixed64(int number, long value) throws IOException { + writeFixed64(number, value); + } - void writeEnum(int number, int value) throws IOException; + default void writeEnum(int number, int value) throws IOException { + writeInt32(number, value); + } void writeBool(int number, boolean value) throws IOException; - void writeDouble(int number, double value) throws IOException; + default void writeDouble(int number, double value) throws IOException { + writeFixed64(number, Double.doubleToRawLongBits(value)); + } - void writeFloat(int number, float value) throws IOException; + default void writeFloat(int number, float value) throws IOException { + writeFixed32(number, Float.floatToRawIntBits(value)); + } void writeBytes(int number, ByteBuffer value) throws IOException; - void writeBytes(int number, byte[] value) throws IOException; + default void writeBytes(int number, byte[] value) throws IOException { + writeBytes(number, value, 0, value.length); + } void writeBytes(int number, byte[] value, int offset, int length) throws IOException; // end high level ops + + /** + * Used to write a sub message that can be optimized by implementation. When the sub writer is complete, flush + * should be invoked to ensure + * @return + * @throws IOException + */ + TagWriter subWriter(int number, boolean nested) throws IOException; } diff --git a/core/src/main/java/org/infinispan/protostream/WrappedMessage.java b/core/src/main/java/org/infinispan/protostream/WrappedMessage.java index 22699e04e..28b29d5df 100644 --- a/core/src/main/java/org/infinispan/protostream/WrappedMessage.java +++ b/core/src/main/java/org/infinispan/protostream/WrappedMessage.java @@ -296,15 +296,13 @@ private static void writeMessage(ImmutableSerializationContext ctx, TagWriter ou if (t.getClass().isEnum()) { ((EnumMarshallerDelegate) marshallerDelegate).encode(WRAPPED_ENUM, (Enum) t, out); } else { - ByteArrayOutputStreamEx buffer = new ByteArrayOutputStreamEx(); - TagWriterImpl nestedCtx = TagWriterImpl.newInstanceNoBuffer(ctx, buffer); - marshallerDelegate.marshall(nestedCtx, null, t); - nestedCtx.flush(); - out.writeBytes(WRAPPED_MESSAGE, buffer.getByteBuffer()); + TagWriter nestedWriter = out.subWriter(WRAPPED_MESSAGE, false); + marshallerDelegate.marshall((ProtobufTagMarshaller.WriteContext) nestedWriter, null, t); + nestedWriter.close(); } } } - out.flush(); + out.close(); } private static void writeContainer(ImmutableSerializationContext ctx, TagWriter out, BaseMarshallerDelegate marshallerDelegate, Object container) throws IOException { @@ -355,7 +353,7 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in String typeName = null; Integer typeId = null; int enumValue = -1; - byte[] messageBytes = null; + TagReader messageReader = null; Object value = null; int fieldCount = 0; int expectedFieldCount = 1; @@ -398,7 +396,7 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in } case WRAPPED_MESSAGE << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { expectedFieldCount = 2; - messageBytes = in.readByteArray(); + messageReader = in.subReaderFromArray(); break; } case WRAPPED_STRING << WireType.TAG_TYPE_NUM_BITS | WireType.WIRETYPE_LENGTH_DELIMITED: { @@ -514,7 +512,7 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in } } - if (value == null && typeName == null && typeId == null && messageBytes == null) { + if (value == null && typeName == null && typeId == null && messageReader == null) { return null; } @@ -533,10 +531,9 @@ private static T readMessage(ImmutableSerializationContext ctx, TagReader in typeName = ctx.getDescriptorByTypeId(typeId).getFullName(); } BaseMarshallerDelegate marshallerDelegate = ((SerializationContextImpl) ctx).getMarshallerDelegate(typeName); - if (messageBytes != null) { + if (messageReader != null) { // it's a Message type - TagReaderImpl nestedInput = TagReaderImpl.newInstance(ctx, messageBytes); - return (T) marshallerDelegate.unmarshall(nestedInput, null); + return (T) marshallerDelegate.unmarshall((ProtobufTagMarshaller.ReadContext) messageReader, null); } else { // it's an Enum EnumMarshaller marshaller = (EnumMarshaller) marshallerDelegate.getMarshaller(); diff --git a/core/src/main/java/org/infinispan/protostream/annotations/impl/GeneratedMarshallerBase.java b/core/src/main/java/org/infinispan/protostream/annotations/impl/GeneratedMarshallerBase.java index ef166d9f1..32879a248 100644 --- a/core/src/main/java/org/infinispan/protostream/annotations/impl/GeneratedMarshallerBase.java +++ b/core/src/main/java/org/infinispan/protostream/annotations/impl/GeneratedMarshallerBase.java @@ -3,6 +3,7 @@ import java.io.IOException; import org.infinispan.protostream.ProtobufTagMarshaller; +import org.infinispan.protostream.TagWriter; import org.infinispan.protostream.impl.BaseMarshallerDelegate; import org.infinispan.protostream.impl.ByteArrayOutputStreamEx; import org.infinispan.protostream.impl.Log; @@ -46,6 +47,17 @@ protected final void writeNestedMessage(BaseMarshallerDelegate marshaller throw log.maxNestedMessageDepth(maxNestedMessageDepth, message.getClass()); } + if (ctx instanceof TagWriter) { + TagWriter nestedWriter = ((TagWriter) ctx).subWriter(fieldNumber, true); + marshallerDelegate.marshall((ProtobufTagMarshaller.WriteContext) nestedWriter, null, message); + nestedWriter.close(); + } else { + handleNonTagWriter(marshallerDelegate, ctx, fieldNumber, message); + } + } + + private void handleNonTagWriter(BaseMarshallerDelegate marshallerDelegate, ProtobufTagMarshaller.WriteContext ctx, + int fieldNumber, T message) throws IOException { ByteArrayOutputStreamEx baos = new ByteArrayOutputStreamEx(); TagWriterImpl nested = TagWriterImpl.newNestedInstance(ctx, baos); writeMessage(marshallerDelegate, nested, message); diff --git a/core/src/main/java/org/infinispan/protostream/impl/ByteArrayOutputStreamEx.java b/core/src/main/java/org/infinispan/protostream/impl/ByteArrayOutputStreamEx.java index 081c4856b..cacf3a673 100644 --- a/core/src/main/java/org/infinispan/protostream/impl/ByteArrayOutputStreamEx.java +++ b/core/src/main/java/org/infinispan/protostream/impl/ByteArrayOutputStreamEx.java @@ -21,4 +21,14 @@ public ByteArrayOutputStreamEx(int size) { public synchronized ByteBuffer getByteBuffer() { return ByteBuffer.wrap(buf, 0, count); } + + public int skipFixedVarint() { + int prev = count; + count += 5; + return prev; + } + + public void writePositiveFixedVarint(int pos) { + TagWriterImpl.writePositiveFixedVarint(buf, pos, count - pos - 5); + } } \ No newline at end of file diff --git a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java index 5a3585cea..eae920764 100644 --- a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java +++ b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java @@ -107,11 +107,6 @@ public boolean skipField(int tag) throws IOException { return decoder.skipField(tag); } - @Override - public long readUInt64() throws IOException { - return decoder.readVarint64(); - } - @Override public long readInt64() throws IOException { return decoder.readVarint64(); @@ -132,19 +127,9 @@ public int readFixed32() throws IOException { return decoder.readFixed32(); } - @Override - public double readDouble() throws IOException { - return Double.longBitsToDouble(decoder.readFixed64()); - } - - @Override - public float readFloat() throws IOException { - return Float.intBitsToFloat(decoder.readFixed32()); - } - @Override public boolean readBool() throws IOException { - return decoder.readVarint64() != 0L; + return decoder.readRawByte() != 0L; } @Override @@ -165,23 +150,9 @@ public ByteBuffer readByteBuffer() throws IOException { } @Override - public int readUInt32() throws IOException { - return decoder.readVarint32(); - } - - @Override - public int readEnum() throws IOException { - return decoder.readVarint32(); - } - - @Override - public int readSFixed32() throws IOException { - return decoder.readFixed32(); - } - - @Override - public long readSFixed64() throws IOException { - return decoder.readFixed64(); + public TagReader subReaderFromArray() throws IOException { + int length = decoder.readVarint32(); + return new TagReaderImpl(serCtx, decoder.decoderFromLength(length)); } @Override @@ -246,7 +217,6 @@ public TagReader getReader() { @Override public byte[] fullBufferArray() throws IOException { checkBufferUnused("fullBufferArray"); - return decoder.getBufferArray(); } @@ -397,6 +367,8 @@ final int readVarint32() throws IOException { abstract void popLimit(int oldLimit); + abstract Decoder decoderFromLength(int length) throws IOException; + /** * Sets a hard limit on how many bytes we can continue to read while parsing a message from current position. This is * useful to prevent corrupted or malicious messages with wrong length values to abuse memory allocation. Initially @@ -413,11 +385,7 @@ private static final class ByteArrayDecoder extends Decoder { // all positions are absolute private final int start; - private final int stop; private int pos; - private int end; // limit adjusted - - // number of bytes we are allowed to read starting from start position private int limit; private ByteArrayDecoder(byte[] array, int offset, int length) { @@ -438,9 +406,7 @@ private ByteArrayDecoder(byte[] array, int offset, int length) { } this.array = array; this.start = this.pos = offset; - this.limit = length; - this.stop = this.end = offset + length; - adjustEnd(); + this.limit = offset + length; } @Override @@ -448,51 +414,49 @@ int pushLimit(int limit) throws IOException { if (limit < 0) { throw log.negativeLength(); } - limit += pos - start; + limit += pos; int oldLimit = this.limit; if (limit > oldLimit) { // the end of a nested message cannot go beyond the end of the outer message throw log.messageTruncated(); } this.limit = limit; - adjustEnd(); return oldLimit; } @Override void popLimit(int oldLimit) { limit = oldLimit; - adjustEnd(); - } - - private void adjustEnd() { - end = stop - start > limit ? start + limit : stop; } @Override int getEnd() { - return end; + return limit; } @Override int getPos() { - return pos; + return pos - start; } @Override byte[] getBufferArray() throws IOException { - return array; + if (pos == 0 && limit == array.length) { + return array; + } else { + return Arrays.copyOfRange(array, pos, limit); + } } @Override boolean isAtEnd() { - return pos == end; + return pos == limit; } @Override String readString() throws IOException { int length = readVarint32(); - if (length > 0 && length <= end - pos) { + if (length > 0 && length <= limit - pos) { String value = new String(array, pos, length, UTF8); pos += length; return value; @@ -508,7 +472,7 @@ String readString() throws IOException { @Override ByteBuffer readRawByteBuffer(int length) throws IOException { - if (length > 0 && length <= end - pos) { + if (length > 0 && length <= limit - pos) { int from = pos; pos += length; return ByteBuffer.wrap(array, from, length).slice(); @@ -524,7 +488,7 @@ ByteBuffer readRawByteBuffer(int length) throws IOException { @Override protected void skipVarint() throws IOException { - if (end - pos >= MAX_VARINT_SIZE) { + if (limit - pos >= MAX_VARINT_SIZE) { for (int i = 0; i < MAX_VARINT_SIZE; i++) { if (array[pos++] >= 0) { return; @@ -543,7 +507,7 @@ protected void skipVarint() throws IOException { @Override long readVarint64() throws IOException { long value = 0; - if (end - pos >= MAX_VARINT_SIZE) { + if (limit - pos >= MAX_VARINT_SIZE) { for (int i = 0; i < 64; i += 7) { byte b = array[pos++]; value |= (long) (b & 0x7F) << i; @@ -606,7 +570,7 @@ byte readRawByte() throws IOException { @Override byte[] readRawByteArray(int length) throws IOException { - if (length > 0 && length <= end - pos) { + if (length > 0 && length <= limit - pos) { int from = pos; pos += length; return Arrays.copyOfRange(array, from, pos); @@ -625,13 +589,23 @@ protected void skipRawBytes(int length) throws IOException { if (length < 0) { throw log.negativeLength(); } - if (length <= end - pos) { + if (length <= limit - pos) { pos += length; return; } throw log.messageTruncated(); } + @Override + Decoder decoderFromLength(int length) throws IOException { + int currentPos = pos; + if (length + currentPos > limit) { + throw log.messageTruncated(); + } + pos += length; + return new ByteArrayDecoder(array, currentPos, length); + } + @Override int setGlobalLimit(int globalLimit) { return Integer.MAX_VALUE; @@ -695,7 +669,18 @@ int getPos() { @Override byte[] getBufferArray() throws IOException { - return buf.array(); + if (end < limit) { + throw log.messageTruncated(); + } + int pos = buf.position(); + int remaining = buf.remaining(); + if (pos == 0 && end == remaining && buf.hasArray()) { + return buf.array(); + } else { + byte[] bytes = new byte[remaining]; + buf.get(bytes, 0, remaining); + return bytes; + } } @@ -834,6 +819,13 @@ protected void skipRawBytes(int length) throws IOException { throw log.messageTruncated(); } + @Override + Decoder decoderFromLength(int length) throws IOException { + ByteBuffer buffer = buf.slice().limit(length); + buf.position(buffer.position() + length); + return new ByteBufferDecoder(buffer); + } + @Override int setGlobalLimit(int globalLimit) { return Integer.MAX_VALUE; @@ -981,11 +973,12 @@ int getPos() { @Override byte[] getBufferArray() throws IOException { - if (globalLimit == Integer.MAX_VALUE) { + int readLimit = Math.min(limit, globalLimit); + if (readLimit == Integer.MAX_VALUE) { pos = Integer.MAX_VALUE; return in.readAllBytes(); } else { - int length = globalLimit - pos; + int length = readLimit - pos; return readRawByteArray(length); } } @@ -1102,6 +1095,12 @@ private void skipNBytes(long n) throws IOException { } } } + + @Override + Decoder decoderFromLength(int length) throws IOException { + byte[] bytes = readRawByteArray(length); + return new ByteArrayDecoder(bytes, 0, length); + } } } diff --git a/core/src/main/java/org/infinispan/protostream/impl/TagWriterImpl.java b/core/src/main/java/org/infinispan/protostream/impl/TagWriterImpl.java index 37d83934e..e91b8a1af 100644 --- a/core/src/main/java/org/infinispan/protostream/impl/TagWriterImpl.java +++ b/core/src/main/java/org/infinispan/protostream/impl/TagWriterImpl.java @@ -10,15 +10,16 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; import org.infinispan.protostream.ImmutableSerializationContext; import org.infinispan.protostream.ProtobufTagMarshaller; -import org.infinispan.protostream.ProtobufUtil; import org.infinispan.protostream.TagWriter; import org.infinispan.protostream.descriptors.WireType; + /** * @author anistor@redhat.com * @since 3.0 @@ -66,9 +67,13 @@ public static TagWriterImpl newNestedInstance(ProtobufTagMarshaller.WriteContext } public static TagWriterImpl newInstance(ImmutableSerializationContext serCtx, OutputStream output) { - return new TagWriterImpl((SerializationContextImpl) serCtx, new OutputStreamEncoder(output, ProtobufUtil.DEFAULT_STREAM_BUFFER_SIZE)); + return new TagWriterImpl((SerializationContextImpl) serCtx, new OutputStreamNoBufferEncoder(output)); } + /** + * @deprecated since 4.6.3 Please use {@link #newInstance(ImmutableSerializationContext, OutputStream)} with a {@link java.io.BufferedOutputStream instead} + */ + @Deprecated public static TagWriterImpl newInstance(ImmutableSerializationContext serCtx, OutputStream output, int bufferSize) { return new TagWriterImpl((SerializationContextImpl) serCtx, new OutputStreamEncoder(output, bufferSize)); } @@ -105,13 +110,8 @@ public void flush() throws IOException { } @Override - public void writeTag(int number, int wireType) throws IOException { - encoder.writeVarint32(WireType.makeTag(number, wireType)); - } - - @Override - public void writeTag(int number, WireType wireType) throws IOException { - encoder.writeVarint32(WireType.makeTag(number, wireType)); + public void close() throws IOException { + encoder.close(); } @Override @@ -136,36 +136,16 @@ public void writeString(int number, String value) throws IOException { encoder.writeBytes(utf8buffer, 0, utf8buffer.length); } - @Override - public void writeInt32(int number, int value) throws IOException { - if (value >= 0) { - encoder.writeUInt32Field(number, value); - } else { - encoder.writeUInt64Field(number, value); - } - } - @Override public void writeUInt32(int number, int value) throws IOException { encoder.writeUInt32Field(number, value); } - @Override - public void writeSInt32(int number, int value) throws IOException { - // Roll the bits in order to move the sign bit from position 31 to position 0, to reduce the wire length of negative numbers. - encoder.writeUInt32Field(number, (value << 1) ^ (value >> 31)); - } - @Override public void writeFixed32(int number, int value) throws IOException { encoder.writeFixed32Field(number, value); } - @Override - public void writeSFixed32(int number, int value) throws IOException { - writeFixed32(number, value); - } - @Override public void writeInt64(int number, long value) throws IOException { encoder.writeUInt64Field(number, value); @@ -176,59 +156,54 @@ public void writeUInt64(int number, long value) throws IOException { encoder.writeUInt64Field(number, value); } - @Override - public void writeSInt64(int number, long value) throws IOException { - // Roll the bits in order to move the sign bit from position 63 to position 0, to reduce the wire length of negative numbers. - encoder.writeUInt64Field(number, (value << 1) ^ (value >> 63)); - } - @Override public void writeFixed64(int number, long value) throws IOException { encoder.writeFixed64Field(number, value); } - @Override - public void writeSFixed64(int number, long value) throws IOException { - writeFixed64(number, value); - } - - @Override - public void writeEnum(int number, int value) throws IOException { - writeInt32(number, value); - } - @Override public void writeBool(int number, boolean value) throws IOException { encoder.writeBoolField(number, value); } - @Override - public void writeDouble(int number, double value) throws IOException { - encoder.writeFixed64Field(number, Double.doubleToRawLongBits(value)); - } - - @Override - public void writeFloat(int number, float value) throws IOException { - encoder.writeFixed32Field(number, Float.floatToRawIntBits(value)); - } - @Override public void writeBytes(int number, ByteBuffer value) throws IOException { encoder.writeLengthDelimitedField(number, value.remaining()); encoder.writeBytes(value); } - @Override - public void writeBytes(int number, byte[] value) throws IOException { - writeBytes(number, value, 0, value.length); - } - @Override public void writeBytes(int number, byte[] value, int offset, int length) throws IOException { encoder.writeLengthDelimitedField(number, length); encoder.writeBytes(value, offset, length); } + @Override + public TagWriter subWriter(int number, boolean nested) throws IOException { + if (encoder.supportsFixedVarint()) { + writeVarint32(WireType.makeTag(number, WireType.WIRETYPE_LENGTH_DELIMITED)); + return nested ? new TagWriterImpl(this, new FixedVarintWrappedEncoder((FixedVarintEncoder) encoder)) : + new TagWriterImpl(serCtx, new FixedVarintWrappedEncoder((FixedVarintEncoder) encoder)); + } + // This ensures we aren't allocating a byte[] larger than we actually need + int space = bytesAvailableForVariableEncoding(encoder.remainingSpace()); + return nested ? new TagWriterImpl(this, new ArrayBasedWrappedEncoder(space, encoder, number)) : + new TagWriterImpl(serCtx, new ArrayBasedWrappedEncoder(space, encoder, number)); + } + + // Returns how many bytes are usable for data from a given range of bytes when inserting a variable int before + // the actual data + private int bytesAvailableForVariableEncoding(int spaceAllowed) { + if (spaceAllowed < 128) { + return spaceAllowed - 1; + } else if (spaceAllowed < 16384) { + return spaceAllowed - 2; + } else if (spaceAllowed < 2097151) { + return spaceAllowed - 3; + } + return spaceAllowed - (spaceAllowed < 268435455 ? 4 : 5); + } + @Override public void writeRawByte(byte value) throws IOException { encoder.writeByte(value); @@ -306,6 +281,14 @@ private abstract static class Encoder { void flush() throws IOException { } + void close() throws IOException { + flush(); + } + + int remainingSpace() { + return Integer.MAX_VALUE; + } + // high level ops, writing fields void writeUInt32Field(int fieldNumber, int value) throws IOException { @@ -353,6 +336,20 @@ void writeLengthDelimitedField(int fieldNumber, int length) throws IOException { abstract void writeBytes(byte[] value, int offset, int length) throws IOException; abstract void writeBytes(ByteBuffer value) throws IOException; + + boolean supportsFixedVarint() { + return false; + } + } + + abstract static class FixedVarintEncoder extends Encoder { + abstract int skipFixedVarint(); + + abstract void writePositiveFixedVarint(int pos); + @Override + boolean supportsFixedVarint() { + return true; + } } /** @@ -425,7 +422,7 @@ void writeFixed64(long value) { /** * Writes to a user provided byte array. */ - private static class ByteArrayEncoder extends Encoder { + private static class ByteArrayEncoder extends FixedVarintEncoder { private final byte[] array; @@ -457,6 +454,7 @@ private ByteArrayEncoder(byte[] array, int offset, int length) { this.pos = offset; } + @Override protected final int remainingSpace() { return limit - pos; } @@ -594,6 +592,26 @@ final void writeFixed64(long value) throws IOException { throw log.outOfWriteBufferSpace(e); } } + + @Override + int skipFixedVarint() { + int prev = pos; + pos += 5; + return prev; + } + + @Override + void writePositiveFixedVarint(int pos) { + TagWriterImpl.writePositiveFixedVarint(array, pos, this.pos - pos - 5); + } + } + + public static void writePositiveFixedVarint(byte[] array, int pos, int length) { + array[pos++] = (byte) (length & 0x7F | 0x80); + array[pos++] = (byte) ((length >>> 7) & 0x7F | 0x80); + array[pos++] = (byte) ((length >>> 14) & 0x7F | 0x80); + array[pos++] = (byte) ((length >>> 21) & 0x7F | 0x80); + array[pos] = (byte) ((length >>> 28) & 0x7F); } /** @@ -722,7 +740,7 @@ void writeFixed64(long value) throws IOException { } } - private static class OutputStreamNoBufferEncoder extends Encoder { + private static class OutputStreamNoBufferEncoder extends FixedVarintEncoder { private final OutputStream out; @@ -806,11 +824,28 @@ void flush() throws IOException { super.flush(); out.flush(); } + + @Override + int skipFixedVarint() { + return ((ByteArrayOutputStreamEx) out).skipFixedVarint(); + } + + @Override + void writePositiveFixedVarint(int pos) { + ((ByteArrayOutputStreamEx) out).writePositiveFixedVarint(pos); + } + + @Override + boolean supportsFixedVarint() { + return out instanceof ByteArrayOutputStreamEx; + } } /** * Writes to an {@link OutputStream} and performs internal buffering to minimize the number of stream writes. + * @Deprecated this is to be removed in next major */ + @Deprecated private static final class OutputStreamEncoder extends Encoder { private final ByteArrayEncoder buffer; @@ -922,4 +957,202 @@ void flush() throws IOException { buffer.flushToStream(out); } } + + private static class FixedVarintWrappedEncoder extends Encoder { + private final FixedVarintEncoder parentEncoder; + private final int originalPos; + private boolean closed; + + private FixedVarintWrappedEncoder(FixedVarintEncoder parentEncoder) { + this.parentEncoder = parentEncoder; + this.originalPos = parentEncoder.skipFixedVarint(); + } + + @Override + void writeVarint32(int value) throws IOException { + parentEncoder.writeVarint32(value); + } + + @Override + void writeVarint64(long value) throws IOException { + parentEncoder.writeVarint64(value); + } + + @Override + void writeFixed32(int value) throws IOException { + parentEncoder.writeFixed32(value); + } + + @Override + void writeFixed64(long value) throws IOException { + parentEncoder.writeFixed64(value); + } + + @Override + void writeByte(byte value) throws IOException { + parentEncoder.writeByte(value); + } + + @Override + void writeBytes(byte[] value, int offset, int length) throws IOException { + parentEncoder.writeBytes(value, offset, length); + } + + @Override + void writeBytes(ByteBuffer value) throws IOException { + parentEncoder.writeBytes(value); + } + + @Override + void close() throws IOException { + if (!closed) { + closed = true; + parentEncoder.writePositiveFixedVarint(originalPos); + } + } + } + + private static class ArrayBasedWrappedEncoder extends Encoder { + private final int maxSize; + private final Encoder parentEncoder; + private final int number; + private int pos = 0; + private byte[] bytes; + private boolean closed; + + public ArrayBasedWrappedEncoder(int maxSize, Encoder parentEncoder, int number) { + this.maxSize = maxSize; + this.parentEncoder = parentEncoder; + this.number = number; + bytes = new byte[Math.min(maxSize, 32)]; + } + + @Override + void writeVarint32(int value) throws IOException { + ensureSize(5); + try { + while (true) { + if ((value & 0xFFFFFF80) == 0) { + bytes[pos++] = (byte) value; + break; + } else { + bytes[pos++] = (byte) (value & 0x7F | 0x80); + value >>>= 7; + } + } + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeVarint64(long value) throws IOException { + ensureSize(10); + try { + while (true) { + if ((value & 0xFFFFFFFFFFFFFF80L) == 0) { + bytes[pos++] = (byte) value; + break; + } else { + bytes[pos++] = (byte) ((int) value & 0x7F | 0x80); + value >>>= 7; + } + } + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeFixed32(int value) throws IOException { + ensureSize(4); + try { + bytes[pos++] = (byte) (value & 0xFF); + bytes[pos++] = (byte) ((value >> 8) & 0xFF); + bytes[pos++] = (byte) ((value >> 16) & 0xFF); + bytes[pos++] = (byte) ((value >> 24) & 0xFF); + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeFixed64(long value) throws IOException { + ensureSize(8); + try { + bytes[pos++] = (byte) (value & 0xFF); + bytes[pos++] = (byte) ((value >> 8) & 0xFF); + bytes[pos++] = (byte) ((value >> 16) & 0xFF); + bytes[pos++] = (byte) ((value >> 24) & 0xFF); + bytes[pos++] = (byte) ((int) (value >> 32) & 0xFF); + bytes[pos++] = (byte) ((int) (value >> 40) & 0xFF); + bytes[pos++] = (byte) ((int) (value >> 48) & 0xFF); + bytes[pos++] = (byte) ((int) (value >> 56) & 0xFF); + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeByte(byte value) throws IOException { + ensureSize(1); + try { + bytes[pos++] = value; + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeBytes(byte[] value, int offset, int length) throws IOException { + ensureSize(length); + try { + System.arraycopy(value, offset, bytes, pos, length); + pos += length; + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + + @Override + void writeBytes(ByteBuffer value) throws IOException { + int length = value.remaining(); + ensureSize(length); + if (value.hasArray()) { + writeBytes(value.array(), value.arrayOffset() + value.position(), length); + value.position(value.position() + length); + } else { + try { + value.get(bytes, pos, length); + pos += length; + } catch (IndexOutOfBoundsException e) { + throw log.outOfWriteBufferSpace(e); + } + } + } + + @Override + void close() throws IOException { + if (!closed) { + closed = true; + parentEncoder.writeLengthDelimitedField(number, pos); + parentEncoder.writeBytes(bytes, 0, pos); + } + } + + private void ensureSize(int possibleLength) { + int targetSize = pos + possibleLength; + int currentSize = bytes.length; + while (targetSize > currentSize) { + if (currentSize > maxSize) { + currentSize = maxSize; + break; + } + currentSize <<= 1; + } + if (currentSize != bytes.length) { + bytes = Arrays.copyOf(bytes, currentSize); + } + } + } } diff --git a/core/src/test/java/org/infinispan/protostream/ProtobufUtilTest.java b/core/src/test/java/org/infinispan/protostream/ProtobufUtilTest.java index ba0773f37..e8f0b04c7 100644 --- a/core/src/test/java/org/infinispan/protostream/ProtobufUtilTest.java +++ b/core/src/test/java/org/infinispan/protostream/ProtobufUtilTest.java @@ -2,7 +2,6 @@ import static org.infinispan.protostream.domain.Account.Currency.BRL; import static org.infinispan.protostream.domain.Account.Currency.USD; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -58,7 +57,8 @@ public void testComputeMessageSize() throws Exception { messageSize = ProtobufUtil.computeWrappedMessageSize(ctx, user); - assertEquals(expectedMessageSize, messageSize); + // Actual array is 4 bigger because of fixed Varint + assertEquals(expectedMessageSize, messageSize + 4); } @Test(expected = MalformedProtobufException.class) @@ -118,7 +118,13 @@ public void testMessageWrapping() throws Exception { byte[] userBytes2 = ProtobufUtil.toByteArray(ctx, new WrappedMessage(user)); // assert that toWrappedByteArray works correctly as a shorthand for toByteArray on a WrappedMessage - assertArrayEquals(userBytes1, userBytes2); + assertWrappedArraysEqual(ctx, userBytes1, userBytes2); + } + + public static void assertWrappedArraysEqual(ImmutableSerializationContext ctx, byte[] array1, byte[] array2) throws IOException { + Object user1 = ProtobufUtil.fromWrappedByteArray(ctx, array1); + Object user2 = ProtobufUtil.fromWrappedByteArray(ctx, array2); + assertEquals(user1, user2); } @Test @@ -484,7 +490,7 @@ private void testJsonConversion(ImmutableSerializationContext ctx, T object, assertValid(json); byte[] bytes = ProtobufUtil.fromCanonicalJSON(ctx, new StringReader(json)); assertEquals(object, ProtobufUtil.fromWrappedByteArray(ctx, bytes)); - assertArrayEquals(marshalled, bytes); + assertWrappedArraysEqual(ctx, marshalled, bytes); } private void testJsonConversion(ImmutableSerializationContext ctx, T object) throws IOException { diff --git a/integrationtests/src/test/java/org/infinispan/protostream/integrationtests/processor/annotated_package/AnnotationOnPackageIntegrationTest.java b/integrationtests/src/test/java/org/infinispan/protostream/integrationtests/processor/annotated_package/AnnotationOnPackageIntegrationTest.java index 2172f9363..e65b43997 100644 --- a/integrationtests/src/test/java/org/infinispan/protostream/integrationtests/processor/annotated_package/AnnotationOnPackageIntegrationTest.java +++ b/integrationtests/src/test/java/org/infinispan/protostream/integrationtests/processor/annotated_package/AnnotationOnPackageIntegrationTest.java @@ -1,17 +1,16 @@ package org.infinispan.protostream.integrationtests.processor.annotated_package; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.io.StringReader; -import java.util.Arrays; import java.util.ServiceLoader; import org.infinispan.protostream.GeneratedSchema; import org.infinispan.protostream.ProtobufUtil; +import org.infinispan.protostream.ProtobufUtilTest; import org.infinispan.protostream.SerializationContext; import org.infinispan.protostream.SerializationContextInitializer; import org.infinispan.protostream.annotations.AutoProtoSchemaBuilder; @@ -65,7 +64,7 @@ public void testUserWithLotsOfFields() throws IOException { String json = ProtobufUtil.toCanonicalJSON(serCtx, userBytes, true); byte[] jsonBytes = ProtobufUtil.fromCanonicalJSON(serCtx, new StringReader(json)); - assertArrayEquals(userBytes, jsonBytes); + ProtobufUtilTest.assertWrappedArraysEqual(serCtx, userBytes, jsonBytes); } @AutoProtoSchemaBuilder(dependsOn = ReusableInitializer.class, includeClasses = DependentInitializer.C.class, service = true)