From c9e413163323e4138f4e8f0e545220ccc3aedfb9 Mon Sep 17 00:00:00 2001 From: Liz Fong-Jones Date: Thu, 30 Sep 2021 16:37:43 -0700 Subject: [PATCH] pass level param through --- compress.go | 2 +- zstd.go | 25 ++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/compress.go b/compress.go index 12cd7c3d5..8436dc20e 100644 --- a/compress.go +++ b/compress.go @@ -187,7 +187,7 @@ func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) { } return buf.Bytes(), nil case CompressionZSTD: - return zstdCompress(nil, data) + return zstdCompress(level, nil, data) default: return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)} } diff --git a/zstd.go b/zstd.go index e23bfc477..336564ac0 100644 --- a/zstd.go +++ b/zstd.go @@ -1,18 +1,37 @@ package sarama import ( + "sync" + "github.com/klauspost/compress/zstd" ) +var zstdEncMap sync.Map + var ( zstdDec, _ = zstd.NewReader(nil) - zstdEnc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true)) ) +func getEncoder(level int) *zstd.Encoder { + if ret, ok := zstdEncMap.Load(level); ok { + return ret.(*zstd.Encoder) + } + // It's possible to race and create multiple new writers. + // Only one will survive GC after use. + encoderLevel := zstd.SpeedDefault + if level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(level) + } + zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(encoderLevel)) + zstdEncMap.Store(level, zstdEnc) + return zstdEnc +} + func zstdDecompress(dst, src []byte) ([]byte, error) { return zstdDec.DecodeAll(src, dst) } -func zstdCompress(dst, src []byte) ([]byte, error) { - return zstdEnc.EncodeAll(src, dst), nil +func zstdCompress(level int, dst, src []byte) ([]byte, error) { + return getEncoder(level).EncodeAll(src, dst), nil }