Skip to content

Commit

Permalink
Fix bug in SlicedInputStream with zero length (#4863)
Browse files Browse the repository at this point in the history
Per the contract of InputStream#read(byte[], int, int):

    If len is zero, then no bytes are read and 0 is returned

SlicedInputStream had a bug where if a zero length was passed then it
would drain all the underlying streams and return -1. This was uncovered
by using InputStream#readAllBytes in new code under development. The
only existing usage of SlicedInputStream should not be vulnerable to this
bug.

I've also added a check for invalid arguments and created tests to
ensure the proper exceptions are thrown per the InputStream contract. In
the test I've replaced a "readFully" method with an equivalent
"readNBytes" that was introduced in Java 11.

Signed-off-by: Andrew Ross <[email protected]>

Signed-off-by: Andrew Ross <[email protected]>
(cherry picked from commit 6571db7)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

# Conflicts:
#	CHANGELOG.md
  • Loading branch information
github-actions[bot] committed Nov 2, 2022
1 parent 90fc25b commit 2ce0f0a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.Objects;

/**
* A {@link SlicedInputStream} is a logical
Expand Down Expand Up @@ -100,6 +101,11 @@ public final int read() throws IOException {

@Override
public final int read(byte[] buffer, int offset, int length) throws IOException {
Objects.checkFromIndexSize(offset, length, buffer.length);
if (length == 0) {
return 0;
}

final InputStream stream = currentStream();
if (stream == null) {
return -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@
package org.opensearch.index.snapshots.blobstore;

import com.carrotsearch.randomizedtesting.generators.RandomNumbers;

import org.hamcrest.MatcherAssert;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -86,11 +91,9 @@ protected InputStream openSlice(int slice) throws IOException {
assertThat(random.nextInt(Byte.MAX_VALUE), equalTo(input.read()));
break;
default:
byte[] b = randomBytes(random);
byte[] buffer = new byte[b.length];
int read = readFully(input, buffer);
assertThat(b.length, equalTo(read));
assertArrayEquals(b, buffer);
byte[] expectedBytes = randomBytes(random);
byte[] actualBytes = input.readNBytes(expectedBytes.length);
assertArrayEquals(expectedBytes, actualBytes);
break;
}
}
Expand All @@ -107,19 +110,45 @@ protected InputStream openSlice(int slice) throws IOException {

}

private int readFully(InputStream stream, byte[] buffer) throws IOException {
for (int i = 0; i < buffer.length;) {
int read = stream.read(buffer, i, buffer.length - i);
if (read == -1) {
if (i == 0) {
return -1;
} else {
return i;
}
public void testReadZeroLength() throws IOException {
try (InputStream input = createSingleByteStream()) {
final byte[] buffer = new byte[100];
final int read = input.read(buffer, 0, 0);
MatcherAssert.assertThat(read, equalTo(0));
}
}

public void testInvalidOffsetAndLength() throws IOException {
try (InputStream input = createSingleByteStream()) {
final byte[] buffer = new byte[100];
expectThrows(NullPointerException.class, () -> input.read(null, 0, 10));
expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, -1, 10));
expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, 0, -1));
expectThrows(IndexOutOfBoundsException.class, () -> input.read(buffer, 0, buffer.length + 1));
}
}

public void testReadAllBytes() throws IOException {
final byte[] expectedResults = randomByteArrayOfLength(50_000);
final int numSlices = 200;
final int slizeSize = expectedResults.length / numSlices;

final List<byte[]> arraySlices = new ArrayList<>(numSlices);
for (int i = 0; i < numSlices; i++) {
final int offset = slizeSize * i;
arraySlices.add(Arrays.copyOfRange(expectedResults, offset, offset + slizeSize));
}
// Create a SlicedInputStream that will return the expected result in 2 slices
final byte[] actualResults;
try (InputStream is = new SlicedInputStream(numSlices) {
@Override
protected InputStream openSlice(int slice) {
return new ByteArrayInputStream(arraySlices.get(slice));
}
i += read;
}) {
actualResults = is.readAllBytes();
}
return buffer.length;
assertArrayEquals(expectedResults, actualResults);
}

private byte[] randomBytes(Random random) {
Expand All @@ -129,6 +158,15 @@ private byte[] randomBytes(Random random) {
return data;
}

private static InputStream createSingleByteStream() {
return new SlicedInputStream(1) {
@Override
protected InputStream openSlice(int slice) {
return new ByteArrayInputStream(new byte[] { 1 });
}
};
}

private static final class CheckClosedInputStream extends FilterInputStream {

public boolean closed = false;
Expand Down

0 comments on commit 2ce0f0a

Please sign in to comment.