Skip to content

Commit

Permalink
change saving to pathlib
Browse files Browse the repository at this point in the history
add tests

fix save_models
  • Loading branch information
mastoffel committed Mar 12, 2024
1 parent aa81a13 commit a1764a6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 107 deletions.
39 changes: 26 additions & 13 deletions autoemulate/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from pathlib import Path

import joblib
import numpy as np
Expand All @@ -23,13 +24,18 @@ def _save_model(self, model, path):
"""
model_name = get_model_name(model)
# check if path is directory
if path is not None and os.path.isdir(path):
path = os.path.join(path, model_name)
if path is not None and Path(path).is_dir():
path = Path(path) / model_name
# save with model name if path is None
if path is None:
path = model_name
path = Path(model_name)
else:
path = Path(path)

# model
# create directory if it doesn't exist
path.parent.mkdir(parents=True, exist_ok=True)

# save model
joblib.dump(model, path)

# metadata
Expand All @@ -51,19 +57,23 @@ def _save_models(self, models, path):
path : str
Path to save the models.
"""
if path is not None and not os.path.isdir(path):
if path is None:
save_dir = Path.cwd()
else:
save_dir = Path(path)
# create directory if it doesn't exist
os.makedirs(path, exist_ok=True)
raise ValueError("Path must be a directory")
for model in models.values():
self._save_model(model, path)
save_dir.parent.mkdir(parents=True, exist_ok=True)
for model in models:
model_path = save_dir / get_model_name(model)
self._save_model(model, model_path)

def _load_model(self, path):
"""Loads a model from disk and checks version."""
path = Path(path)
model = joblib.load(path)
meta_path = self._get_meta_path(path)
meta_path = Path(self._get_meta_path(path))

if not os.path.exists(meta_path):
if not meta_path.exists():
raise FileNotFoundError(f"Metadata file {meta_path} not found.")

with open(meta_path, "r") as f:
Expand All @@ -89,6 +99,9 @@ def _get_meta_path(self, path):
If the path has an extension, it is replaced with _meta.json.
Otherwise, _meta.json is appended to the path.
"""
base, ext = os.path.splitext(path)
meta_path = f"{base}_meta.json" if ext else f"{base}_meta{ext}.json"
path = Path(path)
if path.suffix:
meta_path = path.with_name(f"{path.stem}_meta.json")
else:
meta_path = path.with_name(path.name + "_meta.json")
return meta_path
165 changes: 71 additions & 94 deletions tests/test_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest

Expand All @@ -22,117 +24,92 @@ def model():

@pytest.fixture
def models():
return {"RandomForest": RandomForest(), "GaussianProcess": GaussianProcess()}
return [RandomForest(), GaussianProcess()]


@pytest.fixture
def test_path():
return "test_model"


def test_save_model_w_path(model_serialiser, model, test_path):
model_serialiser._save_model(model, test_path)
assert os.path.exists(test_path)
assert os.path.exists(model_serialiser._get_meta_path(test_path))

with open(model_serialiser._get_meta_path(test_path), "r") as f:
meta = json.load(f)
assert "model" in meta
assert "scikit-learn" in meta
assert "numpy" in meta

os.remove(test_path)
os.remove(model_serialiser._get_meta_path(test_path))
def test_save_model_wo_path(model_serialiser, model):
with TemporaryDirectory() as temp_dir:
original_wd = os.getcwd()
os.chdir(temp_dir)

try:
model_serialiser._save_model(model, None)
model_name = get_model_name(model)
expected_path = Path(model_name)
expected_meta_path = model_serialiser._get_meta_path(model_name)
assert expected_path.exists()
assert expected_meta_path.exists()
finally:
os.chdir(original_wd)


def test_save_model_w_name(model_serialiser, model):
with TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "test_model"
model_serialiser._save_model(model, test_path)
meta_path = model_serialiser._get_meta_path(test_path)
assert test_path.exists()
assert meta_path.exists()

with open(meta_path, "r") as f:
meta = json.load(f)
assert "model" in meta
assert "scikit-learn" in meta
assert "numpy" in meta


def test_save_model_w_dir(model_serialiser, model):
test_dir = "test_dir"
os.makedirs(test_dir, exist_ok=True)
model_serialiser._save_model(model, test_dir)
model_name = get_model_name(model)
assert os.path.exists(os.path.join(test_dir, model_name))
with TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir)
model_serialiser._save_model(model, test_path)
assert test_path.exists()
assert (test_path / "RandomForest").exists()
assert (test_path / "RandomForest_meta.json").exists()

os.remove(os.path.join(test_dir, model_name))
os.remove(os.path.join(test_dir, model_serialiser._get_meta_path(model_name)))
os.rmdir(test_dir)

def test_load_model(model_serialiser, model):
with TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "test_model"
model_serialiser._save_model(model, test_path)
loaded_model = model_serialiser._load_model(test_path)
assert isinstance(loaded_model, type(model))

def test_save_model_wo_path(model_serialiser, model):
model_serialiser._save_model(model, None)
model_name = get_model_name(model)
assert os.path.exists(model_name)
assert os.path.exists(model_serialiser._get_meta_path(model_name))

with open(model_serialiser._get_meta_path(model_name), "r") as f:
meta = json.load(f)
assert "model" in meta
assert "scikit-learn" in meta
assert "numpy" in meta

os.remove(model_name)
os.remove(model_serialiser._get_meta_path(model_name))


def test_save_models_wo_dir(model_serialiser, models):
model_serialiser._save_models(models, None)
model_names = models.keys()
assert all([os.path.exists(model_name) for model_name in model_names])
assert all(
[
os.path.exists(model_serialiser._get_meta_path(model_name))
for model_name in model_names
]
)

for model_name in model_names:
os.remove(model_name)
os.remove(model_serialiser._get_meta_path(model_name))


def test_save_models_w_dir(model_serialiser, models):
test_dir = "test_dir"
os.makedirs(test_dir, exist_ok=True)
model_serialiser._save_models(models, test_dir)
model_names = models.keys()
assert all(
[
os.path.exists(os.path.join(test_dir, model_name))
for model_name in model_names
]
)
assert all(
[
os.path.exists(
os.path.join(test_dir, model_serialiser._get_meta_path(model_name))
)
for model_name in model_names
]
)

for model_name in model_names:
os.remove(os.path.join(test_dir, model_name))
os.remove(os.path.join(test_dir, model_serialiser._get_meta_path(model_name)))
os.rmdir(test_dir)


def test_load_model(model_serialiser, model, test_path):
model_serialiser._save_model(model, test_path)
loaded_model = model_serialiser._load_model(test_path)
assert isinstance(loaded_model, type(model))
os.remove(test_path)
os.remove(model_serialiser._get_meta_path(test_path))


def test_load_model_with_missing_meta_file(model_serialiser, model, test_path):
model_serialiser._save_model(model, test_path)
os.remove(model_serialiser._get_meta_path(test_path))
with pytest.raises(FileNotFoundError):
model_serialiser._load_model(test_path)
os.remove(test_path)

def test_load_model_with_missing_meta_file(model_serialiser, model):
with TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "test_model"
model_serialiser._save_model(model, test_path)
meta_path = model_serialiser._get_meta_path(test_path)
meta_path.unlink()
with pytest.raises(FileNotFoundError):
model_serialiser._load_model(test_path)


def test_invalid_file_path(model_serialiser, model):
with pytest.raises(Exception):
# only the / makes it invalid
model_serialiser._save_model(model, "/invalid/path/model")
with pytest.raises(Exception):
model_serialiser._load_model("/invalid/path/model")


def test_save_models_wo_path(model_serialiser, models):
with TemporaryDirectory() as temp_dir:
original_wd = os.getcwd()
os.chdir(temp_dir)

try:
model_serialiser._save_models(models, None)
for model in models:
model_name = get_model_name(model)
expected_path = Path(model_name)
expected_meta_path = model_serialiser._get_meta_path(model_name)
assert expected_path.exists()
assert expected_meta_path.exists()
finally:
os.chdir(original_wd)

0 comments on commit a1764a6

Please sign in to comment.