Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compatibility for urls in generate_inputs #452

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion snakebids/core/input_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,24 @@ def _all_custom_paths(config: InputsConfig):
return all(comp.get("custom_path") for comp in config.values())


def _is_local_relative(path: Path | str):
"""Test if a path location is local path relative to the current working directory.

Parameter
---------
path
A UPath, Path, or str object to be checked

Returns
-------
is_url : bool
True if the path is relative
"""
path_str = str(path)
is_doubleslash_schemed = "://" in path_str
return not is_doubleslash_schemed and not os.path.isabs(path_str)


def _gen_bids_layout(
*,
bids_dir: Path | str,
Expand Down Expand Up @@ -723,7 +741,7 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str,
# If path is relative, we need to get a slash in front of it to ensure parsing works
# correctly. So prepend "./" or ".\" and run function again, then strip before
# returning
if not os.path.isabs(path) and get_first_dir(path) != ".":
if _is_local_relative(path) and get_first_dir(path) != ".":
path_, wildcard_values = _parse_bids_path(os.path.join(".", path), entities)
return str(Path(path_)), wildcard_values

Expand Down
36 changes: 36 additions & 0 deletions snakebids/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,24 @@
valid_entities: tuple[str, ...] = tuple(BidsConfig.load("bids").entities.keys())

path_characters = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8")
# StackOverflow answer by Gumbo
# https://stackoverflow.com/questions/1547899/which-characters-make-a-url-invalid#1547940
scheme_characters = st.sampled_from(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~?#[]@!$&'()*+,;="
)


def nothing() -> Any:
return st.nothing() # type: ignore


def schemes() -> st.SearchStrategy[str]:
# Generate the prefix part of a url.
random_st = st.text(scheme_characters, min_size=1).map(lambda s: f"{s}://")
fixed_st = st.sampled_from(["C://", "c:\\", "gs://", "s3://"])
return random_st | fixed_st


def paths(
*,
min_segments: int = 0,
Expand Down Expand Up @@ -73,6 +85,30 @@ def paths(
return result


def schemed_paths(
*,
min_segments: int = 0,
max_segments: int | None = None,
) -> st.SearchStrategy[str]:
"""schemed paths are those beginning with scheme://.

These paths are incompatible with paths() because pathlib.Path turns // into /
when parsing, and paths() returns Path objects by default. That's why we have a
separate function for schemed paths.

This function returns un-schemed paths as well, occasionally
"""
scheme_st: st.SearchStrategy[str] = (st.none() | schemes()).map(
lambda s: "" if s is None else s
)
paths_st: st.SearchStrategy[str] = paths(
min_segments=min_segments,
max_segments=max_segments,
absolute=False,
).map(lambda p: str(p))
return st.tuples(scheme_st, paths_st).map(lambda t: t[0] + t[1])


def bids_entity(
*,
blacklist_entities: Container[BidsEntity | str] | None = None,
Expand Down
73 changes: 67 additions & 6 deletions snakebids/tests/test_generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import warnings
from collections import defaultdict
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Any, Iterable, Literal, NamedTuple, TypedDict, TypeVar, cast

import attrs
Expand All @@ -30,6 +30,7 @@
_all_custom_paths,
_gen_bids_layout,
_get_components,
_is_local_relative,
_normalize_database_args,
_parse_bids_path,
_parse_custom_path,
Expand Down Expand Up @@ -1602,6 +1603,49 @@ def test_all_custom_paths(count: int):
assert not _all_custom_paths(config)


class TestRecogPathSchemes:
PATH_AND_TYPES = (
("file", "RELATIVE"),
("hello", "RELATIVE"),
("gs", "RELATIVE"),
("./hello/world", "RELATIVE"),
("hello/world", "RELATIVE"),
("/hello/world", "ABSOLUTE"),
("gs://some/google/cloud/bucket", "NETWORK"),
("s3://some/s3/bucket", "NETWORK"),
)

@pytest.mark.parametrize(("path", "path_type"), PATH_AND_TYPES)
def test_is_local_relative(self, path: str, path_type: str):
isnet = path_type == "NETWORK"
is_local_relative = path_type == "RELATIVE"

# test the path itself, and the corresponding Path(path)
assert (
_is_local_relative(path) == is_local_relative
), f"Path {path} fails is local relative path test."
if not isnet:
assert _is_local_relative(Path(path)) == is_local_relative

@pytest.mark.skipif(
sys.version_info < (3, 12), reason="Path class has no with_segments()"
)
@pytest.mark.parametrize(
("path", "path_type"), [tup for tup in PATH_AND_TYPES if tup[1] == "RELATIVE"]
)
def test_path_subclassing(self, path: str, path_type: str):
# Google cloud is not posix, for mocking purpose however we just
# need a class that is a subclass of Path
class MockGCSPath(PosixPath):
def __init__(self, *pathsegments: str):
super().__init__(*pathsegments)

def __str__(self): # __fspath__ calls __str__ by default
return f"gs://{super().__str__()}"

assert not _is_local_relative(MockGCSPath(path))


@example_if(
sys.version_info >= (3, 8),
dataset=BidsDataset(
Expand Down Expand Up @@ -1866,16 +1910,30 @@ def test_when_all_custom_paths_no_layout_indexed(


class TestParseBidsPath:
@given(component=sb_st.bids_components(max_values=1, restrict_patterns=True))
def test_splits_wildcards_from_path(self, component: BidsComponent):
@given(
component=sb_st.bids_components(max_values=1, restrict_patterns=True),
scheme=sb_st.schemes() | st.none(),
)
def test_splits_wildcards_from_path(
self, component: BidsComponent, scheme: str | None
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
tpl_path, matches = _parse_bids_path(path, entities)
assert tpl_path.format(**matches) == path

@given(component=sb_st.bids_components(max_values=1, restrict_patterns=True))
def test_one_match_found_for_each_entity(self, component: BidsComponent):
@given(
component=sb_st.bids_components(max_values=1, restrict_patterns=True),
scheme=sb_st.schemes() | st.none(),
)
def test_one_match_found_for_each_entity(
self, component: BidsComponent, scheme: str | None
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
_, matches = _parse_bids_path(path, entities)
assert set(matches.items()) == {
Expand All @@ -1886,12 +1944,15 @@ def test_one_match_found_for_each_entity(self, component: BidsComponent):
component=sb_st.bids_components(
max_values=1, restrict_patterns=True, extra_entities=False
),
scheme=sb_st.schemes() | st.none(),
entity=sb_st.bids_entity(),
)
def test_missing_match_leads_to_error(
self, component: BidsComponent, entity: BidsEntity
self, component: BidsComponent, scheme: str | None, entity: BidsEntity
):
path = component.expand()[0]
if scheme is not None:
path = f"{scheme}{path}"
entities = [BidsEntity.normalize(e).entity for e in component.zip_lists]
assume(entity.entity not in entities)
with pytest.raises(BidsParseError) as err:
Expand Down
Loading