diff --git a/dissect/util/stream.py b/dissect/util/stream.py index 7b89c7f..e602c0f 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -1,5 +1,7 @@ import io import os +import sys +import zlib from bisect import bisect_left, bisect_right from threading import Lock from typing import BinaryIO, Optional, Union @@ -550,3 +552,95 @@ def _read(self, offset: int, length: int) -> bytes: overlay_idx += 1 return b"".join(result) + + +class ZlibStream(AlignedStream): + """Create a zlib stream from another file-like object. + + Basically the same as ``gzip.GzipFile`` but for raw zlib streams. + Due to the nature of zlib streams, seeking backwards requires resetting the decompression context. + + Args: + fh: The source file-like object. + size: The size the stream should be. + """ + + def __init__(self, fh: BinaryIO, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE, **kwargs): + self._fh = fh + + self._zlib = None + self._zlib_args = kwargs + self._zlib_offset = 0 + self._zlib_prepend = b"" + self._zlib_prepend_offset = None + self._rewind() + + super().__init__(size, align) + + def _rewind(self) -> None: + self._fh.seek(0) + self._zlib = zlib.decompressobj(**self._zlib_args) + self._zlib_offset = 0 + self._zlib_prepend = b"" + self._zlib_prepend_offset = None + + def _seek_zlib(self, offset: int) -> None: + if offset < self._zlib_offset: + self._rewind() + + while self._zlib_offset < offset: + read_size = min(offset - self._zlib_offset, self.align) + if self._read_zlib(read_size) == b"": + break + + def _read_fh(self, length: int) -> bytes: + if self._zlib_prepend_offset is None: + return self._fh.read(length) + + if self._zlib_prepend_offset + length <= len(self._zlib_prepend): + offset = self._zlib_prepend_offset + self._zlib_prepend_offset += length + return self._zlib_prepend[offset : self._zlib_prepend_offset] + + offset = self._zlib_prepend_offset + self._zlib_prepend_offset = None + return self._zlib_prepend[offset:] + self._fh.read(length - len(self._zlib_prepend) + offset) + + def _read_zlib(self, length: int) -> bytes: + if length < 0: + return self.readall() + + result = [] + while length > 0: + buf = self._read_fh(io.DEFAULT_BUFFER_SIZE) + decompressed = self._zlib.decompress(buf, length) + + if self._zlib.unconsumed_tail != b"": + self._zlib_prepend = self._zlib.unconsumed_tail + self._zlib_prepend_offset = 0 + + if buf == b"": + break + + result.append(decompressed) + length -= len(decompressed) + + buf = b"".join(result) + self._zlib_offset += len(buf) + return buf + + def _read(self, offset: int, length: int) -> bytes: + self._seek_zlib(offset) + return self._read_zlib(length) + + def readall(self) -> bytes: + self._seek_zlib(self.tell()) + + chunks = [] + # sys.maxsize means the max length of output buffer is unlimited, + # so that the whole input buffer can be decompressed within one + # .decompress() call. + while data := self._read_zlib(sys.maxsize): + chunks.append(data) + + return b"".join(chunks) diff --git a/tests/test_stream.py b/tests/test_stream.py index 1bf57be..1f0cdd1 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,4 +1,6 @@ import io +import zlib +from unittest.mock import patch import pytest @@ -170,3 +172,34 @@ def test_overlay_stream(): fh.add((512 * 8) - 4, b"\x04" * 100) fh.seek((512 * 8) - 4) assert fh.read(100) == b"\x04" * 4 + + +def test_zlib_stream(): + data = b"\x01" * 8192 + b"\x02" * 8192 + b"\x03" * 8192 + b"\x04" * 8192 + fh = stream.ZlibStream(io.BytesIO(zlib.compress(data)), size=8192 * 4, align=512) + + assert fh.read(8192) == b"\x01" * 8192 + assert fh.read(8192) == b"\x02" * 8192 + assert fh.read(8192) == b"\x03" * 8192 + assert fh.read(8192) == b"\x04" * 8192 + assert fh.read(1) == b"" + + fh.seek(0) + assert fh.read(8192) == b"\x01" * 8192 + + fh.seek(1024) + assert fh.read(8192) == b"\x01" * 7168 + b"\x02" * 1024 + + fh.seek(512) + assert fh.read(1024) == b"\x01" * 1024 + + fh.seek(0) + assert fh.readall() == data + + fh.seek(512) + assert fh.read(1024) == b"\x01" * 1024 + with patch("io.DEFAULT_BUFFER_SIZE", 8): + assert fh.read(1024) == b"\x01" * 1024 + + fh.seek(0) + assert fh.read() == data