diff --git a/hfutils/archive/__init__.py b/hfutils/archive/__init__.py index 92207966dc4..e9e2e786f71 100644 --- a/hfutils/archive/__init__.py +++ b/hfutils/archive/__init__.py @@ -17,8 +17,9 @@ .. warning:: The creation of archive files in the RAR format is not supported, as we utilize the `rarfile `_ library, which does not offer functionality for creating RAR files. """ -from .base import register_archive_type, archive_pack, archive_unpack, get_archive_type, get_archive_extname -from .rar import _rar_pack, _rar_unpack -from .sevenz import _7z_pack, _7z_unpack -from .tar import _tarfile_pack, _tarfile_unpack -from .zip import _zip_pack, _zip_unpack +from .base import register_archive_type, archive_pack, archive_unpack, get_archive_type, get_archive_extname, \ + archive_writer, ArchiveWriter +from .rar import _rar_pack, _rar_unpack, RARWriter +from .sevenz import _7z_pack, _7z_unpack, SevenZWriter +from .tar import _tarfile_pack, _tarfile_unpack, TarWriter +from .zip import _zip_pack, _zip_unpack, ZipWriter diff --git a/hfutils/archive/base.py b/hfutils/archive/base.py index 20d173a9d5e..5cc91ad3d12 100644 --- a/hfutils/archive/base.py +++ b/hfutils/archive/base.py @@ -14,10 +14,43 @@ import warnings from typing import List, Dict, Tuple, Callable, Optional -_KNOWN_ARCHIVE_TYPES: Dict[str, Tuple[List[str], Callable, Callable]] = {} +class ArchiveWriter: + def __init__(self, archive_file: str): + self.archive_file = archive_file + self._handler = None -def register_archive_type(name: str, exts: List[str], fn_pack: Callable, fn_unpack: Callable): + def _create_handler(self): + raise NotImplementedError # pragma: no cover + + def _add_file(self, filename: str, arcname: str): + raise NotImplementedError # pragma: no cover + + def open(self): + if self._handler is None: + self._handler = self._create_handler() + + def add(self, filename: str, arcname: str): + return self._add_file(filename, arcname) + + def close(self): + if self._handler is not None: + self._handler.close() + self._handler = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +_FN_WRITER = Callable[[str], ArchiveWriter] +_KNOWN_ARCHIVE_TYPES: Dict[str, Tuple[List[str], Callable, Callable, _FN_WRITER]] = {} + + +def register_archive_type(name: str, exts: List[str], fn_pack: Callable, fn_unpack: Callable, fn_writer: _FN_WRITER): """ Register a custom archive type with associated file extensions and packing/unpacking functions. @@ -45,7 +78,7 @@ def register_archive_type(name: str, exts: List[str], fn_pack: Callable, fn_unpa """ if len(exts) == 0: raise ValueError(f'At least one extension name for archive type {name!r} should be provided.') - _KNOWN_ARCHIVE_TYPES[name] = (exts, fn_pack, fn_unpack) + _KNOWN_ARCHIVE_TYPES[name] = (exts, fn_pack, fn_unpack, fn_writer) def get_archive_extname(type_name: str) -> str: @@ -65,7 +98,7 @@ def get_archive_extname(type_name: str) -> str: '.zip' """ if type_name in _KNOWN_ARCHIVE_TYPES: - exts, _, _ = _KNOWN_ARCHIVE_TYPES[type_name] + exts, _, _, _ = _KNOWN_ARCHIVE_TYPES[type_name] return exts[0] else: raise ValueError(f'Unknown archive type - {type_name!r}.') @@ -95,7 +128,7 @@ def archive_pack(type_name: str, directory: str, archive_file: str, Example: >>> archive_pack('zip', '/path/to/directory', '/path/to/archive.zip', pattern='*.txt') """ - exts, fn_pack, _ = _KNOWN_ARCHIVE_TYPES[type_name] + exts, fn_pack, _, _ = _KNOWN_ARCHIVE_TYPES[type_name] if not any(os.path.normcase(archive_file).endswith(extname) for extname in exts): warnings.warn(f'The archive type {type_name!r} should be one of the {exts!r}, ' f'but file name {archive_file!r} is assigned. ' @@ -122,7 +155,7 @@ def get_archive_type(archive_file: str) -> str: 'gztar' """ archive_file = os.path.normcase(archive_file) - for type_name, (exts, _, _) in _KNOWN_ARCHIVE_TYPES.items(): + for type_name, (exts, _, _, _) in _KNOWN_ARCHIVE_TYPES.items(): if any(archive_file.endswith(extname) for extname in exts): return type_name @@ -149,5 +182,15 @@ def archive_unpack(archive_file: str, directory: str, silent: bool = False, pass >>> archive_unpack('/path/to/archive.zip', '/path/to/extract') """ type_name = get_archive_type(archive_file) - _, _, fn_unpack = _KNOWN_ARCHIVE_TYPES[type_name] + _, _, fn_unpack, _ = _KNOWN_ARCHIVE_TYPES[type_name] return fn_unpack(archive_file, directory, silent=silent, password=password) + + +def archive_writer(type_name: str, archive_file: str) -> ArchiveWriter: + exts, _, _, fn_writer = _KNOWN_ARCHIVE_TYPES[type_name] + if not any(os.path.normcase(archive_file).endswith(extname) for extname in exts): + warnings.warn(f'The archive type {type_name!r} should be one of the {exts!r}, ' + f'but file name {archive_file!r} is assigned. ' + f'We strongly recommend using a regular extension name for the archive file.') + + return fn_writer(archive_file) diff --git a/hfutils/archive/rar.py b/hfutils/archive/rar.py index a636f25c13d..458b90c6897 100644 --- a/hfutils/archive/rar.py +++ b/hfutils/archive/rar.py @@ -11,7 +11,7 @@ import os from typing import Optional -from .base import register_archive_type +from .base import register_archive_type, ArchiveWriter try: import rarfile @@ -19,6 +19,18 @@ rarfile = None +class RARWriter(ArchiveWriter): + def __init__(self, archive_file: str): + super().__init__(archive_file) + raise RuntimeError('RAR format writing is not supported.') + + def _create_handler(self): + raise NotImplementedError # pragma: no cover + + def _add_file(self, filename: str, arcname: str): + raise NotImplementedError # pragma: no cover + + def _rar_pack(directory, zip_file, pattern: Optional[str] = None, silent: bool = False, clear: bool = False): """ Placeholder function for RAR packing (not supported). @@ -58,4 +70,4 @@ def _rar_unpack(rar_file, directory, silent: bool = False, password: Optional[st if rarfile is not None: - register_archive_type('rar', ['.rar'], _rar_pack, _rar_unpack) + register_archive_type('rar', ['.rar'], _rar_pack, _rar_unpack, RARWriter) diff --git a/hfutils/archive/sevenz.py b/hfutils/archive/sevenz.py index 97a9f9fd7d4..b81cdb121d3 100644 --- a/hfutils/archive/sevenz.py +++ b/hfutils/archive/sevenz.py @@ -11,7 +11,7 @@ import os from typing import Optional -from .base import register_archive_type +from .base import register_archive_type, ArchiveWriter from ..utils import tqdm, walk_files try: @@ -20,6 +20,14 @@ py7zr = None +class SevenZWriter(ArchiveWriter): + def _create_handler(self): + return py7zr.SevenZipFile(self.archive_file, 'w') + + def _add_file(self, filename: str, arcname: str): + return self._handler.write(filename, arcname) + + def _7z_pack(directory, sz_file, pattern: Optional[str] = None, silent: bool = False, clear: bool = False): """ Pack files from a directory into a 7z archive. @@ -35,11 +43,11 @@ def _7z_pack(directory, sz_file, pattern: Optional[str] = None, silent: bool = F :param clear: If True, remove source files after packing. :type clear: bool, optional """ - with py7zr.SevenZipFile(sz_file, 'w') as zf: + with SevenZWriter(sz_file) as zf: progress = tqdm(walk_files(directory, pattern=pattern), silent=silent, desc=f'Packing {directory!r} ...') for file in progress: progress.set_description(file) - zf.write(os.path.join(directory, file), file) + zf.add(os.path.join(directory, file), file) if clear: os.remove(os.path.join(directory, file)) @@ -65,4 +73,4 @@ def _7z_unpack(sz_file, directory, silent: bool = False, password: Optional[str] if py7zr is not None: - register_archive_type('7z', ['.7z'], _7z_pack, _7z_unpack) + register_archive_type('7z', ['.7z'], _7z_pack, _7z_unpack, SevenZWriter) diff --git a/hfutils/archive/tar.py b/hfutils/archive/tar.py index 8aab9bb9474..85e0c7e1e52 100644 --- a/hfutils/archive/tar.py +++ b/hfutils/archive/tar.py @@ -15,7 +15,7 @@ from functools import partial from typing import Literal, Optional -from .base import register_archive_type +from .base import register_archive_type, ArchiveWriter from .zip import _ZLIB_SUPPORTED from ..utils import walk_files, tqdm @@ -38,6 +38,29 @@ CompressTyping = Literal['', 'gzip', 'bzip2', 'xz'] +class TarWriter(ArchiveWriter): + + def __init__(self, archive_file: str, compress: CompressTyping = "gzip"): + super().__init__(archive_file) + if compress is None: + self._tar_compression = '' + elif compress == 'gzip': + self._tar_compression = 'gz' + elif compress == 'bzip2': + self._tar_compression = 'bz2' + elif compress == 'xz': + self._tar_compression = 'xz' + else: + raise ValueError("bad value for 'compress', or compression format not " + "supported : {0}".format(compress)) + + def _create_handler(self): + return tarfile.open(self.archive_file, f'w|{self._tar_compression}') + + def _add_file(self, filename: str, arcname: str): + return self._handler.add(filename, arcname) + + def _tarfile_pack(directory, tar_file, pattern: Optional[str] = None, compress: CompressTyping = "gzip", silent: bool = False, clear: bool = False): """ @@ -57,19 +80,7 @@ def _tarfile_pack(directory, tar_file, pattern: Optional[str] = None, :type clear: bool :raises ValueError: If an unsupported compression method is specified. """ - if compress is None: - tar_compression = '' - elif compress == 'gzip': - tar_compression = 'gz' - elif compress == 'bzip2': - tar_compression = 'bz2' - elif compress == 'xz': - tar_compression = 'xz' - else: - raise ValueError("bad value for 'compress', or compression format not " - "supported : {0}".format(compress)) - - with tarfile.open(tar_file, f'w|{tar_compression}') as tar: + with TarWriter(tar_file, compress=compress) as tar: progress = tqdm(walk_files(directory, pattern=pattern), silent=silent, desc=f'Packing {directory!r} ...') for file in progress: progress.set_description(file) @@ -125,10 +136,30 @@ def _tarfile_unpack(tar_file, directory, silent: bool = False, numeric_owner=Fal # Register various tar archive types based on available compression libraries -register_archive_type('tar', ['.tar'], partial(_tarfile_pack, compress=None), _tarfile_unpack) +register_archive_type( + 'tar', ['.tar'], + partial(_tarfile_pack, compress=None), + _tarfile_unpack, + partial(TarWriter, compress=None), +) if _ZLIB_SUPPORTED: - register_archive_type('gztar', ['.tar.gz', '.tgz'], partial(_tarfile_pack, compress='gzip'), _tarfile_unpack) + register_archive_type( + 'gztar', ['.tar.gz', '.tgz'], + partial(_tarfile_pack, compress='gzip'), + _tarfile_unpack, + partial(TarWriter, compress='gzip'), + ) if _BZ2_SUPPORTED: - register_archive_type('bztar', ['.tar.bz2', '.tbz2'], partial(_tarfile_pack, compress='bzip2'), _tarfile_unpack) + register_archive_type( + 'bztar', ['.tar.bz2', '.tbz2'], + partial(_tarfile_pack, compress='bzip2'), + _tarfile_unpack, + partial(TarWriter, compress='bzip2'), + ) if _LZMA_SUPPORTED: - register_archive_type('xztar', ['.tar.xz', '.txz'], partial(_tarfile_pack, compress='xz'), _tarfile_unpack) + register_archive_type( + 'xztar', ['.tar.xz', '.txz'], + partial(_tarfile_pack, compress='xz'), + _tarfile_unpack, + partial(TarWriter, compress='xz'), + ) diff --git a/hfutils/archive/zip.py b/hfutils/archive/zip.py index 922cc16bf67..1fced3208cb 100644 --- a/hfutils/archive/zip.py +++ b/hfutils/archive/zip.py @@ -9,7 +9,7 @@ import zipfile from typing import Optional -from .base import register_archive_type +from .base import register_archive_type, ArchiveWriter from ..utils import tqdm, walk_files try: @@ -21,6 +21,14 @@ _ZLIB_SUPPORTED = False +class ZipWriter(ArchiveWriter): + def _create_handler(self): + return zipfile.ZipFile(self.archive_file, "w", compression=zipfile.ZIP_DEFLATED) + + def _add_file(self, filename: str, arcname: str): + return self._handler.write(filename, arcname) + + def _zip_pack(directory, zip_file, pattern: Optional[str] = None, silent: bool = False, clear: bool = False): """ Pack a directory into a ZIP file. @@ -36,11 +44,12 @@ def _zip_pack(directory, zip_file, pattern: Optional[str] = None, silent: bool = :param clear: If True, remove original files after packing. :type clear: bool """ - with zipfile.ZipFile(zip_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: - progress = tqdm(walk_files(directory, pattern=pattern), silent=silent, desc=f'Packing {directory!r} ...') + with ZipWriter(zip_file) as zf: + progress = tqdm(walk_files(directory, pattern=pattern), + silent=silent, desc=f'Packing {directory!r} ...') for file in progress: progress.set_description(file) - zf.write(os.path.join(directory, file), file) + zf.add(os.path.join(directory, file), file) if clear: os.remove(os.path.join(directory, file)) @@ -70,4 +79,4 @@ def _zip_unpack(zip_file, directory, silent: bool = False, password: Optional[st if _ZLIB_SUPPORTED: - register_archive_type('zip', ['.zip'], _zip_pack, _zip_unpack) + register_archive_type('zip', ['.zip'], _zip_pack, _zip_unpack, ZipWriter) diff --git a/test/archive/test_base.py b/test/archive/test_base.py index dd35040a963..eda919d0a35 100644 --- a/test/archive/test_base.py +++ b/test/archive/test_base.py @@ -33,4 +33,4 @@ def test_pack_with_warning(self, raw_dir): def test_empty_register(self): with pytest.raises(ValueError): - register_archive_type('xxx', [], lambda: None, lambda: None) + register_archive_type('xxx', [], lambda: None, lambda: None, lambda x: None)