Skip to content

Commit

Permalink
Add a zlib stream implementation (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
Schamper authored Jan 15, 2024
1 parent 7041f19 commit d7ac7e3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
94 changes: 94 additions & 0 deletions dissect/util/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io
import os
import sys
import zlib
from bisect import bisect_left, bisect_right
from threading import Lock
from typing import BinaryIO, Optional, Union
Expand Down Expand Up @@ -550,3 +552,95 @@ def _read(self, offset: int, length: int) -> bytes:
overlay_idx += 1

return b"".join(result)


class ZlibStream(AlignedStream):
"""Create a zlib stream from another file-like object.
Basically the same as ``gzip.GzipFile`` but for raw zlib streams.
Due to the nature of zlib streams, seeking backwards requires resetting the decompression context.
Args:
fh: The source file-like object.
size: The size the stream should be.
"""

def __init__(self, fh: BinaryIO, size: Optional[int] = None, align: int = STREAM_BUFFER_SIZE, **kwargs):
self._fh = fh

self._zlib = None
self._zlib_args = kwargs
self._zlib_offset = 0
self._zlib_prepend = b""
self._zlib_prepend_offset = None
self._rewind()

super().__init__(size, align)

def _rewind(self) -> None:
self._fh.seek(0)
self._zlib = zlib.decompressobj(**self._zlib_args)
self._zlib_offset = 0
self._zlib_prepend = b""
self._zlib_prepend_offset = None

def _seek_zlib(self, offset: int) -> None:
if offset < self._zlib_offset:
self._rewind()

while self._zlib_offset < offset:
read_size = min(offset - self._zlib_offset, self.align)
if self._read_zlib(read_size) == b"":
break

def _read_fh(self, length: int) -> bytes:
if self._zlib_prepend_offset is None:
return self._fh.read(length)

if self._zlib_prepend_offset + length <= len(self._zlib_prepend):
offset = self._zlib_prepend_offset
self._zlib_prepend_offset += length
return self._zlib_prepend[offset : self._zlib_prepend_offset]

offset = self._zlib_prepend_offset
self._zlib_prepend_offset = None
return self._zlib_prepend[offset:] + self._fh.read(length - len(self._zlib_prepend) + offset)

def _read_zlib(self, length: int) -> bytes:
if length < 0:
return self.readall()

result = []
while length > 0:
buf = self._read_fh(io.DEFAULT_BUFFER_SIZE)
decompressed = self._zlib.decompress(buf, length)

if self._zlib.unconsumed_tail != b"":
self._zlib_prepend = self._zlib.unconsumed_tail
self._zlib_prepend_offset = 0

if buf == b"":
break

result.append(decompressed)
length -= len(decompressed)

buf = b"".join(result)
self._zlib_offset += len(buf)
return buf

def _read(self, offset: int, length: int) -> bytes:
self._seek_zlib(offset)
return self._read_zlib(length)

def readall(self) -> bytes:
self._seek_zlib(self.tell())

chunks = []
# sys.maxsize means the max length of output buffer is unlimited,
# so that the whole input buffer can be decompressed within one
# .decompress() call.
while data := self._read_zlib(sys.maxsize):
chunks.append(data)

return b"".join(chunks)
33 changes: 33 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import zlib
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -170,3 +172,34 @@ def test_overlay_stream():
fh.add((512 * 8) - 4, b"\x04" * 100)
fh.seek((512 * 8) - 4)
assert fh.read(100) == b"\x04" * 4


def test_zlib_stream():
data = b"\x01" * 8192 + b"\x02" * 8192 + b"\x03" * 8192 + b"\x04" * 8192
fh = stream.ZlibStream(io.BytesIO(zlib.compress(data)), size=8192 * 4, align=512)

assert fh.read(8192) == b"\x01" * 8192
assert fh.read(8192) == b"\x02" * 8192
assert fh.read(8192) == b"\x03" * 8192
assert fh.read(8192) == b"\x04" * 8192
assert fh.read(1) == b""

fh.seek(0)
assert fh.read(8192) == b"\x01" * 8192

fh.seek(1024)
assert fh.read(8192) == b"\x01" * 7168 + b"\x02" * 1024

fh.seek(512)
assert fh.read(1024) == b"\x01" * 1024

fh.seek(0)
assert fh.readall() == data

fh.seek(512)
assert fh.read(1024) == b"\x01" * 1024
with patch("io.DEFAULT_BUFFER_SIZE", 8):
assert fh.read(1024) == b"\x01" * 1024

fh.seek(0)
assert fh.read() == data

0 comments on commit d7ac7e3

Please sign in to comment.