Skip to content

Commit

Permalink
Allow using cache directory set in env var
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Sep 6, 2024
1 parent 8e228f9 commit e46d463
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools==71.0",
requires = ["setuptools<=71.0",
"setuptools_scm[toml]>=6.2",
"wheel",
"oldest-supported-numpy"]
Expand Down
11 changes: 8 additions & 3 deletions spotiflow/model/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
Expand All @@ -6,7 +7,6 @@
from ..utils.get_file import get_file



@dataclass
class RegisteredModel:
"""
Expand All @@ -25,8 +25,13 @@ def list_registered():


def _default_cache_dir():
return Path("~").expanduser() / ".spotiflow" / "models"

default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None)
if default_cache_dir is None:
return Path("~").expanduser() / ".spotiflow" / "models"
default_cache_dir = Path(default_cache_dir)
if default_cache_dir.stem != "models":
default_cache_dir = default_cache_dir / "models"
return default_cache_dir

def get_pretrained_model_path(name: str, cache_dir: Optional[Path] = None) -> Path:
"""
Expand Down
1 change: 1 addition & 0 deletions spotiflow/model/spotiflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def from_pretrained(
inference_mode (bool, optional): whether to set the model in eval mode. Defaults to True.
which (str, optional): which checkpoint to load. Defaults to "best".
map_location (str, optional): device string to load the model to. Defaults to 'auto' (hardware-based).
cache_dir (Optional[Union[Path, str]], optional): directory to cache the model. Defaults to None. If None, will use the default cache directory (given by the env var SPOTIFLOW_CACHE_DIR if set, otherwise ~/.spotiflow).
Returns:
Self: loaded model
Expand Down
13 changes: 10 additions & 3 deletions spotiflow/sample_data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

from ..utils import get_data, NotRegisteredError
from ..utils import NotRegisteredError, get_data
from ..utils.get_file import get_file


Expand All @@ -23,9 +24,14 @@ class RegisteredDataset:
def list_registered():
return list(_REGISTERED.keys())


def _default_cache_dir():
return Path("~").expanduser() / ".spotiflow" / "datasets"
default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None)
if default_cache_dir is None:
return Path("~").expanduser() / ".spotiflow" / "datasets"
default_cache_dir = Path(default_cache_dir)
if default_cache_dir.stem != "datasets":
default_cache_dir = default_cache_dir / "datasets"
return default_cache_dir


def get_training_datasets_path(name: str, cache_dir: Optional[Path] = None) -> Path:
Expand Down Expand Up @@ -57,6 +63,7 @@ def load_dataset(name: str, include_test: bool=False, cache_dir: Optional[Union[
Args:
name (str): the name of the dataset to load.
include_test (bool, optional): whether to include the test set in the returned data. Defaults to False.
cache_dir (Optional[Union[Path, str]], optional): directory to cache the model. Defaults to None. If None, will use the default cache directory (given by the env var SPOTIFLOW_CACHE_DIR if set, otherwise ~/.spotiflow).
"""
if name not in _REGISTERED:
raise NotRegisteredError(f"No training dataset named {name} found. Available datasets: {','.join(sorted(list_registered()))}")
Expand Down

0 comments on commit e46d463

Please sign in to comment.