diff --git a/lib/spack/spack/test/util/compression.py b/lib/spack/spack/test/util/compression.py index 7b0f8d45fbcddd..74fced664252df 100644 --- a/lib/spack/spack/test/util/compression.py +++ b/lib/spack/spack/test/util/compression.py @@ -3,9 +3,11 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) + import os import shutil import sys +from itertools import product import pytest @@ -40,18 +42,24 @@ def compr_support_check(monkeypatch): @pytest.fixture -def archive_file(tmpdir_factory, request): - """Copy example archive to temp directory for test""" +def archive_file_and_extension(tmpdir_factory, request): + """Copy example archive to temp directory into an extension-less file for test""" archive_file_stub = os.path.join(datadir, "Foo") - extension = request.param + extension, add_extension = request.param tmpdir = tmpdir_factory.mktemp("compression") - shutil.copy(archive_file_stub + "." + extension, str(tmpdir)) - return os.path.join(str(tmpdir), "Foo.%s" % extension) - - -@pytest.mark.parametrize("archive_file", native_archive_list, indirect=True) -def test_native_unpacking(tmpdir_factory, archive_file): - util = scomp.decompressor_for(archive_file) + tmp_archive_file = os.path.join( + str(tmpdir), "Foo" + (("." + extension) if add_extension else "") + ) + shutil.copy(archive_file_stub + "." + extension, tmp_archive_file) + return (tmp_archive_file, extension) + + +@pytest.mark.parametrize( + "archive_file_and_extension", product(native_archive_list, [True, False]), indirect=True +) +def test_native_unpacking(tmpdir_factory, archive_file_and_extension): + archive_file, extension = archive_file_and_extension + util = scomp.decompressor_for(archive_file, extension) tmpdir = tmpdir_factory.mktemp("comp_test") with working_dir(str(tmpdir)): assert not os.listdir(os.getcwd()) @@ -64,9 +72,12 @@ def test_native_unpacking(tmpdir_factory, archive_file): @pytest.mark.skipif(sys.platform == "win32", reason="Only Python unpacking available on Windows") -@pytest.mark.parametrize("archive_file", ext_archive.keys(), indirect=True) -def test_system_unpacking(tmpdir_factory, archive_file, compr_support_check): +@pytest.mark.parametrize( + "archive_file_and_extension", [(ext, True) for ext in ext_archive.keys()], indirect=True +) +def test_system_unpacking(tmpdir_factory, archive_file_and_extension, compr_support_check): # actually run test + archive_file, _ = archive_file_and_extension util = scomp.decompressor_for(archive_file) tmpdir = tmpdir_factory.mktemp("system_comp_test") with working_dir(str(tmpdir)): @@ -102,3 +113,21 @@ def test_get_bad_extension(): @pytest.mark.parametrize("path", ext_archive.values()) def test_allowed_archive(path): assert scomp.allowed_archive(path) + + +@pytest.mark.parametrize("ext_path", ext_archive.items()) +def test_strip_compression_extension(ext_path): + ext, path = ext_path + stripped = scomp.strip_compression_extension(path) + if ext == "zip": + assert stripped == "Foo.zip" + stripped = scomp.strip_compression_extension(path, "zip") + assert stripped == "Foo" + elif ( + ext == "tar" + or ext in scomp.CONTRACTION_MAP.keys() + or ext in [".".join(ext) for ext in product(scomp.PRE_EXTS, scomp.EXTS)] + ): + assert stripped == "Foo.tar" or stripped == "Foo.TAR" + else: + assert stripped == "Foo" diff --git a/lib/spack/spack/util/compression.py b/lib/spack/spack/util/compression.py index ab7313d47dc342..fd8d17edb83bb5 100644 --- a/lib/spack/spack/util/compression.py +++ b/lib/spack/spack/util/compression.py @@ -21,6 +21,7 @@ PRE_EXTS = ["tar", "TAR"] EXTS = ["gz", "bz2", "xz", "Z"] NOTAR_EXTS = ["zip", "tgz", "tbz2", "tbz", "txz"] +CONTRACTION_MAP = {"tgz": "tar.gz", "txz": "tar.xz", "tbz": "tar.bz2", "tbz2": "tar.bz2"} # Add PRE_EXTS and EXTS last so that .tar.gz is matched *before* .tar or .gz ALLOWED_ARCHIVE_TYPES = ( @@ -77,8 +78,14 @@ def _system_untar(archive_file): archive_file (str): absolute path to the archive to be extracted. Can be one of .tar(.[gz|bz2|xz|Z]) or .(tgz|tbz|tbz2|txz). """ - outfile = os.path.basename(strip_extension(archive_file, "tar")) - + archive_file_no_ext = strip_extension(archive_file) + outfile = os.path.basename(archive_file_no_ext) + if archive_file_no_ext == archive_file: + # the archive file has no extension. Tar on windows cannot untar onto itself + # archive_file can be a tar file (which causes the problem on windows) but it can + # also have other extensions (on Unix) such as tgz, tbz2, ... + archive_file = archive_file_no_ext + "-input" + shutil.move(archive_file_no_ext, archive_file) tar = which("tar", required=True) tar.add_default_arg("-oxf") tar(archive_file) @@ -159,7 +166,12 @@ def _py_gunzip(archive_file): def _system_gunzip(archive_file): """Returns path to gunzip'd file Decompresses `.gz` compressed files via system gzip""" - decompressed_file = os.path.basename(strip_compression_extension(archive_file, "gz")) + archive_file_no_ext = strip_compression_extension(archive_file) + if archive_file_no_ext == archive_file: + # the zip file has no extension. On Unix gunzip cannot unzip onto itself + archive_file = archive_file + ".gz" + shutil.move(archive_file_no_ext, archive_file) + decompressed_file = os.path.basename(archive_file_no_ext) working_dir = os.getcwd() destination_abspath = os.path.join(working_dir, decompressed_file) compressed_file = os.path.basename(archive_file) @@ -233,14 +245,12 @@ def unarchive(archive_file): # record name of new archive so we can extract # and later clean up decomped_tarball = decompressor(archive_file) - if check_extension(decomped_tarball, "tar"): - # run tar on newly decomped archive - outfile = _system_untar(decomped_tarball) - # clean intermediate archive to mimic end result - # produced by one shot decomp/extraction - os.remove(decomped_tarball) - return outfile - return decomped_tarball + # run tar on newly decomped archive + outfile = _system_untar(decomped_tarball) + # clean intermediate archive to mimic end result + # produced by one shot decomp/extraction + os.remove(decomped_tarball) + return outfile return unarchive @@ -248,7 +258,7 @@ def unarchive(archive_file): def _py_lzma(archive_file): """Returns path to decompressed .xz files Decompress lzma compressed .xz files via python lzma module""" - decompressed_file = os.path.basename(strip_extension(archive_file, "xz")) + decompressed_file = os.path.basename(strip_compression_extension(archive_file, "xz")) archive_out = os.path.join(os.getcwd(), decompressed_file) with open(archive_out, "wb") as ar: with lzma.open(archive_file) as lar: @@ -707,15 +717,18 @@ def extension_from_path(path): def strip_compression_extension(path, ext=None): - """Returns path with last supported or provided archive extension stripped""" - path = expand_contracted_extension_in_path(path) - exts_to_check = EXTS - if ext: - exts_to_check = [ext] - for ext_check in exts_to_check: - mod_path = check_and_remove_ext(path, ext_check) - if mod_path != path: - return mod_path + """Returns path with last supported (can be combined with tar) or + provided archive extension stripped""" + path_ext = extension_from_path(path) + if path_ext: + path = expand_contracted_extension_in_path(path) + exts_to_check = EXTS + if ext: + exts_to_check = [ext] + for ext_check in exts_to_check: + mod_path = check_and_remove_ext(path, ext_check) + if mod_path != path: + return mod_path return path @@ -781,8 +794,7 @@ def expand_contracted_extension(extension): """Return expanded version of contracted extension i.e. .tgz -> .tar.gz, no op on non contracted extensions""" extension = extension.strip(".") - contraction_map = {"tgz": "tar.gz", "txz": "tar.xz", "tbz": "tar.bz2", "tbz2": "tar.bz2"} - return contraction_map.get(extension, extension) + return CONTRACTION_MAP.get(extension, extension) def compression_ext_from_compressed_archive(extension):