diff --git a/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java b/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java index 60183d73..7993ad34 100644 --- a/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java +++ b/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java @@ -69,6 +69,9 @@ public static int decompress( } while (value == 255 && input < inputLimit - 15); } + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } // copy literal long literalEnd = input + literalLength; @@ -127,6 +130,9 @@ public static int decompress( while (value == 255); } matchLength += MIN_MATCH; // implicit length from initial 4-byte match in encoder + if (matchLength < 0) { + throw new MalformedInputException(input - inputAddress); + } long matchOutputLimit = output + matchLength; diff --git a/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java b/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java index bb30c2d1..2a8dcc4f 100644 --- a/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java +++ b/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java @@ -248,6 +248,10 @@ else if ((command & 0b1100_0000) != 0) { } firstCommand = false; + if (matchLength < 0) { + throw new MalformedInputException(input - inputAddress); + } + // copy match if (matchLength != 0) { // lzo encodes match offset minus one @@ -316,6 +320,9 @@ else if ((command & 0b1100_0000) != 0) { } // copy literal + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } long literalOutputLimit = output + literalLength; if (literalOutputLimit > fastOutputLimit || input + literalLength > inputLimit - SIZE_OF_LONG) { if (literalOutputLimit > outputLimit) { diff --git a/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java b/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java index 7cca9afa..450f9a05 100644 --- a/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java +++ b/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java @@ -116,6 +116,9 @@ private static int uncompressAll( if ((opCode & 0x3) == LITERAL) { int literalLength = length + trailer; + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } // copy literal long literalOutputLimit = output + literalLength; diff --git a/src/test/java/io/airlift/compress/snappy/TestSnappy.java b/src/test/java/io/airlift/compress/snappy/TestSnappy.java index 91c14cb5..b2e92d79 100644 --- a/src/test/java/io/airlift/compress/snappy/TestSnappy.java +++ b/src/test/java/io/airlift/compress/snappy/TestSnappy.java @@ -16,8 +16,12 @@ import io.airlift.compress.AbstractTestCompression; import io.airlift.compress.Compressor; import io.airlift.compress.Decompressor; +import io.airlift.compress.MalformedInputException; import io.airlift.compress.thirdparty.XerialSnappyCompressor; import io.airlift.compress.thirdparty.XerialSnappyDecompressor; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestSnappy extends AbstractTestCompression @@ -45,4 +49,22 @@ protected Decompressor getVerifyDecompressor() { return new XerialSnappyDecompressor(); } + + @Test + public void testInvalidLiteralLength() + { + byte[] data = { + // Encoded uncompressed length 1024 + -128, 8, + // op-code + (byte) 252, + // Trailer value Integer.MAX_VALUE + (byte) 0b1111_1111, (byte) 0b1111_1111, (byte) 0b1111_1111, (byte) 0b0111_1111, + // Some arbitrary data + 0, 0, 0, 0, 0, 0, 0, 0 + }; + + assertThatThrownBy(() -> new SnappyDecompressor().decompress(data, 0, data.length, new byte[1024], 0, 1024)) + .isInstanceOf(MalformedInputException.class); + } }