Skip to content

Commit

Permalink
Merge pull request #452 from Karl5766/main
Browse files Browse the repository at this point in the history
Add _is_local_relative() check to avoid replacing // with / for schemed paths
  • Loading branch information
Karl5766 authored Jul 27, 2024
2 parents 96fb5e1 + 75770b0 commit e8e944c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 7 deletions.
20 changes: 19 additions & 1 deletion snakebids/core/input_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,24 @@ def _all_custom_paths(config: InputsConfig):
return all(comp.get("custom_path") for comp in config.values())


def _is_local_relative(path: Path | str):
"""Test if a path location is local path relative to the current working directory.
Parameter
---------
path
A UPath, Path, or str object to be checked
Returns
-------
is_url : bool
True if the path is relative
"""
path_str = str(path)
is_doubleslash_schemed = "://" in path_str
return not is_doubleslash_schemed and not os.path.isabs(path_str)


def _gen_bids_layout(
*,
bids_dir: Path | str,
Expand Down Expand Up @@ -723,7 +741,7 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str,
# If path is relative, we need to get a slash in front of it to ensure parsing works
# correctly. So prepend "./" or ".\" and run function again, then strip before
# returning
if not os.path.isabs(path) and get_first_dir(path) != ".":
if _is_local_relative(path) and get_first_dir(path) != ".":
path_, wildcard_values = _parse_bids_path(os.path.join(".", path), entities)
return str(Path(path_)), wildcard_values

Expand Down
36 changes: 36 additions & 0 deletions snakebids/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,24 @@
valid_entities: tuple[str, ...] = tuple(BidsConfig.load("bids").entities.keys())

path_characters = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8")
# StackOverflow answer by Gumbo
# https://stackoverflow.com/questions/1547899/which-characters-make-a-url-invalid#1547940
scheme_characters = st.sampled_from(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~?#[]@!$&'()*+,;="
)


def nothing() -> Any:
return st.nothing() # type: ignore


def schemes() -> st.SearchStrategy[str]:
# Generate the prefix part of a url.
random_st = st.text(scheme_characters, min_size=1).map(lambda s: f"{s}://")
fixed_st = st.sampled_from(["C://", "c:\\", "gs://", "s3://"])
return random_st | fixed_st


def paths(
*,
min_segments: int = 0,
Expand Down Expand Up @@ -73,6 +85,30 @@ def paths(
return result


def schemed_paths(
*,
min_segments: int = 0,
max_segments: int | None = None,
) -> st.SearchStrategy[str]:
"""schemed paths are those beginning with scheme://.
These paths are incompatible with paths() because pathlib.Path turns // into /
when parsing, and paths() returns Path objects by default. That's why we have a
separate function for schemed paths.
This function returns un-schemed paths as well, occasionally
"""
scheme_st: st.SearchStrategy[str] = (st.none() | schemes()).map(
lambda s: "" if s is None else s
)
paths_st: st.SearchStrategy[str] = paths(
min_segments=min_segments,
max_segments=max_segments,
absolute=False,
).map(lambda p: str(p))
return st.tuples(scheme_st, paths_st).map(lambda t: t[0] + t[1])


def bids_entity(
*,
blacklist_entities: Container[BidsEntity | str] | None = None,
Expand Down
73 changes: 67 additions & 6 deletions snakebids/tests/test_generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import warnings
from collections import defaultdict
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Any, Iterable, Literal, NamedTuple, TypedDict, TypeVar, cast

import attrs
Expand All @@ -30,6 +30,7 @@
_all_custom_paths,
_gen_bids_layout,
_get_components,
_is_local_relative,
_normalize_database_args,
_parse_bids_path,
_parse_custom_path,
Expand Down Expand Up @@ -1602,6 +1603,49 @@ def test_all_custom_paths(count: int):
assert not _all_custom_paths(config)


class TestRecogPathSchemes:
PATH_AND_TYPES = (
("file", "RELATIVE"),
("hello", "RELATIVE"),
("gs", "RELATIVE"),
("./hello/world", "RELATIVE"),
("hello/world", "RELATIVE"),
("/hello/world", "ABSOLUTE"),
("gs://some/google/cloud/bucket", "NETWORK"),
("s3://some/s3/bucket", "NETWORK"),
)

@pytest.mark.parametrize(("path", "path_type"), PATH_AND_TYPES)
def test_is_local_relative(self, path: str, path_type: str):
isnet = path_type == "NETWORK"
is_local_relative = path_type == "RELATIVE"

# test the path itself, and the corresponding Path(path)
assert (
_is_local_relative(path) == is_local_relative
), f"Path {path} fails is local relative path test."
if not isnet:
assert _is_local_relative(Path(path)) == is_local_relative

@pytest.mark.skipif(
sys.version_info < (3, 12), reason="Path class has no with_segments()"
)
@pytest.mark.parametrize(
("path", "path_type"), [tup for tup in PATH_AND_TYPES if tup[1] == "RELATIVE"]
)
def test_path_subclassing(self, path: str, path_type: str):
# Google cloud is not posix, for mocking purpose however we just
# need a class that is a subclass of Path
class MockGCSPath(PosixPath):
def __init__(self, *pathsegments: str):
super().__init__(*pathsegments)

def __str__(self): # __fspath__ calls __str__ by default
return f"gs://{super().__str__()}"

assert not _is_local_relative(MockGCSPath(path))


@example_if(
sys.version_info >= (3, 8),
dataset=BidsDataset(
Expand Down Expand Up @@ -1866,16 +1910,30 @@ def test_when_all_custom_paths_no_layout_indexed(


class TestParseBidsPath:
@given(component=sb_st.bids_components(max_values=1, restrict_patterns=True))
def test_splits_wildcards_from_path(self, component: BidsComponent):
@given(
component=sb_st.bids_components(max_values=1, restrict_patterns=True),
scheme=sb_st.schemes() | st.none(),
)
def test_splits_wildcards_from_path(
self, component: BidsComponent, scheme: str | None
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
tpl_path, matches = _parse_bids_path(path, entities)
assert tpl_path.format(**matches) == path

@given(component=sb_st.bids_components(max_values=1, restrict_patterns=True))
def test_one_match_found_for_each_entity(self, component: BidsComponent):
@given(
component=sb_st.bids_components(max_values=1, restrict_patterns=True),
scheme=sb_st.schemes() | st.none(),
)
def test_one_match_found_for_each_entity(
self, component: BidsComponent, scheme: str | None
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
_, matches = _parse_bids_path(path, entities)
assert set(matches.items()) == {
Expand All @@ -1886,12 +1944,15 @@ def test_one_match_found_for_each_entity(self, component: BidsComponent):
component=sb_st.bids_components(
max_values=1, restrict_patterns=True, extra_entities=False
),
scheme=sb_st.schemes() | st.none(),
entity=sb_st.bids_entity(),
)
def test_missing_match_leads_to_error(
self, component: BidsComponent, entity: BidsEntity
self, component: BidsComponent, scheme: str | None, entity: BidsEntity
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
assume(entity.entity not in entities)
with pytest.raises(BidsParseError) as err:
Expand Down

0 comments on commit e8e944c

Please sign in to comment.