diff --git a/.github/workflows/ci-cd.yaml b/.github/workflows/ci-cd.yaml index 81dbb03..6109216 100644 --- a/.github/workflows/ci-cd.yaml +++ b/.github/workflows/ci-cd.yaml @@ -1,6 +1,7 @@ name: ci-cd on: + workflow_dispatch: push: branches: - main diff --git a/.github/workflows/push-docker.yaml b/.github/workflows/push-docker.yaml index 64a5a27..17fbc68 100644 --- a/.github/workflows/push-docker.yaml +++ b/.github/workflows/push-docker.yaml @@ -1,6 +1,7 @@ name: push-docker on: + workflow_dispatch: push: branches: - main diff --git a/Makefile b/Makefile index 92cac85..5272b9e 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: quality style tests +.PHONY: quality style types tests quality: black --check --target-version py39 --preview src/wandbfsspec tests @@ -9,5 +9,8 @@ style: black --target-version py39 --preview src/wandbfsspec tests isort src/wandbfsspec tests +types: + mypy src/wandbfsspec tests + tests: pytest tests/ --durations 0 -s \ No newline at end of file diff --git a/README.md b/README.md index 100b8df..19097e4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ The `wandbfsspec` implementation is based on https://github.com/fsspec/filesyste Here's an example on how to locate and open a file from the File System: ```python ->>> from wandbfsspec.core import WandbFileSystem +>>> from wandbfsspec.spec import WandbFileSystem >>> fs = WandbFileSystem(api_key="YOUR_API_KEY") >>> fs.ls("alvarobartt/wandbfsspec-tests/3s6km7mp") ['alvarobartt/wandbfsspec-tests/3s6km7mp/config.yaml', 'alvarobartt/wandbfsspec-tests/3s6km7mp/file.yaml', 'alvarobartt/wandbfsspec-tests/3s6km7mp/files', 'alvarobartt/wandbfsspec-tests/3s6km7mp/output.log', 'alvarobartt/wandbfsspec-tests/3s6km7mp/requirements.txt', 'alvarobartt/wandbfsspec-tests/3s6km7mp/wandb-metadata.json', 'alvarobartt/wandbfsspec-tests/3s6km7mp/wandb-summary.json'] @@ -37,7 +37,7 @@ b'some: data\nfor: testing' Which is similar to how to locate and open a file from the Artifact Storage (just changing the class and the path): ```python ->>> from wandbfsspec.core import WandbArtifactStore +>>> from wandbfsspec.spec import WandbArtifactStore >>> fs = WandbArtifactStore(api_key="YOUR_API_KEY") >>> fs.ls("wandb/yolo-chess/model/run_1dnrszzr_model/v8") ['wandb/yolo-chess/model/run_1dnrszzr_model/v8/last.pt'] diff --git a/pyproject.toml b/pyproject.toml index f0e4bc4..0c2d4b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ repository = "https://github.com/alvarobartt/wandbfsspec" license = "MIT License" [tool.poetry.plugins."fsspec.specs"] -wandbfs = "wandbfsspec.core.WandbFileSystem" -wandbas = "wandbfsspec.core.WandbArtifactStore" +wandbfs = "wandbfsspec.spec.WandbFileSystem" +wandbas = "wandbfsspec.spec.WandbArtifactStore" [tool.poetry.dependencies] python = ">=3.7,<3.10" @@ -52,6 +52,10 @@ implicit_reexport = false strict_equality = true # --strict end +[[tool.mypy.overrides]] +module = ["fsspec.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] log_cli = true log_cli_level = "INFO" diff --git a/src/wandbfsspec/core.py b/src/wandbfsspec/core.py index b5e6600..d29d5e1 100644 --- a/src/wandbfsspec/core.py +++ b/src/wandbfsspec/core.py @@ -1,247 +1,30 @@ # Copyright 2022 Alvaro Bartolome, alvarobartt @ GitHub # See LICENSE for details. -import datetime -import logging import os -import tempfile import urllib.request -from pathlib import Path -from typing import Any, Dict, List, Literal, Tuple, Union +from typing import Any, Literal, Union import wandb -from fsspec import AbstractFileSystem -from fsspec.spec import AbstractBufferedFile +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem -MAX_PATH_LENGTH_WITHOUT_FILE_PATH = 3 -MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH = 5 +__all__ = ["WandbFile", "WandbBaseFileSystem"] -logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) -__all__ = ["WandbFileSystem", "WandbArtifactStore"] - - -class WandbFileSystem(AbstractFileSystem): - protocol = "wandbfs" - - def __init__( - self, - api_key: Union[str, None] = None, - ) -> None: - super().__init__() - - if api_key: - os.environ["WANDB_API_KEY"] = api_key - - assert os.getenv("WANDB_API_KEY"), ( - "In order to connect to the wandb Public API you need to provide the API" - " key either via param `api_key`, setting the key in the environment" - " variable `WANDB_API_KEY`, or running `wandb login `." - ) - - self.api = wandb.Api() - - @classmethod - def split_path( - self, path: str - ) -> Tuple[str, Union[str, None], Union[str, None], Union[str, None]]: - path = self._strip_protocol(path=path) - path = path.lstrip("/") - if "/" not in path: - return (path, None, None, None) - path = path.split("/") - if len(path) > MAX_PATH_LENGTH_WITHOUT_FILE_PATH: - return ( - *path[:MAX_PATH_LENGTH_WITHOUT_FILE_PATH], - "/".join(path[MAX_PATH_LENGTH_WITHOUT_FILE_PATH:]), - ) - path += [None] * (MAX_PATH_LENGTH_WITHOUT_FILE_PATH - len(path)) - return (*path, None) - - def ls(self, path: str, detail: bool = False) -> Union[List[str], Dict[str, Any]]: - entity, project, run_id, file_path = self.split_path(path=path) - if entity and project and run_id: - _files = self.api.run(f"{entity}/{project}/{run_id}").files() - base_path = f"{entity}/{project}/{run_id}" - return self.__ls_files( - _files=_files, - base_path=f"{base_path}/{file_path}" if file_path else base_path, - file_path=file_path if file_path else Path("./"), - detail=detail, - ) - elif entity and project: - _files = self.api.runs(f"{entity}/{project}") - base_path = f"{entity}/{project}" - return self.__ls_projects_or_runs( - _files=_files, base_path=base_path, detail=detail - ) - elif entity: - _files = self.api.projects(entity=entity) - base_path = entity - return self.__ls_projects_or_runs( - _files=_files, base_path=base_path, detail=detail - ) - return [] - - @staticmethod - def __ls_files( - _files: List[str], - base_path: Union[str, Path], - file_path: Union[str, Path] = Path("./"), - detail: bool = False, - ) -> Union[List[str], Dict[str, Any]]: - file_path = Path(file_path) if isinstance(file_path, str) else file_path - files = [] - for _file in _files: - filename = Path(_file.name) - if file_path not in filename.parents: - continue - filename_strip = Path(_file.name.replace(f"{file_path}/", "")) - if filename_strip.is_dir() or len(filename_strip.parents) > 1: - filename_strip = filename_strip.parent.as_posix().split("/")[0] - path = f"{base_path}/{filename_strip}" - if any(f["name"] == path for f in files) if detail else path in files: - continue - files.append( - { - "name": path, - "type": "directory", - "size": 0, - } - if detail - else path - ) - continue - files.append( - { - "name": f"{base_path}/{filename.name}", - "type": "file", - "size": _file.size, - } - if detail - else f"{base_path}/{filename.name}" - ) - return files - - @staticmethod - def __ls_projects_or_runs( - _files: List[str], detail: bool = False - ) -> Union[List[str], Dict[str, Any]]: - files = [] - for _file in _files: - files.append( - { - "name": _file.name, - "type": "directory", - "size": 0, - } - if detail - else _file.name - ) - return files - - def modified(self, path: str) -> datetime.datetime: - """Return the modified timestamp of a file as a datetime.datetime""" - entity, project, run_id, file_path = self.split_path(path=path) - if not file_path: - raise ValueError - _file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) - if not _file: - raise ValueError - return datetime.datetime.fromisoformat(_file.updated_at) - - def open(self, path: str, mode: Literal["rb", "wb"] = "rb") -> None: - _, _, _, file_path = self.split_path(path=path) - if not file_path: - raise ValueError - return WandbFile(self, path=path, mode=mode) - - def url(self, path: str) -> str: - entity, project, run_id, file_path = self.split_path(path=path) - _file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) - if not _file: - raise ValueError - return _file.direct_url - - def cat_file( - self, path: str, start: Union[int, None] = None, end: Union[int, None] = None - ) -> bytes: - url = self.url(path=path) - req = urllib.request.Request(url=url) - if not start and not end: - start, end = 0, "" - req.add_header("Range", f"bytes={start}-{end}") - return urllib.request.urlopen(req).read() - - def put_file(self, lpath: str, rpath: str, **kwargs) -> None: - lpath_ext = os.path.splitext(lpath)[1] - if lpath_ext == "": - raise ValueError("`lpath` must be a file path with extension!") - rpath_ext = os.path.splitext(rpath)[1] - if rpath_ext != "" and rpath_ext != lpath_ext: - raise ValueError( - "`lpath` and `rpath` extensions must match if those are file paths!" - ) - lpath = os.path.abspath(lpath) - _lpath = lpath - entity, project, run_id, file_path = self.split_path(path=rpath) - if rpath_ext != "": - _lpath = os.path.abspath(file_path) - os.makedirs(os.path.dirname(_lpath), exist_ok=True) - os.replace(lpath, _lpath) - run = self.api.run(f"{entity}/{project}/{run_id}") - run.upload_file(path=_lpath, root=".") - - def get_file( - self, rpath: str, lpath: str, overwrite: bool = False, **kwargs - ) -> None: - if os.path.splitext(rpath)[1] == "": - raise ValueError("`rpath` must be a file path with extension!") - entity, project, run_id, file_path = self.split_path(path=rpath) - file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) - _lpath = lpath - if os.path.splitext(lpath)[1] != "": - lpath = os.path.dirname(lpath) - file.download(root=lpath, replace=overwrite) - src_path = os.path.abspath(f"{lpath}/{rpath.split('/')[-1]}") - tgt_path = os.path.abspath(_lpath) - if src_path != tgt_path and not os.path.isdir(tgt_path): - os.rename(src_path, tgt_path) - - def rm_file(self, path: str) -> None: - entity, project, run_id, file_path = self.split_path(path=path) - file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) - file.delete() - - def cp_file(self, path1: str, path2: str, **kwargs) -> None: - path1_ext = os.path.splitext(path1)[1] - if path1_ext == "": - raise ValueError(f"Path {path1} must be a file path with extension!") - path2_ext = os.path.splitext(path2)[1] - if path2_ext == "": - raise ValueError(f"Path {path1} must be a file path with extension!") - if path1_ext != path2_ext: - raise ValueError("Path extensions must be the same for both parameters!") - with tempfile.TemporaryDirectory() as f: - self.get_file(lpath=f, rpath=path1, overwrite=True) - _, _, _, file_path = self.split_path(path=path1) - self.put_file(lpath=f"{f}/{file_path}", rpath=path2) - - -class WandbFile(AbstractBufferedFile): +class WandbFile(AbstractBufferedFile): # type: ignore def __init__( - self, fs: WandbFileSystem, path: str, mode: Literal["rb", "wb"] = "rb" + self, fs: AbstractFileSystem, path: str, mode: Literal["rb", "wb"] = "rb" ) -> None: super().__init__(fs=fs, path=path, mode=mode) def _fetch_range( self, start: Union[int, None] = None, end: Union[int, None] = None - ) -> bytes: + ) -> Any: return self.fs.cat_file(path=self.path, start=start, end=end) -class WandbArtifactStore(AbstractFileSystem): - protocol = "wandbas" +class WandbBaseFileSystem(AbstractFileSystem): # type: ignore + protocol: Literal["wandbfs", "wandbas"] def __init__( self, @@ -261,214 +44,24 @@ def __init__( self.api = wandb.Api() @classmethod - def split_path( - self, path: str - ) -> Tuple[str, Union[str, None], Union[str, None], Union[str, None]]: - path = self._strip_protocol(path=path) - path = path.lstrip("/") - if "/" not in path: - return (path, *[None] * MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH) - path = path.split("/") - if len(path) > MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH: - return ( - *path[:MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH], - "/".join(path[MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH:]), - ) - path += [None] * (MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH - len(path)) - return (*path, None) - - def ls(self, path: str, detail: bool = False) -> Union[List[str], Dict[str, Any]]: - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - _, - ) = self.split_path(path=path) - if entity and project and artifact_type and artifact_name and artifact_version: - return [ - f"{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}/{f.name}" - if not detail - else { - "name": f"{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}/{f.name}", - "type": "file", - "size": f.size, - } - for f in self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ).files() - ] - elif entity and project and artifact_type and artifact_name: - return [ - f"{entity}/{project}/{artifact_type}/{artifact_name}/{v.name.split(':')[1]}" - if not detail - else { - "name": f"{entity}/{project}/{artifact_type}/{artifact_name}/{v.name.split(':')[1]}", - "type": "directory", - "size": 0, - } - for v in self.api.artifact_versions( - name=f"{entity}/{project}/{artifact_name}", type_name=artifact_type - ) - ] - elif entity and project and artifact_type: - return [ - f"{entity}/{project}/{artifact_type}/{c.name}" - if not detail - else { - "name": f"{entity}/{project}/{artifact_type}/{c.name}", - "type": "directory", - "size": 0, - } - for c in self.api.artifact_type( - project=f"{entity}/{project}", type_name=artifact_type - ).collections() - ] - elif entity and project: - return [ - f"{entity}/{project}/{a.name}" - if not detail - else { - "name": f"{entity}/{project}/{a.name}", - "type": "directory", - "size": 0, - } - for a in self.api.artifact_types(project=f"{entity}/{project}") - ] - elif entity: - return [ - f"{entity}/{p.name}" - if not detail - else { - "name": f"{entity}/{p.name}", - "type": "directory", - "size": 0, - } - for p in self.api.projects(entity=entity) - ] - return [] - - def created(self, path: str) -> datetime.datetime: - """Return the created timestamp of a file as a datetime.datetime""" - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - _, - ) = self.split_path(path=path) - artifact = self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ) - if not artifact: - raise ValueError - return datetime.datetime.fromisoformat(artifact.created_at) - - def modified(self, path: str) -> datetime.datetime: - """Return the modified timestamp of a file as a datetime.datetime""" - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - _, - ) = self.split_path(path=path) - artifact = self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ) - if not artifact: - raise ValueError - return datetime.datetime.fromisoformat(artifact.updated_at) + def split_path(self, path: str) -> Any: + raise NotImplementedError("Needs to be implemented!") - def open(self, path: str, mode: Literal["rb", "wb"] = "rb") -> None: - ( - _, - _, - _, - _, - _, - file_path, - ) = self.split_path(path=path) + def open(self, path: str, mode: Literal["rb", "wb"] = "rb") -> WandbFile: + *_, file_path = self.split_path(path=path) if not file_path: raise ValueError return WandbFile(self, path=path, mode=mode) def url(self, path: str) -> str: - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - file_path, - ) = self.split_path(path=path) - artifact = self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ) - manifest = artifact._load_manifest() - digest = manifest.entries[file_path].digest - digest_id = wandb.util.b64_to_hex_id(digest) - return f"https://api.wandb.ai/artifactsV2/gcp-us/{artifact.entity}/{artifact.id}/{digest_id}" + raise NotImplementedError("Needs to be implemented!") def cat_file( self, path: str, start: Union[int, None] = None, end: Union[int, None] = None - ) -> bytes: + ) -> Any: url = self.url(path=path) req = urllib.request.Request(url=url) if not start and not end: - start, end = 0, "" + start, end = 0, "" # type: ignore req.add_header("Range", f"bytes={start}-{end}") return urllib.request.urlopen(req).read() - - def get_file( - self, lpath: str, rpath: str, overwrite: bool = False, **kwargs - ) -> None: - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - file_path, - ) = self.split_path(path=rpath) - artifact = self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ) - path = artifact.get_path(name=file_path) - if os.path.exists(lpath) and not overwrite: - return - path.download(root=lpath) - - def rm_file(self, path: str, force_rm: bool = False) -> None: - ( - entity, - project, - artifact_type, - artifact_name, - artifact_version, - file_path, - ) = self.split_path(path=path) - if not file_path: - if not force_rm: - logging.info( - "In order to remove an artifact, you'll need to pass" - " `force_rm=True`." - ) - return - artifact = self.api.artifact( - name=f"{entity}/{project}/{artifact_name}:{artifact_version}", - type=artifact_type, - ) - artifact.delete(delete_aliases=True) - return - logging.info( - "W&B just lets you remove complete artifact versions not artifact files." - ) diff --git a/src/wandbfsspec/spec.py b/src/wandbfsspec/spec.py new file mode 100644 index 0000000..631f1e9 --- /dev/null +++ b/src/wandbfsspec/spec.py @@ -0,0 +1,403 @@ +# Copyright 2022 Alvaro Bartolome, alvarobartt @ GitHub +# See LICENSE for details. + +import datetime +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import wandb + +from wandbfsspec.core import WandbBaseFileSystem + +MAX_PATH_LENGTH_WITHOUT_FILE_PATH = 3 +MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH = 5 + +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) + +__all__ = ["WandbFileSystem", "WandbArtifactStore"] + + +class WandbFileSystem(WandbBaseFileSystem): + protocol = "wandbfs" # type: ignore + + @classmethod + def split_path( + self, path: str + ) -> Tuple[str, Union[str, None], Union[str, None], Union[str, None]]: + path = self._strip_protocol(path=path) + path = path.lstrip("/") + if "/" not in path: + return (path, None, None, None) + path = path.split("/") # type: ignore + if len(path) > MAX_PATH_LENGTH_WITHOUT_FILE_PATH: + return ( # type: ignore + *path[:MAX_PATH_LENGTH_WITHOUT_FILE_PATH], + "/".join(path[MAX_PATH_LENGTH_WITHOUT_FILE_PATH:]), + ) + path += [None] * (MAX_PATH_LENGTH_WITHOUT_FILE_PATH - len(path)) # type: ignore + return (*path, None) # type: ignore + + @staticmethod + def __ls_files( + _files: List[Any], + base_path: Union[str, Path], + file_path: Union[str, Path] = Path("./"), + detail: bool = False, + ) -> Union[List[str], List[Dict[str, Any]]]: + file_path = Path(file_path) if isinstance(file_path, str) else file_path + files: Union[List[str], List[Dict[str, Any]]] = list() # type: ignore + for _file in _files: + filename = Path(_file.name) + if file_path not in filename.parents: + continue + filename_strip = Path(_file.name.replace(f"{file_path}/", "")) + if filename_strip.is_dir() or len(filename_strip.parents) > 1: + filename_strip = filename_strip.parent.as_posix().split("/")[0] # type: ignore + path = f"{base_path}/{filename_strip}" + if any(f["name"] == path for f in files) if detail else path in files: # type: ignore + continue + files.append( + { # type: ignore + "name": path, + "type": "directory", + "size": 0, + } + if detail + else path + ) + continue + files.append( + { # type: ignore + "name": f"{base_path}/{filename.name}", + "type": "file", + "size": _file.size, + } + if detail + else f"{base_path}/{filename.name}" + ) + return files + + @staticmethod + def __ls_projects_or_runs( + _files: List[Any], detail: bool = False + ) -> Union[List[str], List[Dict[str, Any]]]: + files = [] + for _file in _files: + files.append( + { + "name": _file.name, + "type": "directory", + "size": 0, + } + if detail + else _file.name + ) + return files + + def ls( + self, path: str, detail: bool = False + ) -> Union[List[str], List[Dict[str, Any]]]: + entity, project, run_id, file_path = self.split_path(path=path) + if entity and project and run_id: + _files = self.api.run(f"{entity}/{project}/{run_id}").files() # type: ignore + base_path = f"{entity}/{project}/{run_id}" + return self.__ls_files( + _files=_files, + base_path=f"{base_path}/{file_path}" if file_path else base_path, + file_path=file_path if file_path else Path("./"), + detail=detail, + ) + elif entity and project: + _files = self.api.runs(f"{entity}/{project}") + base_path = f"{entity}/{project}" + return self.__ls_projects_or_runs(_files=_files, detail=detail) + elif entity: + _files = self.api.projects(entity=entity) # type: ignore + base_path = entity + return self.__ls_projects_or_runs(_files=_files, detail=detail) + raise ValueError("You need to at least provide an `entity` value!") + + def modified(self, path: str) -> datetime.datetime: + """Return the modified timestamp of a file as a datetime.datetime""" + entity, project, run_id, file_path = self.split_path(path=path) + if not file_path: + raise ValueError( + "`file_path` can't be None, make sure the `path` is valid!" + ) + _file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) # type: ignore + if not _file: + raise FileNotFoundError( + f"`file` at {file_path} for {entity}/{project}/{run_id} couldn't be" + " found or doesn't exist!" + ) + return datetime.datetime.fromisoformat(_file.updated_at) + + def url(self, path: str) -> str: + entity, project, run_id, file_path = self.split_path(path=path) + _file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) # type: ignore + if not _file: + raise FileNotFoundError( + f"`file` at {file_path} for {entity}/{project}/{run_id} couldn't be" + " found or doesn't exist!" + ) + return str(_file.direct_url) + + def put_file(self, lpath: str, rpath: str, **kwargs: Dict[str, Any]) -> None: + lpath_ext = os.path.splitext(lpath)[1] + if lpath_ext == "": + raise ValueError("`lpath` must be a file path with extension!") + rpath_ext = os.path.splitext(rpath)[1] + if rpath_ext != "" and rpath_ext != lpath_ext: + raise ValueError( + "`lpath` and `rpath` extensions must match if those are file paths!" + ) + lpath = os.path.abspath(lpath) + _lpath = lpath + entity, project, run_id, file_path = self.split_path(path=rpath) + if rpath_ext != "": + _lpath = os.path.abspath(file_path) # type: ignore + os.makedirs(os.path.dirname(_lpath), exist_ok=True) + os.replace(lpath, _lpath) + run = self.api.run(f"{entity}/{project}/{run_id}") # type: ignore + run.upload_file(path=_lpath, root=".") + + def get_file( + self, rpath: str, lpath: str, overwrite: bool = False, **kwargs: Dict[str, Any] + ) -> None: + if os.path.splitext(rpath)[1] == "": + raise ValueError("`rpath` must be a file path with extension!") + entity, project, run_id, file_path = self.split_path(path=rpath) + file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) # type: ignore + _lpath = lpath + if os.path.splitext(lpath)[1] != "": + lpath = os.path.dirname(lpath) + file.download(root=lpath, replace=overwrite) + src_path = os.path.abspath(f"{lpath}/{rpath.split('/')[-1]}") + tgt_path = os.path.abspath(_lpath) + if src_path != tgt_path and not os.path.isdir(tgt_path): + os.rename(src_path, tgt_path) + + def rm_file(self, path: str) -> None: + entity, project, run_id, file_path = self.split_path(path=path) + file = self.api.run(f"{entity}/{project}/{run_id}").file(name=file_path) # type: ignore + file.delete() + + def cp_file(self, path1: str, path2: str, **kwargs: Dict[str, Any]) -> None: + path1_ext = os.path.splitext(path1)[1] + if path1_ext == "": + raise ValueError(f"Path {path1} must be a file path with extension!") + path2_ext = os.path.splitext(path2)[1] + if path2_ext == "": + raise ValueError(f"Path {path1} must be a file path with extension!") + if path1_ext != path2_ext: + raise ValueError("Path extensions must be the same for both parameters!") + with tempfile.TemporaryDirectory() as f: + self.get_file(lpath=f, rpath=path1, overwrite=True) + _, _, _, file_path = self.split_path(path=path1) + self.put_file(lpath=f"{f}/{file_path}", rpath=path2) + + +class WandbArtifactStore(WandbBaseFileSystem): + protocol = "wandbas" # type: ignore + + @classmethod + def split_path( + self, path: str + ) -> Tuple[ + str, + Union[str, None], + Union[str, None], + Union[str, None], + Union[str, None], + Union[str, None], + ]: + path = self._strip_protocol(path=path) + path = path.lstrip("/") + if "/" not in path: + return (path, *[None] * MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH) # type: ignore + path = path.split("/") # type: ignore + if len(path) > MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH: + return ( # type: ignore + *path[:MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH], + "/".join(path[MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH:]), + ) + path += [None] * (MAX_ARTIFACT_LENGTH_WITHOUT_FILE_PATH - len(path)) # type: ignore + return (*path, None) # type: ignore + + def ls( + self, path: str, detail: bool = False + ) -> Union[List[str], List[Dict[str, Any]]]: + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + _, + ) = self.split_path(path=path) + if entity and project and artifact_type and artifact_name and artifact_version: + return [ + f"{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}/{f.name}" + if not detail + else { + "name": f"{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}/{f.name}", + "type": "file", + "size": f.size, + } + for f in self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ).files() + ] + elif entity and project and artifact_type and artifact_name: + return [ + f"{entity}/{project}/{artifact_type}/{artifact_name}/{v.name.split(':')[1]}" + if not detail + else { + "name": f"{entity}/{project}/{artifact_type}/{artifact_name}/{v.name.split(':')[1]}", + "type": "directory", + "size": 0, + } + for v in self.api.artifact_versions( # type: ignore + name=f"{entity}/{project}/{artifact_name}", type_name=artifact_type + ) + ] + elif entity and project and artifact_type: + return [ + f"{entity}/{project}/{artifact_type}/{c.name}" + if not detail + else { + "name": f"{entity}/{project}/{artifact_type}/{c.name}", + "type": "directory", + "size": 0, + } + for c in self.api.artifact_type( # type: ignore + project=f"{entity}/{project}", type_name=artifact_type + ).collections() + ] + elif entity and project: + return [ + f"{entity}/{project}/{a.name}" + if not detail + else { + "name": f"{entity}/{project}/{a.name}", + "type": "directory", + "size": 0, + } + for a in self.api.artifact_types(project=f"{entity}/{project}") # type: ignore + ] + elif entity: + return [ + f"{entity}/{p.name}" + if not detail + else { + "name": f"{entity}/{p.name}", + "type": "directory", + "size": 0, + } + for p in self.api.projects(entity=entity) # type: ignore + ] + raise ValueError("You need to at least provide an `entity` value!") + + def created(self, path: str) -> datetime.datetime: + """Return the created timestamp of a file as a datetime.datetime""" + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + _, + ) = self.split_path(path=path) + artifact = self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ) + if not artifact: + raise ValueError("`artifact` is None, make sure that it exists!") + return datetime.datetime.fromisoformat(artifact.created_at) + + def modified(self, path: str) -> datetime.datetime: + """Return the modified timestamp of a file as a datetime.datetime""" + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + _, + ) = self.split_path(path=path) + artifact = self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ) + if not artifact: + raise ValueError("`artifact` is None, make sure that it exists!") + return datetime.datetime.fromisoformat(artifact.updated_at) + + def url(self, path: str) -> str: + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + file_path, + ) = self.split_path(path=path) + artifact = self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ) + manifest = artifact._load_manifest() + digest = manifest.entries[file_path].digest + digest_id = wandb.util.b64_to_hex_id(digest) + return f"https://api.wandb.ai/artifactsV2/gcp-us/{artifact.entity}/{artifact.id}/{digest_id}" + + def get_file( + self, lpath: str, rpath: str, overwrite: bool = False, **kwargs: Dict[str, Any] + ) -> None: + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + file_path, + ) = self.split_path(path=rpath) + artifact = self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ) + path = artifact.get_path(name=file_path) + if os.path.exists(lpath) and not overwrite: + return + path.download(root=lpath) + + def rm_file(self, path: str, force_rm: bool = False) -> None: + ( + entity, + project, + artifact_type, + artifact_name, + artifact_version, + file_path, + ) = self.split_path(path=path) + if not file_path: + if not force_rm: + logging.info( + "In order to remove an artifact, you'll need to pass" + " `force_rm=True`." + ) + return + artifact = self.api.artifact( # type: ignore + name=f"{entity}/{project}/{artifact_name}:{artifact_version}", + type=artifact_type, + ) + artifact.delete(delete_aliases=True) + return + logging.info( + "W&B just lets you remove complete artifact versions not artifact files." + ) diff --git a/tests/conftest.py b/tests/conftest.py index 80dee2e..a0db532 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,15 +4,23 @@ import os import pytest +from _pytest.fixtures import SubRequest + +from wandbfsspec.spec import WandbArtifactStore, WandbFileSystem from .utils import MockRun -MOCK_RUN = MockRun( +MOCK_RUN = MockRun( # type: ignore entity=os.getenv("WANDB_ENTITY", "alvarobartt"), project=os.getenv("WANDB_PROJECT", "wandbfsspec-tests"), ) +@pytest.fixture(params=[WandbArtifactStore.protocol, WandbFileSystem.protocol]) +def protocol(request: SubRequest) -> str: + return request.param # type: ignore + + @pytest.fixture def entity() -> str: return MOCK_RUN.entity @@ -25,19 +33,19 @@ def project() -> str: @pytest.fixture def run_id() -> str: - return MOCK_RUN.run_id + return MOCK_RUN.run_id # type: ignore @pytest.fixture def artifact_type() -> str: - return MOCK_RUN.artifact_type + return MOCK_RUN.artifact_type # type: ignore @pytest.fixture def artifact_name() -> str: - return MOCK_RUN.artifact_name + return MOCK_RUN.artifact_name # type: ignore @pytest.fixture def artifact_version() -> str: - return MOCK_RUN.artifact_version + return MOCK_RUN.artifact_version # type: ignore diff --git a/tests/test_core.py b/tests/test_core.py deleted file mode 100644 index 89fa95e..0000000 --- a/tests/test_core.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2022 Alvaro Bartolome, alvarobartt @ GitHub -# See LICENSE for details. - -import datetime -from typing import List - -import pytest - -from wandbfsspec.core import WandbArtifactStore, WandbFile, WandbFileSystem - - -class TestWandbFileSystem: - """Test `wandbfsspec.core.WandbFileSystem` class methods.""" - - @pytest.fixture(autouse=True) - @pytest.mark.usefixtures("entity", "project", "run_id") - def setup_method(self, entity: str, project: str, run_id: str) -> None: - self.fs = WandbFileSystem() - self.path = f"{self.fs.protocol}://{entity}/{project}/{run_id}" - self.file_path = "file.yaml" - - def teardown(self): - del self.fs - - def test_ls(self) -> None: - """Test `WandbFileSystem.ls` method.""" - files = self.fs.ls(path=self.path) - assert isinstance(files, List) - - def test_modified(self) -> None: - modified_at = self.fs.modified(path=f"{self.path}/{self.file_path}") - assert isinstance(modified_at, datetime.datetime) - - def test_open(self) -> None: - _file = self.fs.open(path=f"{self.path}/{self.file_path}") - assert isinstance(_file, WandbFile) - - -class TestWandbArtifactStore: - """Test `wandbfsspec.core.WandbArtifactStore` class methods.""" - - @pytest.fixture(autouse=True) - @pytest.mark.usefixtures( - "entity", "project", "artifact_type", "artifact_name", "artifact_version" - ) - def setup_method( - self, - entity: str, - project: str, - artifact_type: str, - artifact_name: str, - artifact_version: str, - ) -> None: - self.fs = WandbArtifactStore() - self.path = f"{self.fs.protocol}://{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}" - self.file_path = "file.yaml" - - def teardown(self): - del self.fs - - def test_ls(self) -> None: - """Test `WandbArtifactStore.ls` method.""" - files = self.fs.ls(path=self.path) - assert isinstance(files, List) - - def test_created(self) -> None: - created = self.fs.created(path=f"{self.path}/{self.file_path}") - assert isinstance(created, datetime.datetime) - - def test_modified(self) -> None: - modified_at = self.fs.modified(path=f"{self.path}/{self.file_path}") - assert isinstance(modified_at, datetime.datetime) - - def test_open(self) -> None: - _file = self.fs.open(path=f"{self.path}/{self.file_path}") - assert isinstance(_file, WandbFile) - - def test_rm_file(self) -> None: - self.fs.rm_file(path=self.path) - self.fs.rm_file(path=self.path, force_rm=True) diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py new file mode 100644 index 0000000..a365bee --- /dev/null +++ b/tests/test_fsspec.py @@ -0,0 +1,62 @@ +# Copyright 2022 Alvaro Bartolome, alvarobartt @ GitHub +# See LICENSE for details. + +import datetime +from typing import List + +import fsspec +import pytest + +from wandbfsspec.core import WandbFile + + +class TestFsspecFileSystem: + """Test `fsspec.FileSystem` class methods for `wandbfs` and `wandbas`.""" + + @pytest.fixture(autouse=True) + @pytest.mark.usefixtures( + "protocol", + "entity", + "project", + "run_id", + "artifact_type", + "artifact_name", + "artifact_version", + ) + def setup_method( + self, + protocol: str, + entity: str, + project: str, + run_id: str, + artifact_type: str, + artifact_name: str, + artifact_version: str, + ) -> None: + self.fs = fsspec.filesystem(protocol) + + self.base_path = f"{protocol}://{entity}/{project}" + self.path = ( + f"{self.base_path}/{run_id}" + if protocol == "wandbfs" + else f"{self.base_path}/{artifact_type}/{artifact_name}/{artifact_version}" + ) + self.file_path = "file.yaml" + + def teardown(self) -> None: + del self.fs + + def test_ls(self) -> None: + """Test `fsspec.FileSystem.ls` method.""" + files = self.fs.ls(path=self.path) + assert isinstance(files, List) + + def test_modified(self) -> None: + """Test `fsspec.FileSystem.modified` method.""" + modified_at = self.fs.modified(path=f"{self.path}/{self.file_path}") + assert isinstance(modified_at, datetime.datetime) + + def test_open(self) -> None: + """Test `fsspec.FileSystem.open` method.""" + with self.fs.open(path=f"{self.path}/{self.file_path}") as f: + assert isinstance(f, WandbFile) diff --git a/tests/test_spec.py b/tests/test_spec.py index 62b6b66..2a83d4c 100644 --- a/tests/test_spec.py +++ b/tests/test_spec.py @@ -4,36 +4,74 @@ import datetime from typing import List -import fsspec import pytest -from wandbfsspec.core import WandbFile, WandbFileSystem +from wandbfsspec.core import WandbFile +from wandbfsspec.spec import WandbArtifactStore, WandbFileSystem -class TestFsspecFileSystem: - """Test `fsspec.FileSystem` class methods for `wandbfs`.""" +class TestWandbFileSystem: + """Test `wandbfsspec.core.WandbFileSystem` class methods.""" @pytest.fixture(autouse=True) @pytest.mark.usefixtures("entity", "project", "run_id") def setup_method(self, entity: str, project: str, run_id: str) -> None: - self.fs = fsspec.filesystem(WandbFileSystem.protocol) - self.path = f"{WandbFileSystem.protocol}://{entity}/{project}/{run_id}" + self.fs = WandbFileSystem() + self.path = f"{self.fs.protocol}://{entity}/{project}/{run_id}" self.file_path = "file.yaml" - def teardown(self): + def teardown(self) -> None: del self.fs def test_ls(self) -> None: - """Test `fsspec.FileSystem.ls` method.""" + """Test `WandbFileSystem.ls` method.""" files = self.fs.ls(path=self.path) assert isinstance(files, List) def test_modified(self) -> None: - """Test `fsspec.FileSystem.modified` method.""" modified_at = self.fs.modified(path=f"{self.path}/{self.file_path}") assert isinstance(modified_at, datetime.datetime) def test_open(self) -> None: - """Test `fsspec.FileSystem.open` method.""" - with self.fs.open(path=f"{self.path}/{self.file_path}") as f: - assert isinstance(f, WandbFile) + _file = self.fs.open(path=f"{self.path}/{self.file_path}") + assert isinstance(_file, WandbFile) + + +class TestWandbArtifactStore: + """Test `wandbfsspec.core.WandbArtifactStore` class methods.""" + + @pytest.fixture(autouse=True) + @pytest.mark.usefixtures( + "entity", "project", "artifact_type", "artifact_name", "artifact_version" + ) + def setup_method( + self, + entity: str, + project: str, + artifact_type: str, + artifact_name: str, + artifact_version: str, + ) -> None: + self.fs = WandbArtifactStore() + self.path = f"{self.fs.protocol}://{entity}/{project}/{artifact_type}/{artifact_name}/{artifact_version}" + self.file_path = "file.yaml" + + def teardown(self) -> None: + del self.fs + + def test_ls(self) -> None: + """Test `WandbArtifactStore.ls` method.""" + files = self.fs.ls(path=self.path) + assert isinstance(files, List) + + def test_created(self) -> None: + created = self.fs.created(path=f"{self.path}/{self.file_path}") + assert isinstance(created, datetime.datetime) + + def test_modified(self) -> None: + modified_at = self.fs.modified(path=f"{self.path}/{self.file_path}") + assert isinstance(modified_at, datetime.datetime) + + def test_open(self) -> None: + _file = self.fs.open(path=f"{self.path}/{self.file_path}") + assert isinstance(_file, WandbFile) diff --git a/tests/utils.py b/tests/utils.py index 70ad718..b65e8ed 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,7 +22,7 @@ class MockRun: artifact_name: Union[str, None] = None artifact_version: Union[str, None] = None - def __post_init__(self): + def __post_init__(self) -> None: assert os.getenv("WANDB_API_KEY"), ( "In order to connect to the wandb Public API you need to provide the API" " key either via param `api_key`, setting the key in the environment"