diff --git a/gozstd.go b/gozstd.go index 85f07d0..e90e0a2 100644 --- a/gozstd.go +++ b/gozstd.go @@ -45,6 +45,8 @@ import ( // DefaultCompressionLevel is the default compression level. const DefaultCompressionLevel = 3 // Obtained from ZSTD_CLEVEL_DEFAULT. +const maxFrameContentSize = 256 << 20 // 256 MB + // Compress appends compressed src to dst and returns the result. func Compress(dst, src []byte) []byte { return compressDictLevel(dst, src, nil, DefaultCompressionLevel) @@ -257,14 +259,14 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte // Slow path - resize dst to fit decompressed data. srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - decompressBound := int(C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))) - switch uint64(decompressBound) { - case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN): + contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src))) + switch { + case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize: return streamDecompress(dst, src, dd) - case uint64(C.ZSTD_CONTENTSIZE_ERROR): + case contentSize == C.ZSTD_CONTENTSIZE_ERROR: return dst, fmt.Errorf("cannot decompress invalid src") } - decompressBound++ + decompressBound := int(contentSize) + 1 if n := dstLen + decompressBound - cap(dst); n > 0 { // This should be optimized since go 1.11 - see https://golang.org/doc/go1.11#performance-compiler. diff --git a/gozstd_test.go b/gozstd_test.go index 3a465b8..e870505 100644 --- a/gozstd_test.go +++ b/gozstd_test.go @@ -54,6 +54,14 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) { }) } +func TestDecompressTooLarge(t *testing.T) { + src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67} + _, err := Decompress(nil, src) + if err == nil { + t.Fatalf("expecting error when decompressing malformed frame") + } +} + func mustUnhex(dataHex string) []byte { data, err := hex.DecodeString(dataHex) if err != nil {