Skip to content

Commit

Permalink
improve io funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Sep 8, 2023
1 parent 2cc2e13 commit 102425c
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 151 deletions.
312 changes: 162 additions & 150 deletions bioimageio/core/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from bioimageio.core._internal.utils import get_parent_url, write_zip
from bioimageio.spec import ResourceDescription
from bioimageio.spec import load_description as load_description_from_content
from bioimageio.spec import load_description as load_description
from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase
from bioimageio.spec._internal.constants import DISCOVER, LATEST
from bioimageio.spec._internal.constants import DISCOVER
from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, ValidationContext, YamlValue
from bioimageio.spec.description import dump_description
from bioimageio.spec.model.v0_4 import WeightsFormat
Expand All @@ -26,106 +26,43 @@

StrictFileSource = Union[HttpUrl, FilePath]
FileSource = Union[StrictFileSource, str]
RdfSource = Union[FileSource, RdfContent, ResourceDescription, str]
RdfSource = Union[FileSource, ResourceDescription]


class RawRdf(NamedTuple):
content: RdfContent
root: Union[HttpUrl, DirectoryPath]
file_name: str
LEGACY_RDF_NAME = "rdf.yaml"


def load_description(
rdf_source: RdfSource,
def read_description(
rdf_source: FileSource,
/,
*,
context: Optional[ValidationContext] = None,
format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER,
) -> Tuple[Optional[ResourceDescription], ValidationSummary]:
context = context or ValidationContext()
rdf_content = _get_rdf_content_and_update_context(rdf_source, context)
return load_description_from_content(
rdf_content,
context=context,
rdf = download_rdf(rdf_source)
return load_description(
rdf.content,
context=ValidationContext(root=rdf.root, file_name=rdf.file_name),
format_version=format_version,
)


LEGACY_RDF_NAME = "rdf.yaml"


def read_rdf_content(
def read_description_and_validate(
rdf_source: FileSource,
/,
*,
known_hash: Optional[str] = None,
rdf_encoding: str = "utf-8",
) -> RawRdf:
try:
rdf_source = TypeAdapter(StrictFileSource).validate_python(rdf_source)
except ValidationError as e:
raise e

if isinstance(rdf_source, AnyUrl):
_ls: Any = pooch.retrieve(url=str(rdf_source), known_hash=known_hash)
local_source = Path(_ls)
root: Union[HttpUrl, DirectoryPath] = get_parent_url(rdf_source)
else:
local_source = rdf_source
root = rdf_source.parent

if is_zipfile(local_source):
out_path = local_source.with_suffix(local_source.suffix + ".unzip")
with ZipFile(local_source, "r") as f:
rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")]
if len(rdfs) > 1:
raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).")
elif len(rdfs) == 1:
rdf_file_name = rdfs[0]
elif LEGACY_RDF_NAME in f.namelist():
rdf_file_name = LEGACY_RDF_NAME
else:
raise ValueError(
f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)."
)

f.extractall(out_path)
local_source = out_path / rdf_file_name

with local_source.open(encoding=rdf_encoding) as f:
content: YamlValue = yaml.load(f)

if not isinstance(content, collections.abc.Mapping):
raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.")

return RawRdf(
content=cast(RdfContent, content),
root=root,
file_name=extract_file_name(rdf_source),
)
) -> Tuple[Optional[ResourceDescription], ValidationSummary]:
rdf = download_rdf(rdf_source)
return load_description_and_validate(rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name))


def resolve_source(
source: Union[HttpUrl, FilePath, RelativeFilePath, str],
def load_description_and_validate(
rdf_content: RdfContent,
/,
*,
known_hash: Optional[str] = None,
root: Union[DirectoryPath, AnyUrl, None] = None,
) -> FilePath:
if isinstance(source, str):
source = TypeAdapter(Union[HttpUrl, FilePath, RelativeFilePath]).validate_python(source)

if isinstance(source, RelativeFilePath):
if root is None:
raise ValueError(f"Cannot resolve relative file path '{source}' without root.")

source = source.get_absolute(root)

if isinstance(source, AnyUrl):
_s: Any = pooch.retrieve(str(source), known_hash=known_hash)
source = Path(_s)

return source
context: Optional[ValidationContext] = None,
format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER,
) -> Tuple[Optional[ResourceDescription], ValidationSummary]:
"""load and validate a BioImage.IO description from the content of a resource description file (RDF)"""
rd, summary = load_description(rdf_content, context=context, format_version=format_version)
# todo: add dynamic validation
return rd, summary


def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union[FilePath, TextIO]):
Expand All @@ -141,69 +78,10 @@ def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union
yaml.dump(content, file)


def load_description_and_validate(
rdf_source: RdfSource,
/,
*,
context: Optional[ValidationContext] = None,
) -> Tuple[Optional[ResourceDescription], ValidationSummary]:
"""load and validate a BioImage.IO description from the content of a resource description file (RDF)"""
context = context or ValidationContext()
rdf_content = _get_rdf_content_and_update_context(rdf_source, context)
rd, summary = load_description_from_content(rdf_content, context=context, format_version=LATEST)
# todo: add dynamic validation
return rd, summary


def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent:
if isinstance(rdf_source, (AnyUrl, Path, str)):
rdf = read_rdf_content(rdf_source)
rdf_source = rdf.content
context.root = rdf.root
context.file_name = rdf.file_name
elif isinstance(rdf_source, ResourceDescriptionBase):
rdf_source = dump_description(rdf_source, exclude_unset=False)

return rdf_source


def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription:
if not isinstance(rdf_source, ResourceDescriptionBase):
descr, summary = load_description(rdf_source, context=context)
if descr is None:
rdf_source_msg = (
f"{{name={rdf_source.get('name', 'missing'), ...}}})"
if isinstance(rdf_source, collections.abc.Mapping)
else rdf_source
)
raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}")
rdf_source = descr

return rdf_source


def validate(
rdf_source: RdfSource,
/,
*,
context: Optional[ValidationContext] = None,
) -> ValidationSummary:
_rd, summary = load_description_and_validate(rdf_source, context=context)
return summary


def validate_format_only(
rdf_source: Union[ResourceDescription, RdfContent, FileSource], context: Optional[ValidationContext] = None
) -> ValidationSummary:
_rd, summary = load_description(rdf_source, context=context)
return summary


def prepare_resource_package(
rdf_source: RdfSource,
/,
*,
context: Optional[ValidationContext] = None,
weights_priority_order: Optional[Sequence[WeightsFormat]] = None,
) -> Dict[FileName, Union[FilePath, RdfContent]]:
"""Prepare to package a resource description; downloads all required files.
Expand All @@ -214,8 +92,17 @@ def prepare_resource_package(
weights_priority_order: If given only the first weights format present in the model is included.
If none of the prioritized weights formats is found all are included.
"""
context = context or ValidationContext()
rd = _get_description_and_update_context(rdf_source, context)
if isinstance(rdf_source, ResourceDescriptionBase):
rd = rdf_source
_ctxt = rd._internal_validation_context # pyright: ignore[reportPrivateUsage]
context = ValidationContext(root=_ctxt["root"], file_name=_ctxt["file_name"])
else:
rdf = download_rdf(rdf_source)
context = ValidationContext(root=rdf.root, file_name=rdf.file_name)
rd = load_description(
rdf.content,
context=context,
)
package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order)

local_package_content: Dict[FileName, Union[FilePath, RdfContent]] = {}
Expand All @@ -227,14 +114,11 @@ def prepare_resource_package(

return local_package_content

# output_folder.mkdir(parents=True, exist_ok=True)


def write_package(
rdf_source: RdfSource,
/,
*,
context: Optional[ValidationContext] = None,
compression: int = ZIP_DEFLATED,
compression_level: int = 1,
output_path: Optional[os.PathLike[str]] = None,
Expand Down Expand Up @@ -278,3 +162,131 @@ def write_package(

write_zip(output_path, package_content, compression=compression, compression_level=compression_level)
return output_path


class _LocalFile(NamedTuple):
path: FilePath
original_root: Union[AnyUrl, DirectoryPath]
original_file_name: str


class _LocalRdf(NamedTuple):
content: RdfContent
root: Union[AnyUrl, DirectoryPath]
file_name: str


def download(
source: FileSource,
/,
*,
known_hash: Optional[str] = None,
) -> _LocalFile:
source = _interprete_file_source(source)
if isinstance(source, AnyUrl):
_ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash)
local_source = Path(_ls)
root: Union[HttpUrl, DirectoryPath] = get_parent_url(source)
else:
local_source = source
root = source.parent

return _LocalFile(
local_source,
root,
extract_file_name(source),
)


def download_rdf(source: FileSource, /, *, known_hash: Optional[str] = None, rdf_encoding: str = "utf-8"):
local_source, root, file_name = download(source, known_hash=known_hash)
if is_zipfile(local_source):
out_path = local_source.with_suffix(local_source.suffix + ".unzip")
with ZipFile(local_source, "r") as f:
rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")]
if len(rdfs) > 1:
raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).")
elif len(rdfs) == 1:
rdf_file_name = rdfs[0]
elif LEGACY_RDF_NAME in f.namelist():
rdf_file_name = LEGACY_RDF_NAME
else:
raise ValueError(
f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)."
)

f.extractall(out_path)
local_source = out_path / rdf_file_name

with local_source.open(encoding=rdf_encoding) as f:
content: YamlValue = yaml.load(f)

if not isinstance(content, collections.abc.Mapping):
raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.")

return _LocalRdf(cast(RdfContent, content), root, file_name)


def resolve_source(
source: Union[FileSource, RelativeFilePath],
/,
*,
known_hash: Optional[str] = None,
root: Union[DirectoryPath, AnyUrl, None] = None,
) -> FilePath:
if isinstance(source, RelativeFilePath):
if root is None:
raise ValueError(f"Cannot resolve relative file path '{source}' without root.")

source = source.get_absolute(root)

return download(source, known_hash=known_hash).path


# def _get_rdf_content(rdf_source: RdfSource) -> Tuple[RdfContent, ValidationContext]:
# if isinstance(rdf_source, (AnyUrl, Path, str)):
# rdf = read_rdf_content(rdf_source)
# rdf_content = rdf.content
# context = ValidationContext(root=rdf.root, file_name=rdf.file_name)
# elif isinstance(rdf_source, ResourceDescriptionBase):
# rdf_content = dump_description(rdf_source, exclude_unset=False)
# ctxt = rdf_source._internal_validation_context # pyright: ignore[reportPrivateUsage]
# context = ValidationContext(root=ctxt["root"], file_name=ctxt["file_name"])
# else:
# rdf_content = rdf_source
# context = ValidationContext()

# return rdf_content, context


# def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent:
# if isinstance(rdf_source, (AnyUrl, Path, str)):
# rdf = read_rdf_content(rdf_source)
# rdf_source = rdf.content
# context.root = rdf.root
# context.file_name = rdf.file_name
# elif isinstance(rdf_source, ResourceDescriptionBase):
# rdf_source = dump_description(rdf_source, exclude_unset=False)

# return rdf_source


def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription:
if isinstance(rdf_source, dict):
descr, summary = load_description(rdf_source, context=context)
if descr is None:
rdf_source_msg = (
f"{{name={rdf_source.get('name', 'missing'), ...}}})"
if isinstance(rdf_source, collections.abc.Mapping)
else rdf_source
)
raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}")

return descr


def _interprete_file_source(file_source: FileSource) -> StrictFileSource:
return TypeAdapter(StrictFileSource).validate_python(file_source)
# todo: prettier file source validation error
# try:
# except ValidationError as e:
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pythonPlatform = "All"

[tool.pytest.ini_options]
addopts = "--capture=no --doctest-modules --failed-first"
# testpaths = ["bioimageio", "scripts", "example", "tests"]

[tool.ruff]
line-length = 120
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]
target-version = "py38"

0 comments on commit 102425c

Please sign in to comment.