Skip to content

Commit

Permalink
Initial commit for adding zstd (de)compression support for workload c…
Browse files Browse the repository at this point in the history
…orpora

Signed-off-by: beaioun <[email protected]>
  • Loading branch information
beaioun committed Nov 8, 2023
1 parent 3adacc0 commit 7c95c4b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
40 changes: 39 additions & 1 deletion osbenchmark/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import tarfile
import zipfile
from contextlib import suppress
import zstandard as zstd

import mmap

Expand Down Expand Up @@ -249,7 +250,7 @@ def is_archive(name):
:return: True iff the given file name is an archive that is also recognized for decompression by Benchmark.
"""
_, ext = splitext(name)
return ext in [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2"]
return ext in [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".zst"]


def is_executable(name):
Expand All @@ -272,6 +273,28 @@ def compress(source_directory, archive_name):
_zipdir(source_directory, archive)


def compress_zstd(source_directory, archive_name):
"""
Compress a directory tree using Zstandard compression.
:param source_directory: The source directory to compress. Must be readable.
:param archive_name: The absolute path including the file name of the archive. Must have the extension .zst.
"""
zstc = zstd.ZstdCompressor()

with open(archive_name, "wb") as archive_file:
with zstc.stream_writer(archive_file) as compressor:
for root, _, files in os.walk(source_directory):
for file in files:
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, source_directory)
# Write the file path (relative) to the archive to recreate the directory structure
compressor.write(rel_path.encode("utf-8"))
with open(file_path, "rb") as source_file:
# Write the content of the file to the archive
for chunk in source_file:
compressor.write(chunk)


def decompress(zip_name, target_directory):
"""
Decompresses the provided archive to the target directory. The following file extensions are supported:
Expand All @@ -283,6 +306,7 @@ def decompress(zip_name, target_directory):
* tar.gz
* tgz
* tar.bz2
* zst
The decompression method is chosen based on the file extension.
Expand All @@ -303,6 +327,8 @@ def decompress(zip_name, target_directory):
_do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib)
elif extension in [".tar", ".tar.gz", ".tgz", ".tar.bz2"]:
_do_decompress(target_directory, tarfile.open(zip_name))
elif extension == ".zst":
_do_decompress_zstd(target_directory, zip_name)
else:
raise RuntimeError("Unsupported file extension [%s]. Cannot decompress [%s]" % (extension, zip_name))

Expand Down Expand Up @@ -344,6 +370,18 @@ def _do_decompress_manually_with_lib(target_directory, filename, compressed_file
compressed_file.close()


def _do_decompress_zstd(target_directory, filename):
path_without_extension = os.path.splitext(os.path.basename(filename))[0]
try:
with open(filename, 'rb') as compressed_file:
zstd_decompressor = zstd.ZstdDecompressor()
with open(os.path.join(target_directory, path_without_extension), "wb") as new_file:
for chunk in zstd_decompressor.read_to_iter(compressed_file.read):
new_file.write(chunk)
except Exception as e:
logging.getLogger(__name__).warning("Failed to decompress [%s] with Zstandard. Error: %s.", filename, str(e))


def _do_decompress(target_directory, compressed_file):
try:
compressed_file.extractall(path=target_directory)
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def str_from_file(name):
# jmespath: MIT
# s3transfer: Apache 2.0
"boto3==1.28.62",
# Licence: BSD-3-Clause
"zstandard==0.22.0",
]

tests_require = [
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_has_extension(self):

class TestDecompression:
def test_decompresses_supported_file_formats(self):
for ext in ["zip", "gz", "bz2", "tgz", "tar.bz2", "tar.gz"]:
for ext in ["zip", "gz", "bz2", "tgz", "tar.bz2", "tar.gz", "zst"]:
tmp_dir = tempfile.mkdtemp()
archive_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources", f"test.txt.{ext}")
decompressed_path = os.path.join(tmp_dir, "test.txt")
Expand All @@ -90,7 +90,7 @@ def test_decompresses_supported_file_formats(self):

@mock.patch.object(io, "is_executable", return_value=False)
def test_decompresses_supported_file_formats_with_lib_as_failover(self, mocked_is_executable):
for ext in ["zip", "gz", "bz2", "tgz", "tar.bz2", "tar.gz"]:
for ext in ["zip", "gz", "bz2", "tgz", "tar.bz2", "tar.gz", "zst"]:
tmp_dir = tempfile.mkdtemp()
archive_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources", f"test.txt.{ext}")
decompressed_path = os.path.join(tmp_dir, "test.txt")
Expand Down

0 comments on commit 7c95c4b

Please sign in to comment.