From 9078b05a9b016fc75154340f04a799918fcdefd3 Mon Sep 17 00:00:00 2001 From: William Burns Date: Thu, 4 May 2023 16:57:26 -0700 Subject: [PATCH] IPROTO-264 Remove stream buffer when unmarshalling objects --- .../protostream/impl/TagReaderImpl.java | 357 +++++------------- 1 file changed, 93 insertions(+), 264 deletions(-) diff --git a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java index dbb89b75d..5c1801da2 100644 --- a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java +++ b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java @@ -7,20 +7,18 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.PushbackInputStream; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.infinispan.protostream.ImmutableSerializationContext; import org.infinispan.protostream.MalformedProtobufException; import org.infinispan.protostream.ProtobufTagMarshaller; -import org.infinispan.protostream.ProtobufUtil; import org.infinispan.protostream.TagReader; import org.infinispan.protostream.descriptors.WireType; @@ -63,7 +61,7 @@ private TagReaderImpl(SerializationContextImpl serCtx, Decoder decoder) { } public static TagReaderImpl newNestedInstance(ProtobufTagMarshaller.ReadContext parent, InputStream input) { - return new TagReaderImpl((TagReaderImpl) parent, new InputStreamDecoder(input, ProtobufUtil.DEFAULT_STREAM_BUFFER_SIZE)); + return new TagReaderImpl((TagReaderImpl) parent, new InputStreamDecoder(input)); } public static TagReaderImpl newNestedInstance(ProtobufTagMarshaller.ReadContext parent, byte[] buf) { @@ -71,7 +69,7 @@ public static TagReaderImpl newNestedInstance(ProtobufTagMarshaller.ReadContext } public static TagReaderImpl newInstance(ImmutableSerializationContext serCtx, InputStream input) { - return new TagReaderImpl((SerializationContextImpl) serCtx, new InputStreamDecoder(input, ProtobufUtil.DEFAULT_STREAM_BUFFER_SIZE)); + return new TagReaderImpl((SerializationContextImpl) serCtx, new InputStreamDecoder(input)); } public static TagReaderImpl newInstance(ImmutableSerializationContext serCtx, ByteBuffer buf) { @@ -846,49 +844,33 @@ private static final class InputStreamDecoder extends Decoder { private final InputStream in; - private final byte[] buf; - - /** - * The end position of buffered data. This is limit adjusted. - */ - private int end; - /** * Current position. */ private int pos; - /** - * Number of bytes already read before the current buffer. - */ - private int bytesBeforeStart = 0; - - /** - * Number of bytes after the limit. - */ - private int bytesAfterLimit = 0; - /** * Absolute position (from start of input data) of the last byte we are allowed to read by last pushLimit. */ private int limit = Integer.MAX_VALUE; - private InputStreamDecoder(InputStream in, int bufferSize) { + private InputStreamDecoder(InputStream in) { if (in == null) { throw new IllegalArgumentException("input stream cannot be null"); } - this.in = in; - bufferSize = Math.max(bufferSize, MAX_VARINT_SIZE * 2); - this.buf = new byte[bufferSize]; - this.end = 0; - this.pos = 0; + if (in.markSupported()) { + this.in = in; + } else { + this.in = new PushbackInputStream(in); + } } @Override String readString() throws IOException { int length = readVarint32(); - if (length > 0 && length <= end - pos) { - String value = new String(buf, pos, length, UTF8); + if (length > 0 && length <= limit - pos) { + byte[] bytes = readRawByteArray(length); + String value = new String(bytes, 0, length, UTF8); pos += length; return value; } @@ -898,51 +880,20 @@ String readString() throws IOException { if (length < 0) { throw log.negativeLength(); } - if (length <= buf.length) { - fillBuffer(length); - String value = new String(buf, pos, length, UTF8); - pos += length; - return value; - } - return new String(readRawBytesLarge(length), UTF8); + throw log.messageTruncated(); } @Override ByteBuffer readRawByteBuffer(int length) throws IOException { - if (length <= end - pos && length > 0) { - int from = pos; - pos += length; - return ByteBuffer.wrap(Arrays.copyOfRange(buf, from, pos)); - } - if (length == 0) { - return EMPTY_BUFFER; - } - if (length < 0) { - throw log.negativeLength(); - } - if (length <= buf.length) { - fillBuffer(length); - int from = pos; - pos += length; - return ByteBuffer.wrap(Arrays.copyOfRange(buf, from, pos)); - } - // TODO [anistor] implement a readRawByteBufferLarge, using off-heap allocation - return ByteBuffer.wrap(readRawBytesLarge(length)); + byte[] bytes = readRawByteArray(length); + return ByteBuffer.wrap(bytes); } @Override protected void skipVarint() throws IOException { - if (end - pos >= MAX_VARINT_SIZE) { - for (int i = 0; i < MAX_VARINT_SIZE; i++) { - if (buf[pos++] >= 0) { - return; - } - } - } else { - for (int i = 0; i < MAX_VARINT_SIZE; i++) { - if (readRawByte() >= 0) { - return; - } + for (int i = 0; i < MAX_VARINT_SIZE; i++) { + if (readRawByte() >= 0) { + return; } } throw log.malformedVarint(); @@ -951,21 +902,11 @@ protected void skipVarint() throws IOException { @Override long readVarint64() throws IOException { long value = 0; - if (end - pos >= MAX_VARINT_SIZE) { - for (int i = 0; i < 64; i += 7) { - byte b = buf[pos++]; - value |= (long) (b & 0x7F) << i; - if (b >= 0) { - return value; - } - } - } else { - for (int i = 0; i < 64; i += 7) { - byte b = readRawByte(); - value |= (long) (b & 0x7F) << i; - if (b >= 0) { - return value; - } + for (int i = 0; i < 64; i += 7) { + byte b = readRawByte(); + value |= (long) (b & 0x7F) << i; + if (b >= 0) { + return value; } } throw log.malformedVarint(); @@ -973,32 +914,28 @@ long readVarint64() throws IOException { @Override int readFixed32() throws IOException { - if (end - pos < FIXED_32_SIZE) { - fillBuffer(FIXED_32_SIZE); + if (limit - pos < FIXED_32_SIZE) { + throw log.messageTruncated(); } - int value = (buf[pos] & 0xFF) - | ((buf[pos + 1] & 0xFF) << 8) - | ((buf[pos + 2] & 0xFF) << 16) - | ((buf[pos + 3] & 0xFF) << 24); - pos += FIXED_32_SIZE; - return value; + return (readRawByte() & 0xFF) + | ((readRawByte() & 0xFF) << 8) + | ((readRawByte() & 0xFF) << 16) + | ((readRawByte() & 0xFF) << 24); } @Override long readFixed64() throws IOException { - if (end - pos < FIXED_64_SIZE) { - fillBuffer(FIXED_64_SIZE); + if (limit - pos < FIXED_64_SIZE) { + throw log.messageTruncated(); } - long value = (buf[pos] & 0xFFL) - | ((buf[pos + 1] & 0xFFL) << 8) - | ((buf[pos + 2] & 0xFFL) << 16) - | ((buf[pos + 3] & 0xFFL) << 24) - | ((buf[pos + 4] & 0xFFL) << 32) - | ((buf[pos + 5] & 0xFFL) << 40) - | ((buf[pos + 6] & 0xFFL) << 48) - | ((buf[pos + 7] & 0xFFL) << 56); - pos += FIXED_64_SIZE; - return value; + return (readRawByte() & 0xFFL) + | ((readRawByte() & 0xFFL) << 8) + | ((readRawByte() & 0xFFL) << 16) + | ((readRawByte() & 0xFFL) << 24) + | ((readRawByte() & 0xFFL) << 32) + | ((readRawByte() & 0xFFL) << 40) + | ((readRawByte() & 0xFFL) << 48) + | ((readRawByte() & 0xFFL) << 56); } @Override @@ -1016,37 +953,24 @@ int pushLimit(int limit) throws IOException { if (limit < 0) { throw log.negativeLength(); } - limit = bytesBeforeStart + pos + limit; + limit = pos + limit; int oldLimit = this.limit; if (limit > oldLimit) { // the end of a nested message cannot go beyond the end of the outer message throw log.messageTruncated(); } this.limit = limit; - adjustEnd(); return oldLimit; } @Override void popLimit(int oldLimit) { limit = oldLimit; - adjustEnd(); - } - - private void adjustEnd() { - end += bytesAfterLimit; - int absEnd = bytesBeforeStart + end; - if (absEnd > limit) { - bytesAfterLimit = absEnd - limit; - end -= bytesAfterLimit; - } else { - bytesAfterLimit = 0; - } } @Override int getEnd() { - return end; + return limit; } @Override @@ -1056,8 +980,13 @@ int getPos() { @Override byte[] getBufferArray() throws IOException { - fillBuffer(buf.length); - return buf; + if (globalLimit == Integer.MAX_VALUE) { + pos = Integer.MAX_VALUE; + return in.readAllBytes(); + } else { + int length = globalLimit - pos; + return readRawByteArray(length); + } } InputStream getInputStream() { @@ -1066,75 +995,63 @@ InputStream getInputStream() { @Override boolean isAtEnd() throws IOException { - return pos == end && !tryFillBuffer(1); - } - - /** - * Ensure that at least the requested number of bytes, or more, but no more than the buffer capacity are - * available in the buffer. - */ - private void fillBuffer(int requestedBytes) throws IOException { - if (!tryFillBuffer(requestedBytes)) { - throw log.messageTruncated(); - } - } - - /** - * Tries to fill the buffer with at least the requested number of bytes, or more, but no more than the buffer - * capacity and indicates if the operation succeeded or failed either due to lack of available data in stream or - * by reaching the limit set with pushLimit. - */ - private boolean tryFillBuffer(int requestedBytes) throws IOException { - if (requestedBytes + pos <= end) { - // all requested bytes already available; nothing to do + if (pos == limit) { return true; } - - if (requestedBytes + bytesBeforeStart + pos > limit) { - // oops, should not exceed the limit + if (in.available() > 0) { return false; } - - // slide existing data if some bytes are already consumed - if (pos > 0) { - if (end > pos) { - System.arraycopy(buf, pos, buf, 0, end - pos); - } - bytesBeforeStart += pos; - end -= pos; - pos = 0; + if (in instanceof PushbackInputStream) { + return isPushbackDone(); } + return isMarkDone(); + } - // fill the space at the end with data from stream - int read = in.read(buf, end, buf.length - end); - if (read <= 0) { - // EOF or maybe our buffer is full and we attempted to read 0 bytes - return false; + private boolean isPushbackDone() throws IOException { + int intVal = in.read(); + if (intVal < 0) { + return true; } - end += read; + ((PushbackInputStream) in).unread(intVal); + return false; + } - if (requestedBytes + bytesBeforeStart - globalLimit > 0) { - throw log.globalLimitExceeded(); + private boolean isMarkDone() throws IOException { + in.mark(1); + if (in.read() < 0) { + return true; } - - adjustEnd(); - return end >= requestedBytes || end == buf.length || tryFillBuffer(requestedBytes); + in.reset(); + return false; } + @Override byte readRawByte() throws IOException { - if (pos == end) { - fillBuffer(1); + if (pos == limit) { + throw log.messageTruncated(); + } + int byteValue = in.read(); + if (byteValue < 0) { + throw log.messageTruncated(); } - return buf[pos++]; + pos++; + return (byte) byteValue; } @Override byte[] readRawByteArray(int length) throws IOException { - if (length > 0 && length <= end - pos) { - int from = pos; + if (length > 0 && length <= limit - pos) { pos += length; - return Arrays.copyOfRange(buf, from, pos); + if (pos > globalLimit) { + throw log.globalLimitExceeded(); + } + byte[] array = new byte[length]; + int readAmount = in.read(array); + if (readAmount != length) { + throw log.messageTruncated(); + } + return array; } if (length == 0) { return EMPTY; @@ -1142,110 +1059,22 @@ byte[] readRawByteArray(int length) throws IOException { if (length < 0) { throw log.negativeLength(); } - if (length <= buf.length) { - fillBuffer(length); - int from = pos; - pos += length; - return Arrays.copyOfRange(buf, from, pos); - } - return readRawBytesLarge(length); - } - - // handle the unhappy case when the length does not fit in the internal buffer - private byte[] readRawBytesLarge(int length) throws IOException { - if (length < 0) { - throw new IllegalArgumentException("Length must not be negative"); - } - - int total = bytesBeforeStart + pos + length; - if (total - globalLimit > 0) { - throw log.globalLimitExceeded(); - } - if (total > limit) { - // limit exceeded, skip up to limit and fail - skipRawBytes(limit - bytesBeforeStart - pos); - throw log.messageTruncated(); - } - - int buffered = end - pos; - int needed = length - buffered; - if (needed <= 0) { - throw new IllegalStateException("The needed data already exists in buffer!"); - } - int oldPos = pos; - bytesBeforeStart += end; - pos = 0; - end = 0; - - if (needed < ProtobufUtil.DEFAULT_STREAM_BUFFER_SIZE || needed <= in.available()) { - byte[] bytes = new byte[length]; - System.arraycopy(buf, oldPos, bytes, 0, buffered); - while (buffered < bytes.length) { - int read = in.read(bytes, buffered, length - buffered); - if (read <= 0) { - throw log.messageTruncated(); - } - bytesBeforeStart += read; - buffered += read; - } - return bytes; - } - - // read in segments to avoid allocating full length at once to prevent sudden death - List segments = new ArrayList<>(); - while (needed > 0) { - byte[] segment = new byte[Math.min(needed, ProtobufUtil.DEFAULT_STREAM_BUFFER_SIZE)]; - int segPos = 0; - while (segPos < segment.length) { - int read = in.read(segment, segPos, segment.length - segPos); - if (read <= 0) { - throw log.messageTruncated(); - } - segPos += read; - bytesBeforeStart += read; - } - segments.add(segment); - needed -= segment.length; - } - - // stitch the segments and hope not to blow up - byte[] bytes = new byte[length]; - System.arraycopy(buf, oldPos, bytes, 0, buffered); - int segPos = buffered; - for (int i = 0; i < segments.size(); i++) { - byte[] segment = segments.get(i); - System.arraycopy(segment, 0, bytes, segPos, segment.length); - segPos += segment.length; - segments.set(i, null); - } - return bytes; + throw log.messageTruncated(); } @Override protected void skipRawBytes(int length) throws IOException { - if (length <= end - pos && length >= 0) { + if (length <= limit - pos && length >= 0) { pos += length; + long skipAmount = in.skip(length); + if (skipAmount != length) { + throw log.messageTruncated(); + } } else { if (length < 0) { throw log.negativeLength(); } - - if (bytesBeforeStart + pos + length > limit) { - // limit exceeded, skip up to limit and fail - skipRawBytes(limit - bytesBeforeStart - pos); - throw log.messageTruncated(); - } - - length -= end - pos; - while (true) { - pos = end; - fillBuffer(1); - if (length <= end) { - pos = length; - break; - } - length -= end; - } + throw log.messageTruncated(); } } }