diff --git a/src/ZstdSharp.Test/ZstdNetSteamingTests.cs b/src/ZstdSharp.Test/ZstdNetSteamingTests.cs new file mode 100644 index 0000000..3f15152 --- /dev/null +++ b/src/ZstdSharp.Test/ZstdNetSteamingTests.cs @@ -0,0 +1,415 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using ZstdSharp.Unsafe; + +namespace ZstdSharp.Test +{ + public enum DataFill + { + Random, + Sequential + } + + internal static class DataGenerator + { + private static readonly Random Random = new(1234); + + public const int LargeBufferSize = 1024 * 1024; + public const int SmallBufferSize = 1024; + + public static MemoryStream GetSmallStream(DataFill dataFill) => GetStream(SmallBufferSize, dataFill); + public static MemoryStream GetLargeStream(DataFill dataFill) => GetStream(LargeBufferSize, dataFill); + public static MemoryStream GetStream(int length, DataFill dataFill) => new(GetBuffer(length, dataFill)); + + public static byte[] GetSmallBuffer(DataFill dataFill) => GetBuffer(SmallBufferSize, dataFill); + public static byte[] GetLargeBuffer(DataFill dataFill) => GetBuffer(LargeBufferSize, dataFill); + + public static byte[] GetBuffer(int length, DataFill dataFill) + { + var buffer = new byte[length]; + if (dataFill == DataFill.Random) + Random.NextBytes(buffer); + else + { + for (int i = 0; i < buffer.Length; i++) + buffer[i] = (byte) (i % 256); + } + + return buffer; + } + } + + public class ZstdNetSteamingTests + { + [Fact] + public void StreamingCompressionZeroAndOneByte() + { + var data = new byte[] {0, 0, 0, 1, 2, 3, 4, 0, 0, 0}; + + var tempStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(tempStream)) + { + compressionStream.Write(data, 0, 0); + compressionStream.Write(ReadOnlySpan.Empty); + compressionStream.WriteAsync(data, 0, 0).GetAwaiter().GetResult(); + compressionStream.WriteAsync(ReadOnlyMemory.Empty).GetAwaiter().GetResult(); + + compressionStream.Write(data, 3, 1); + compressionStream.Write(new ReadOnlySpan(data, 4, 1)); + compressionStream.Flush(); + compressionStream.WriteAsync(data, 5, 1).GetAwaiter().GetResult(); + compressionStream.WriteAsync(new ReadOnlyMemory(data, 6, 1)).GetAwaiter().GetResult(); + compressionStream.FlushAsync().GetAwaiter().GetResult(); + } + + tempStream.Seek(0, SeekOrigin.Begin); + + var result = new byte[data.Length]; + using (var decompressionStream = new DecompressionStream(tempStream)) + { + Assert.Equal(0, decompressionStream.Read(result, 0, 0)); + Assert.Equal(0, decompressionStream.Read(Span.Empty)); + Assert.Equal(0, decompressionStream.ReadAsync(result, 0, 0).GetAwaiter().GetResult()); + Assert.Equal(0, decompressionStream.ReadAsync(Memory.Empty).GetAwaiter().GetResult()); + + Assert.Equal(1, decompressionStream.Read(result, 3, 1)); + Assert.Equal(1, decompressionStream.Read(new Span(result, 4, 1))); + Assert.Equal(1, decompressionStream.ReadAsync(result, 5, 1).GetAwaiter().GetResult()); + Assert.Equal(1, decompressionStream.ReadAsync(new Memory(result, 6, 1)).GetAwaiter().GetResult()); + } + + Assert.True(data.SequenceEqual(result)); + } + + + [Theory] + [InlineData(new byte[0], 0, 0)] + [InlineData(new byte[] {1, 2, 3}, 1, 2)] + [InlineData(new byte[] {1, 2, 3}, 0, 2)] + [InlineData(new byte[] {1, 2, 3}, 1, 1)] + [InlineData(new byte[] {1, 2, 3}, 0, 3)] + public void StreamingCompressionSimpleWrite(byte[] data, int offset, int count) + { + var tempStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(tempStream)) + compressionStream.Write(data, offset, count); + + tempStream.Seek(0, SeekOrigin.Begin); + + var resultStream = new MemoryStream(); + using (var decompressionStream = new DecompressionStream(tempStream)) + decompressionStream.CopyTo(resultStream); + + var dataToCompress = new byte[count]; + Array.Copy(data, offset, dataToCompress, 0, count); + + Assert.True(dataToCompress.SequenceEqual(resultStream.ToArray())); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(5)] + [InlineData(9)] + [InlineData(10)] + public void StreamingDecompressionSimpleRead(int readCount) + { + var data = new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + var tempStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(tempStream)) + compressionStream.Write(data, 0, data.Length); + + tempStream.Seek(0, SeekOrigin.Begin); + + var buffer = new byte[data.Length]; + using (var decompressionStream = new DecompressionStream(tempStream)) + { + int bytesRead; + int totalBytesRead = 0; + while ((bytesRead = decompressionStream.Read(buffer, totalBytesRead, + Math.Min(readCount, buffer.Length - totalBytesRead))) > 0) + { + Assert.True(bytesRead <= readCount); + totalBytesRead += bytesRead; + } + + Assert.Equal(data.Length, totalBytesRead); + } + + Assert.True(data.SequenceEqual(buffer)); + } + + [Fact] + public void StreamingCompressionFlushDataFromInternalBuffers() + { + var testBuffer = new byte[1]; + + var tempStream = new MemoryStream(); + using var compressionStream = new CompressionStream(tempStream); + compressionStream.Write(testBuffer, 0, testBuffer.Length); + compressionStream.Flush(); + + Assert.True(tempStream.Length > 0); + tempStream.Seek(0, SeekOrigin.Begin); + + //NOTE: without ZSTD_endStream call on compression + var resultStream = new MemoryStream(); + using (var decompressionStream = new DecompressionStream(tempStream)) + decompressionStream.CopyTo(resultStream); + + Assert.True(testBuffer.SequenceEqual(resultStream.ToArray())); + } + + [Fact] + public void CompressionImprovesWithDictionary() + { + var dict = TrainDict(); + + var dataStream = DataGenerator.GetSmallStream(DataFill.Sequential); + + var normalResultStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(normalResultStream)) + dataStream.CopyTo(compressionStream); + + dataStream.Seek(0, SeekOrigin.Begin); + + var dictResultStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(dictResultStream)) + { + compressionStream.LoadDictionary(dict); + dataStream.CopyTo(compressionStream); + } + + Assert.True(normalResultStream.Length > dictResultStream.Length); + + dictResultStream.Seek(0, SeekOrigin.Begin); + + var resultStream = new MemoryStream(); + using (var decompressionStream = new DecompressionStream(dictResultStream)) + { + decompressionStream.LoadDictionary(dict); + decompressionStream.CopyTo(resultStream); + } + + Assert.True(dataStream.ToArray().SequenceEqual(resultStream.ToArray())); + } + + [Fact] + public void CompressionShrinksData() + { + var dataStream = DataGenerator.GetLargeStream(DataFill.Sequential); + + var resultStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(resultStream)) + dataStream.CopyTo(compressionStream); + + Assert.True(dataStream.Length > resultStream.Length); + } + + [Fact] + public void RoundTrip_BatchToStreaming() + { + var data = DataGenerator.GetLargeBuffer(DataFill.Sequential); + + byte[] compressed; + using (var compressor = new Compressor()) + compressed = compressor.Wrap(data).ToArray(); + + var resultStream = new MemoryStream(); + using (var decompressionStream = new DecompressionStream(new MemoryStream(compressed))) + decompressionStream.CopyTo(resultStream); + + Assert.True(data.SequenceEqual(resultStream.ToArray())); + } + + [Fact] + public void RoundTrip_StreamingToBatch() + { + var dataStream = DataGenerator.GetLargeStream(DataFill.Sequential); + + var tempStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(tempStream)) + dataStream.CopyTo(compressionStream); + + var resultBuffer = new byte[dataStream.Length]; + using (var decompressor = new Decompressor()) + Assert.Equal(dataStream.Length, decompressor.Unwrap(tempStream.ToArray(), resultBuffer, 0)); + + Assert.True(dataStream.ToArray().SequenceEqual(resultBuffer)); + } + + [Theory, CombinatorialData] + public void RoundTrip_StreamingToStreaming( + [CombinatorialValues(false, true)] bool useDict, [CombinatorialValues(false, true)] bool advanced, + [CombinatorialValues(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, + DataGenerator.LargeBufferSize + 1)] + int zstdBufferSize, + [CombinatorialValues(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, + DataGenerator.LargeBufferSize + 1)] + int copyBufferSize) + { + var dict = useDict ? TrainDict() : null; + var testStream = DataGenerator.GetLargeStream(DataFill.Sequential); + + const int offset = 1; + var buffer = new byte[copyBufferSize + offset + 1]; + + var tempStream = new MemoryStream(); + using (var compressionStream = + new CompressionStream(tempStream, Compressor.DefaultCompressionLevel, zstdBufferSize)) + { + compressionStream.LoadDictionary(dict); + if (advanced) + { + compressionStream.SetParameter(ZSTD_cParameter.ZSTD_c_windowLog, 11); + compressionStream.SetParameter(ZSTD_cParameter.ZSTD_c_checksumFlag, 1); + } + + int bytesRead; + while ((bytesRead = testStream.Read(buffer, offset, copyBufferSize)) > 0) + compressionStream.Write(buffer, offset, bytesRead); + } + + tempStream.Seek(0, SeekOrigin.Begin); + + var resultStream = new MemoryStream(); + using (var decompressionStream = new DecompressionStream(tempStream, zstdBufferSize)) + { + decompressionStream.LoadDictionary(dict); + if (advanced) + { + decompressionStream.SetParameter(ZSTD_dParameter.ZSTD_d_windowLogMax, 11); + } + + int bytesRead; + while ((bytesRead = decompressionStream.Read(buffer, offset, copyBufferSize)) > 0) + resultStream.Write(buffer, offset, bytesRead); + } + + Assert.True(testStream.ToArray().SequenceEqual(resultStream.ToArray())); + } + + [Theory, CombinatorialData] + public async Task RoundTrip_StreamingToStreamingAsync( + [CombinatorialValues(false, true)] bool useDict, [CombinatorialValues(false, true)] bool advanced, + [CombinatorialValues(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, + DataGenerator.LargeBufferSize + 1)] + int zstdBufferSize, + [CombinatorialValues(1, 2, 7, 101, 1024, 65535, DataGenerator.LargeBufferSize, + DataGenerator.LargeBufferSize + 1)] + int copyBufferSize) + { + var dict = useDict ? TrainDict() : null; + var testStream = DataGenerator.GetLargeStream(DataFill.Sequential); + + const int offset = 1; + var buffer = new byte[copyBufferSize + offset + 1]; + + var tempStream = new MemoryStream(); + await using (var compressionStream = + new CompressionStream(tempStream, Compressor.DefaultCompressionLevel, zstdBufferSize)) + { + compressionStream.LoadDictionary(dict); + if (advanced) + { + compressionStream.SetParameter(ZSTD_cParameter.ZSTD_c_windowLog, 11); + compressionStream.SetParameter(ZSTD_cParameter.ZSTD_c_checksumFlag, 1); + } + + int bytesRead; + while ((bytesRead = await testStream.ReadAsync(buffer, offset, copyBufferSize)) > 0) + await compressionStream.WriteAsync(buffer, offset, bytesRead); + } + + tempStream.Seek(0, SeekOrigin.Begin); + + var resultStream = new MemoryStream(); + await using (var decompressionStream = new DecompressionStream(tempStream, zstdBufferSize)) + { + decompressionStream.LoadDictionary(dict); + if (advanced) + { + decompressionStream.SetParameter(ZSTD_dParameter.ZSTD_d_windowLogMax, 11); + } + + int bytesRead; + while ((bytesRead = await decompressionStream.ReadAsync(buffer, offset, copyBufferSize)) > 0) + await resultStream.WriteAsync(buffer, offset, bytesRead); + } + + Assert.True(testStream.ToArray().SequenceEqual(resultStream.ToArray())); + } + + [Theory(Skip = "stress"), CombinatorialData] + public void RoundTrip_StreamingToStreaming_Stress([CombinatorialValues(true, false)] bool useDict, + [CombinatorialValues(true, false)] bool async) + { + long i = 0; + var dict = useDict ? TrainDict() : null; + Enumerable.Range(0, 10000) + .AsParallel() + .WithDegreeOfParallelism(Environment.ProcessorCount * 4) + .ForAll(n => + { + var testStream = DataGenerator.GetSmallStream(DataFill.Sequential); + var cBuffer = new byte[1 + (int) (n % (testStream.Length * 11))]; + var dBuffer = new byte[1 + (int) (n % (testStream.Length * 13))]; + + var tempStream = new MemoryStream(); + using (var compressionStream = new CompressionStream(tempStream, Compressor.DefaultCompressionLevel, + 1 + (int) (n % (testStream.Length * 17)))) + { + compressionStream.LoadDictionary(dict); + int bytesRead; + int offset = n % cBuffer.Length; + while ((bytesRead = testStream.Read(cBuffer, offset, cBuffer.Length - offset)) > 0) + { + if (async) + compressionStream.WriteAsync(cBuffer, offset, bytesRead).GetAwaiter().GetResult(); + else + compressionStream.Write(cBuffer, offset, bytesRead); + if (Interlocked.Increment(ref i) % 100 == 0) + GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true, true); + } + } + + tempStream.Seek(0, SeekOrigin.Begin); + + var resultStream = new MemoryStream(); + using (var decompressionStream = + new DecompressionStream(tempStream, 1 + (int) (n % (testStream.Length * 19)))) + { + decompressionStream.LoadDictionary(dict); + int bytesRead; + int offset = n % dBuffer.Length; + while ((bytesRead = async + ? decompressionStream.ReadAsync(dBuffer, offset, dBuffer.Length - offset).GetAwaiter() + .GetResult() + : decompressionStream.Read(dBuffer, offset, dBuffer.Length - offset)) > 0) + { + resultStream.Write(dBuffer, offset, bytesRead); + if (Interlocked.Increment(ref i) % 100 == 0) + GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true, true); + } + } + + Assert.True(testStream.ToArray().SequenceEqual(resultStream.ToArray())); + }); + } + + private static byte[] TrainDict() + { + var trainingData = new byte[100][]; + for (int i = 0; i < trainingData.Length; i++) + trainingData[i] = DataGenerator.GetSmallBuffer(DataFill.Sequential); + return DictBuilder.TrainFromBuffer(trainingData); + } + } +} diff --git a/src/ZstdSharp.Test/ZstdSharp.Test.csproj b/src/ZstdSharp.Test/ZstdSharp.Test.csproj index b2018de..25bb5d5 100644 --- a/src/ZstdSharp.Test/ZstdSharp.Test.csproj +++ b/src/ZstdSharp.Test/ZstdSharp.Test.csproj @@ -17,6 +17,7 @@ + diff --git a/src/ZstdSharp/CompressionStream.cs b/src/ZstdSharp/CompressionStream.cs new file mode 100644 index 0000000..6612f24 --- /dev/null +++ b/src/ZstdSharp/CompressionStream.cs @@ -0,0 +1,187 @@ +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using ZstdSharp.Unsafe; + +namespace ZstdSharp +{ + public class CompressionStream : Stream + { + private readonly Stream innerStream; + private readonly byte[] outputBuffer; + private Compressor compressor; + private ZSTD_outBuffer_s output; + + public CompressionStream(Stream stream, int level = Compressor.DefaultCompressionLevel, + int bufferSize = 0) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + if (!stream.CanWrite) + throw new ArgumentException("Stream is not writable", nameof(stream)); + + if (bufferSize < 0) + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + + innerStream = stream; + compressor = new Compressor(level); + + var outputBufferSize = + bufferSize > 0 ? bufferSize : (int) Methods.ZSTD_CStreamOutSize().EnsureZstdSuccess(); + outputBuffer = ArrayPool.Shared.Rent(outputBufferSize); + output = new ZSTD_outBuffer_s {pos = 0, size = (nuint) outputBufferSize}; + } + + public void SetParameter(ZSTD_cParameter parameter, int value) + { + EnsureNotDisposed(); + compressor.SetParameter(parameter, value); + } + + public int GetParameter(ZSTD_cParameter parameter) + { + EnsureNotDisposed(); + return compressor.GetParameter(parameter); + } + + public void LoadDictionary(byte[] dict) + { + EnsureNotDisposed(); + compressor.LoadDictionary(dict); + } + + ~CompressionStream() => Dispose(false); + + public override async ValueTask DisposeAsync() + { + if (compressor == null) + return; + + try + { + await FlushAsync().ConfigureAwait(false); + } + finally + { + ReleaseUnmanagedResources(); + GC.SuppressFinalize(this); + } + } + + protected override void Dispose(bool disposing) + { + if (compressor == null) + return; + + try + { + if (disposing) + Flush(); + } + finally + { + ReleaseUnmanagedResources(); + } + } + + private void ReleaseUnmanagedResources() + { + compressor.Dispose(); + compressor = null; + ArrayPool.Shared.Return(outputBuffer); + } + + public override void Flush() + => WriteInternal(null, true); + + public override async Task FlushAsync(CancellationToken cancellationToken) + => await WriteInternalAsync(null, true, cancellationToken).ConfigureAwait(false); + + public override void Write(byte[] buffer, int offset, int count) + => Write(new ReadOnlySpan(buffer, offset, count)); + + public override void Write(ReadOnlySpan buffer) + => WriteInternal(buffer, false); + + private void WriteInternal(ReadOnlySpan buffer, bool lastChunk) + { + EnsureNotDisposed(); + + var input = new ZSTD_inBuffer_s {pos = 0, size = buffer != null ? (nuint) buffer.Length : 0}; + nuint remaining; + do + { + output.pos = 0; + remaining = CompressStream(ref input, buffer, + lastChunk ? ZSTD_EndDirective.ZSTD_e_end : ZSTD_EndDirective.ZSTD_e_continue); + + var written = (int) output.pos; + if (written > 0) + innerStream.Write(outputBuffer, 0, written); + } while (lastChunk ? remaining > 0 : input.pos < input.size); + } + + private async ValueTask WriteInternalAsync(ReadOnlyMemory? buffer, bool lastChunk, + CancellationToken cancellationToken = default) + { + EnsureNotDisposed(); + + var input = new ZSTD_inBuffer_s { pos = 0, size = buffer.HasValue ? (nuint)buffer.Value.Length : 0 }; + nuint remaining; + do + { + output.pos = 0; + remaining = CompressStream(ref input, buffer.HasValue ? buffer.Value.Span : null, + lastChunk ? ZSTD_EndDirective.ZSTD_e_end : ZSTD_EndDirective.ZSTD_e_continue); + + var written = (int) output.pos; + if (written > 0) + await innerStream.WriteAsync(outputBuffer, 0, written, cancellationToken).ConfigureAwait(false); + } while (lastChunk ? remaining > 0 : input.pos < input.size); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, + CancellationToken cancellationToken = default) + => await WriteInternalAsync(buffer, false, cancellationToken).ConfigureAwait(false); + + internal unsafe nuint CompressStream(ref ZSTD_inBuffer_s input, ReadOnlySpan inputBuffer, + ZSTD_EndDirective directive) + { + fixed (byte* inputBufferPtr = inputBuffer) + fixed (byte* outputBufferPtr = outputBuffer) + { + input.src = inputBufferPtr; + output.dst = outputBufferPtr; + return compressor.CompressStream(ref input, ref output, directive).EnsureZstdSuccess(); + } + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private void EnsureNotDisposed() + { + if (compressor == null) + throw new ObjectDisposedException(nameof(CompressionStream)); + } + } +} diff --git a/src/ZstdSharp/Compressor.cs b/src/ZstdSharp/Compressor.cs index 43e88c7..b581cab 100644 --- a/src/ZstdSharp/Compressor.cs +++ b/src/ZstdSharp/Compressor.cs @@ -51,9 +51,7 @@ public void LoadDictionary(byte[] dict) { fixed (byte* dictPtr = dict) - { Methods.ZSTD_CCtx_loadDictionary(cctx, dictPtr, (nuint) dict.Length).EnsureZstdSuccess(); - } } } @@ -61,9 +59,7 @@ public Compressor(int level = DefaultCompressionLevel) { cctx = Methods.ZSTD_createCCtx(); if (cctx == null) - { throw new ZstdException(ZSTD_ErrorCode.ZSTD_error_GENERIC, "Failed to create cctx"); - } Level = level; } @@ -94,10 +90,9 @@ public int Wrap(ReadOnlySpan src, Span dest) EnsureNotDisposed(); fixed (byte* srcPtr = src) fixed (byte* destPtr = dest) - { - return (int) Methods.ZSTD_compress2(cctx, destPtr, (nuint) dest.Length, srcPtr, (nuint) src.Length) + return (int) Methods + .ZSTD_compress2(cctx, destPtr, (nuint) dest.Length, srcPtr, (nuint) src.Length) .EnsureZstdSuccess(); - } } public int Wrap(ArraySegment src, ArraySegment dest) @@ -124,8 +119,15 @@ public void Dispose() private void EnsureNotDisposed() { if (cctx == null) - { throw new ObjectDisposedException(nameof(Compressor)); + } + + internal nuint CompressStream(ref ZSTD_inBuffer_s input, ref ZSTD_outBuffer_s output, ZSTD_EndDirective directive) + { + fixed (ZSTD_inBuffer_s* inputPtr = &input) + fixed (ZSTD_outBuffer_s* outputPtr = &output) + { + return Methods.ZSTD_compressStream2(cctx, outputPtr, inputPtr, directive).EnsureZstdSuccess(); } } } diff --git a/src/ZstdSharp/DecompressionStream.cs b/src/ZstdSharp/DecompressionStream.cs new file mode 100644 index 0000000..f3a3d36 --- /dev/null +++ b/src/ZstdSharp/DecompressionStream.cs @@ -0,0 +1,159 @@ +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using ZstdSharp.Unsafe; + +namespace ZstdSharp +{ + public class DecompressionStream : Stream + { + private readonly Stream innerStream; + private readonly byte[] inputBuffer; + private readonly int inputBufferSize; + private Decompressor decompressor; + private ZSTD_inBuffer_s input; + private nuint lastDecompressResult = 0; + + public DecompressionStream(Stream stream, int bufferSize = 0) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + if (!stream.CanRead) + throw new ArgumentException("Stream is not readable", nameof(stream)); + + if (bufferSize < 0) + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + + innerStream = stream; + decompressor = new Decompressor(); + + inputBufferSize = bufferSize > 0 ? bufferSize : (int) Methods.ZSTD_CStreamInSize().EnsureZstdSuccess(); + inputBuffer = ArrayPool.Shared.Rent(inputBufferSize); + input = new ZSTD_inBuffer_s {pos = (nuint) inputBufferSize, size = (nuint) inputBufferSize}; + } + + public void SetParameter(ZSTD_dParameter parameter, int value) + { + EnsureNotDisposed(); + decompressor.SetParameter(parameter, value); + } + + public int GetParameter(ZSTD_dParameter parameter) + { + EnsureNotDisposed(); + return decompressor.GetParameter(parameter); + } + + public void LoadDictionary(byte[] dict) + { + EnsureNotDisposed(); + decompressor.LoadDictionary(dict); + } + + ~DecompressionStream() => Dispose(false); + + protected override void Dispose(bool disposing) + { + if (decompressor == null) + return; + + if (lastDecompressResult != 0) + throw new EndOfStreamException("Premature end of stream"); + + decompressor.Dispose(); + decompressor = null; + ArrayPool.Shared.Return(inputBuffer); + } + + public override int Read(byte[] buffer, int offset, int count) + => Read(new Span(buffer, offset, count)); + + public override int Read(Span buffer) + { + EnsureNotDisposed(); + + var output = new ZSTD_outBuffer_s {pos = 0, size = (nuint) buffer.Length}; + while (output.pos < output.size) + { + if (input.pos >= input.size) + { + int bytesRead; + if ((bytesRead = innerStream.Read(inputBuffer, 0, inputBufferSize)) == 0) + break; + + input.size = (nuint) bytesRead; + input.pos = 0; + } + + lastDecompressResult = DecompressStream(ref output, buffer); + } + + return (int) output.pos; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + + public override async ValueTask ReadAsync(Memory buffer, + CancellationToken cancellationToken = default) + { + EnsureNotDisposed(); + + var output = new ZSTD_outBuffer_s { pos = 0, size = (nuint)buffer.Length}; + while (output.pos < output.size) + { + if (input.pos >= input.size) + { + int bytesRead; + if ((bytesRead = await innerStream.ReadAsync(inputBuffer, 0, inputBufferSize, cancellationToken) + .ConfigureAwait(false)) == 0) + break; + + input.size = (nuint) bytesRead; + input.pos = 0; + } + + lastDecompressResult = DecompressStream(ref output, buffer.Span); + } + + return (int) output.pos; + } + private unsafe nuint DecompressStream(ref ZSTD_outBuffer_s output, Span outputBuffer) + { + fixed (byte* inputBufferPtr = inputBuffer) + fixed (byte* outputBufferPtr = outputBuffer) + { + input.src = inputBufferPtr; + output.dst = outputBufferPtr; + return decompressor.DecompressStream(ref input, ref output); + } + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private void EnsureNotDisposed() + { + if (decompressor == null) + throw new ObjectDisposedException(nameof(DecompressionStream)); + } + } +} diff --git a/src/ZstdSharp/Decompressor.cs b/src/ZstdSharp/Decompressor.cs index cdce4c4..9d84031 100644 --- a/src/ZstdSharp/Decompressor.cs +++ b/src/ZstdSharp/Decompressor.cs @@ -11,9 +11,7 @@ public Decompressor() { dctx = Methods.ZSTD_createDCtx(); if (dctx == null) - { throw new ZstdException(ZSTD_ErrorCode.ZSTD_error_GENERIC, "Failed to create dctx"); - } } ~Decompressor() @@ -21,6 +19,20 @@ public Decompressor() ReleaseUnmanagedResources(); } + public void SetParameter(ZSTD_dParameter parameter, int value) + { + EnsureNotDisposed(); + Methods.ZSTD_DCtx_setParameter(dctx, parameter, value).EnsureZstdSuccess(); + } + + public int GetParameter(ZSTD_dParameter parameter) + { + EnsureNotDisposed(); + int value; + Methods.ZSTD_DCtx_getParameter(dctx, parameter, &value).EnsureZstdSuccess(); + return value; + } + public void LoadDictionary(byte[] dict) { EnsureNotDisposed(); @@ -30,20 +42,15 @@ public void LoadDictionary(byte[] dict) } else { - fixed (byte* dictPtr = dict) - { Methods.ZSTD_DCtx_loadDictionary(dctx, dictPtr, (nuint) dict.Length).EnsureZstdSuccess(); - } } } public static ulong GetDecompressedSize(ReadOnlySpan src) { fixed (byte* srcPtr = src) - { return Methods.ZSTD_decompressBound(srcPtr, (nuint) src.Length).EnsureContentSizeOk(); - } } public static ulong GetDecompressedSize(ArraySegment src) @@ -75,10 +82,9 @@ public int Unwrap(ReadOnlySpan src, Span dest) EnsureNotDisposed(); fixed (byte* srcPtr = src) fixed (byte* destPtr = dest) - { - return (int) Methods.ZSTD_decompressDCtx(dctx, destPtr, (nuint) dest.Length, srcPtr, (nuint) src.Length) + return (int) Methods + .ZSTD_decompressDCtx(dctx, destPtr, (nuint) dest.Length, srcPtr, (nuint) src.Length) .EnsureZstdSuccess(); - } } public int Unwrap(byte[] src, int srcOffset, int srcLength, byte[] dst, int dstOffset, int dstLength) @@ -101,8 +107,15 @@ public void Dispose() private void EnsureNotDisposed() { if (dctx == null) - { throw new ObjectDisposedException(nameof(Decompressor)); + } + + internal nuint DecompressStream(ref ZSTD_inBuffer_s input, ref ZSTD_outBuffer_s output) + { + fixed (ZSTD_inBuffer_s* inputPtr = &input) + fixed (ZSTD_outBuffer_s* outputPtr = &output) + { + return Methods.ZSTD_decompressStream(dctx, outputPtr, inputPtr).EnsureZstdSuccess(); } } }