From e46d463ec699160cd73133cef75db1246b313d9c Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Fri, 6 Sep 2024 10:59:55 -0400 Subject: [PATCH] Allow using cache directory set in env var --- pyproject.toml | 2 +- spotiflow/model/pretrained.py | 11 ++++++++--- spotiflow/model/spotiflow.py | 1 + spotiflow/sample_data/datasets.py | 13 ++++++++++--- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1899cd..ac68f64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools==71.0", +requires = ["setuptools<=71.0", "setuptools_scm[toml]>=6.2", "wheel", "oldest-supported-numpy"] diff --git a/spotiflow/model/pretrained.py b/spotiflow/model/pretrained.py index 305c202..be9491a 100644 --- a/spotiflow/model/pretrained.py +++ b/spotiflow/model/pretrained.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -6,7 +7,6 @@ from ..utils.get_file import get_file - @dataclass class RegisteredModel: """ @@ -25,8 +25,13 @@ def list_registered(): def _default_cache_dir(): - return Path("~").expanduser() / ".spotiflow" / "models" - + default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None) + if default_cache_dir is None: + return Path("~").expanduser() / ".spotiflow" / "models" + default_cache_dir = Path(default_cache_dir) + if default_cache_dir.stem != "models": + default_cache_dir = default_cache_dir / "models" + return default_cache_dir def get_pretrained_model_path(name: str, cache_dir: Optional[Path] = None) -> Path: """ diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 2ebd448..9d8ce11 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -193,6 +193,7 @@ def from_pretrained( inference_mode (bool, optional): whether to set the model in eval mode. Defaults to True. which (str, optional): which checkpoint to load. Defaults to "best". map_location (str, optional): device string to load the model to. Defaults to 'auto' (hardware-based). + cache_dir (Optional[Union[Path, str]], optional): directory to cache the model. Defaults to None. If None, will use the default cache directory (given by the env var SPOTIFLOW_CACHE_DIR if set, otherwise ~/.spotiflow). Returns: Self: loaded model diff --git a/spotiflow/sample_data/datasets.py b/spotiflow/sample_data/datasets.py index eba812e..731bd97 100644 --- a/spotiflow/sample_data/datasets.py +++ b/spotiflow/sample_data/datasets.py @@ -1,8 +1,9 @@ +import os from dataclasses import dataclass from pathlib import Path from typing import Optional, Union -from ..utils import get_data, NotRegisteredError +from ..utils import NotRegisteredError, get_data from ..utils.get_file import get_file @@ -23,9 +24,14 @@ class RegisteredDataset: def list_registered(): return list(_REGISTERED.keys()) - def _default_cache_dir(): - return Path("~").expanduser() / ".spotiflow" / "datasets" + default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None) + if default_cache_dir is None: + return Path("~").expanduser() / ".spotiflow" / "datasets" + default_cache_dir = Path(default_cache_dir) + if default_cache_dir.stem != "datasets": + default_cache_dir = default_cache_dir / "datasets" + return default_cache_dir def get_training_datasets_path(name: str, cache_dir: Optional[Path] = None) -> Path: @@ -57,6 +63,7 @@ def load_dataset(name: str, include_test: bool=False, cache_dir: Optional[Union[ Args: name (str): the name of the dataset to load. include_test (bool, optional): whether to include the test set in the returned data. Defaults to False. + cache_dir (Optional[Union[Path, str]], optional): directory to cache the model. Defaults to None. If None, will use the default cache directory (given by the env var SPOTIFLOW_CACHE_DIR if set, otherwise ~/.spotiflow). """ if name not in _REGISTERED: raise NotRegisteredError(f"No training dataset named {name} found. Available datasets: {','.join(sorted(list_registered()))}")