Skip to content

Commit

Permalink
Merge pull request #704 from janosh/fix-zpath-pathlib
Browse files Browse the repository at this point in the history
Fix `zpath` when passing `pathlib.Path`
  • Loading branch information
shyuep authored Oct 21, 2024
2 parents 756b008 + 8163a11 commit d102547
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
9 changes: 5 additions & 4 deletions src/monty/os/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING

from monty.fnmatch import WildCard
Expand All @@ -14,7 +15,7 @@
from typing import Callable, Literal, Optional, Union


def zpath(filename: str) -> str:
def zpath(filename: str | Path) -> str:
"""
Returns an existing (zipped or unzipped) file path given the unzipped
version. If no path exists, returns the filename unmodified.
Expand All @@ -23,10 +24,10 @@ def zpath(filename: str) -> str:
filename: filename without zip extension
Returns:
filename with a zip extension (unless an unzipped version
exists). If filename is not found, the same filename is returned
unchanged.
str: filename with a zip extension (unless an unzipped version exists).
If filename is not found, the same filename is returned unchanged.
"""
filename = str(filename) # ensure we work with strings
exts = ("", ".gz", ".GZ", ".bz2", ".BZ2", ".z", ".Z")
for ext in exts:
filename = filename.removesuffix(ext)
Expand Down
36 changes: 35 additions & 1 deletion tests/test_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,51 @@


class TestPath:
def test_zpath(self, tmp_path: Path):
def test_zpath_str(self, tmp_path: Path):
tmp_gz = tmp_path / "tmp.gz"
tmp_gz.touch()
ret_path = zpath(str(tmp_gz))
assert isinstance(ret_path, str)
assert ret_path == str(tmp_gz)

tmp_not_bz2 = tmp_path / "tmp_not_bz2"
tmp_not_bz2.touch()

ret_path = zpath(f"{tmp_not_bz2}.bz2")
assert ret_path == str(tmp_not_bz2)
assert isinstance(ret_path, str)

def test_zpath_path(self, tmp_path: Path):
# Test with Path input
tmp_gz = tmp_path / "tmp.gz"
tmp_gz.touch()
ret_path = zpath(tmp_gz)
assert ret_path == str(tmp_gz)
assert isinstance(ret_path, str)

def test_zpath_multiple_extensions(self, tmp_path: Path):
exts = ["", ".gz", ".GZ", ".bz2", ".BZ2", ".z", ".Z"]
for ext in exts:
tmp_file = tmp_path / f"tmp{ext}"
# create files with all supported compression extensions
tmp_file.touch()

# zpath should return the file without compression extension
ret_path = zpath(tmp_path / "tmp")
assert ret_path == str(tmp_path / "tmp")
assert isinstance(ret_path, str)

(tmp_path / "tmp").unlink() # Remove the uncompressed file
ret_path = zpath(tmp_path / "tmp")
assert ret_path == str(tmp_path / "tmp.gz") # should find .gz first now

def test_zpath_nonexistent_file(self, tmp_path: Path):
# should return path as is for non-existent file
nonexistent = tmp_path / "nonexistent.txt"
ret_path = zpath(nonexistent)
assert ret_path == str(nonexistent)
ret_path = zpath(f"{nonexistent}.bz2")
assert ret_path == str(nonexistent)

def test_find_exts(self):
assert len(find_exts(MODULE_DIR, "py")) >= 18
Expand Down

0 comments on commit d102547

Please sign in to comment.