diff --git a/CHANGELOG.md b/CHANGELOG.md index 72876ad0ac789..9d5efd066b399 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -130,9 +130,10 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Better plural stemmer than minimal_english ([#4738](https://github.com/opensearch-project/OpenSearch/pull/4738)) - Fix AbstractStringFieldDataTestCase tests to account for TotalHits lower bound ([4867](https://github.com/opensearch-project/OpenSearch/pull/4867)) - [Segment Replication] Fix bug of replica shard's translog not purging on index flush when segment replication is enabled ([4975](https://github.com/opensearch-project/OpenSearch/pull/4975)) +- Fix bug in SlicedInputStream with zero length ([#4863](https://github.com/opensearch-project/OpenSearch/pull/4863)) - ### Security [Unreleased]: https://github.com/opensearch-project/OpenSearch/compare/2.2.0...HEAD -[2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.2.0...2.x \ No newline at end of file +[2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.2.0...2.x diff --git a/server/src/main/java/org/opensearch/index/snapshots/blobstore/SlicedInputStream.java b/server/src/main/java/org/opensearch/index/snapshots/blobstore/SlicedInputStream.java index b9b27547fe50c..ae2e33d41c6ce 100644 --- a/server/src/main/java/org/opensearch/index/snapshots/blobstore/SlicedInputStream.java +++ b/server/src/main/java/org/opensearch/index/snapshots/blobstore/SlicedInputStream.java @@ -35,6 +35,7 @@ import java.io.IOException; import java.io.InputStream; +import java.util.Objects; /** * A {@link SlicedInputStream} is a logical @@ -100,6 +101,11 @@ public final int read() throws IOException { @Override public final int read(byte[] buffer, int offset, int length) throws IOException { + Objects.checkFromIndexSize(offset, length, buffer.length); + if (length == 0) { + return 0; + } + final InputStream stream = currentStream(); if (stream == null) { return -1; diff --git a/server/src/test/java/org/opensearch/index/snapshots/blobstore/SlicedInputStreamTests.java b/server/src/test/java/org/opensearch/index/snapshots/blobstore/SlicedInputStreamTests.java index 3e337bbe3adae..76fb8f62b5468 100644 --- a/server/src/test/java/org/opensearch/index/snapshots/blobstore/SlicedInputStreamTests.java +++ b/server/src/test/java/org/opensearch/index/snapshots/blobstore/SlicedInputStreamTests.java @@ -32,6 +32,8 @@ package org.opensearch.index.snapshots.blobstore; import com.carrotsearch.randomizedtesting.generators.RandomNumbers; + +import org.hamcrest.MatcherAssert; import org.opensearch.test.OpenSearchTestCase; import java.io.ByteArrayInputStream; @@ -39,6 +41,9 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Random; import static org.hamcrest.Matchers.equalTo; @@ -86,11 +91,9 @@ protected InputStream openSlice(int slice) throws IOException { assertThat(random.nextInt(Byte.MAX_VALUE), equalTo(input.read())); break; default: - byte[] b = randomBytes(random); - byte[] buffer = new byte[b.length]; - int read = readFully(input, buffer); - assertThat(b.length, equalTo(read)); - assertArrayEquals(b, buffer); + byte[] expectedBytes = randomBytes(random); + byte[] actualBytes = input.readNBytes(expectedBytes.length); + assertArrayEquals(expectedBytes, actualBytes); break; } } @@ -107,19 +110,45 @@ protected InputStream openSlice(int slice) throws IOException { } - private int readFully(InputStream stream, byte[] buffer) throws IOException { - for (int i = 0; i < buffer.length;) { - int read = stream.read(buffer, i, buffer.length - i); - if (read == -1) { - if (i == 0) { - return -1; - } else { - return i; - } + public void testReadZeroLength() throws IOException { + try (InputStream input = createSingleByteStream()) { + final byte[] buffer = new byte[100]; + final int read = input.read(buffer, 0, 0); + MatcherAssert.assertThat(read, equalTo(0)); + } + } + + public void testInvalidOffsetAndLength() throws IOException { + try (InputStream input = createSingleByteStream()) { + final byte[] buffer = new byte[100]; + expectThrows(NullPointerException.class, () -> input.read(null, 0, 10)); + expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, -1, 10)); + expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, 0, -1)); + expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, 0, buffer.length + 1)); + } + } + + public void testReadAllBytes() throws IOException { + final byte[] expectedResults = randomByteArrayOfLength(50_000); + final int numSlices = 200; + final int slizeSize = expectedResults.length / numSlices; + + final List arraySlices = new ArrayList<>(numSlices); + for (int i = 0; i < numSlices; i++) { + final int offset = slizeSize * i; + arraySlices.add(Arrays.copyOfRange(expectedResults, offset, offset + slizeSize)); + } + // Create a SlicedInputStream that will return the expected result in 2 slices + final byte[] actualResults; + try (InputStream is = new SlicedInputStream(numSlices) { + @Override + protected InputStream openSlice(int slice) { + return new ByteArrayInputStream(arraySlices.get(slice)); } - i += read; + }) { + actualResults = is.readAllBytes(); } - return buffer.length; + assertArrayEquals(expectedResults, actualResults); } private byte[] randomBytes(Random random) { @@ -129,6 +158,15 @@ private byte[] randomBytes(Random random) { return data; } + private static InputStream createSingleByteStream() { + return new SlicedInputStream(1) { + @Override + protected InputStream openSlice(int slice) { + return new ByteArrayInputStream(new byte[] { 1 }); + } + }; + } + private static final class CheckClosedInputStream extends FilterInputStream { public boolean closed = false;