diff --git a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java index 5c1801da2..5a3585cea 100644 --- a/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java +++ b/core/src/main/java/org/infinispan/protostream/impl/TagReaderImpl.java @@ -870,9 +870,7 @@ String readString() throws IOException { int length = readVarint32(); if (length > 0 && length <= limit - pos) { byte[] bytes = readRawByteArray(length); - String value = new String(bytes, 0, length, UTF8); - pos += length; - return value; + return new String(bytes, 0, length, UTF8); } if (length == 0) { return ""; @@ -885,6 +883,9 @@ String readString() throws IOException { @Override ByteBuffer readRawByteBuffer(int length) throws IOException { + if (length == 0) { + return EMPTY_BUFFER; + } byte[] bytes = readRawByteArray(length); return ByteBuffer.wrap(bytes); } @@ -1046,9 +1047,16 @@ byte[] readRawByteArray(int length) throws IOException { if (pos > globalLimit) { throw log.globalLimitExceeded(); } + int readTotal = 0; + int readAmount; byte[] array = new byte[length]; - int readAmount = in.read(array); - if (readAmount != length) { + while ((readAmount = in.read(array, readTotal, length - readTotal)) != -1) { + readTotal += readAmount; + if (readTotal == length) { + break; + } + } + if (readTotal != length) { throw log.messageTruncated(); } return array; @@ -1066,10 +1074,7 @@ byte[] readRawByteArray(int length) throws IOException { protected void skipRawBytes(int length) throws IOException { if (length <= limit - pos && length >= 0) { pos += length; - long skipAmount = in.skip(length); - if (skipAmount != length) { - throw log.messageTruncated(); - } + skipNBytes(length); } else { if (length < 0) { throw log.negativeLength(); @@ -1077,6 +1082,26 @@ protected void skipRawBytes(int length) throws IOException { throw log.messageTruncated(); } } + + // Copied from InputStream, we can't use Java 12 or newer just yet, can be removed when on a newer version. + private void skipNBytes(long n) throws IOException { + while (n > 0) { + long ns = in.skip(n); + if (ns > 0 && ns <= n) { + // adjust number to skip + n -= ns; + } else if (ns == 0) { // no bytes skipped + // read one byte to check for EOS + if (in.read() == -1) { + throw log.messageTruncated(); + } + // one byte read so decrement number to skip + n--; + } else { // skipped negative or too many bytes + throw new IOException("Unable to skip exactly"); + } + } + } } }