diff --git a/CHANGELOG.md b/CHANGELOG.md index 74f13c0..090807b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added + +- `ParamDB.path` to retrieve the database path. +- `ParamDB.latest_commit` to retrieve the latest commit entry. + ## [0.8.0] (June 9 2023) ### Changed diff --git a/paramdb/_database.py b/paramdb/_database.py index a77e317..07c9071 100644 --- a/paramdb/_database.py +++ b/paramdb/_database.py @@ -144,10 +144,16 @@ class Root(Struct): """ def __init__(self, path: str): + self._path = path self._engine = create_engine(URL.create("sqlite+pysqlite", database=path)) self._Session = sessionmaker(self._engine) # pylint: disable=invalid-name _Base.metadata.create_all(self._engine) + @property + def path(self) -> str: + """Path of the database file.""" + return self._path + def commit(self, message: str, data: T) -> int: """ Commit the given data to the database with the given message and return the ID @@ -198,10 +204,9 @@ def load(self, commit_id: int | None = None, *, load_classes: bool = True) -> An if data is None: raise IndexError( f"cannot load most recent commit because database" - f" '{self._engine.url.database}' has no commits" + f" '{self._path}' has no commits" if commit_id is None - else f"commit {commit_id} does not exist in database" - f" '{self._engine.url.database}'" + else f"commit {commit_id} does not exist in database" f" '{self._path}'" ) return json.loads( _decompress(data), @@ -217,6 +222,18 @@ def num_commits(self) -> int: count = session.execute(select_stmt).scalar() return count if count is not None else 0 + @property + def latest_commit(self) -> CommitEntry | None: + """Latest commit added to the database, or None if the database is empty.""" + max_id_func = func.max(_Snapshot.id) # pylint: disable=not-callable + select_max_id = select(max_id_func).scalar_subquery() + select_stmt = select( + _Snapshot.id, _Snapshot.message, _Snapshot.timestamp + ).where(_Snapshot.id == select_max_id) + with self._Session() as session: + latest_entry = session.execute(select_stmt).mappings().first() + return None if latest_entry is None else CommitEntry(**latest_entry) + def commit_history( self, start: int | None = None, diff --git a/tests/test_database.py b/tests/test_database.py index 04bc4b2..0f055f0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -48,6 +48,12 @@ def test_create_database(db_path: str) -> None: assert os.path.exists(db_path) +def test_path(db_path: str) -> None: + """Database path can be retrieved.""" + param_db = ParamDB[Any](db_path) + assert param_db.path == db_path + + def test_commit_not_json_serializable_fails(db_path: str) -> None: """Fails to commit a class that ParamDB does not know how to convert to JSON.""" @@ -271,6 +277,32 @@ def test_num_commits(db_path: str, param: CustomParam) -> None: assert param_db.num_commits == 10 +def test_empty_latest_commit(db_path: str) -> None: + """An empty database has a latest_commit of None.""" + param_db = ParamDB[CustomStruct](db_path) + assert param_db.latest_commit is None + + +def test_latest_commit(db_path: str, param: CustomParam) -> None: + """The database has the correct value of latest_commit after each commit.""" + param_db = ParamDB[CustomParam](db_path) + for i in range(10): + # Make the commit + message = f"Commit {i}" + start = time.time() + sleep_for_datetime() + commit_id = param_db.commit(message, param) + sleep_for_datetime() + end = time.time() + + # Assert latest_commit matches the commit that was just made + latest_commit = param_db.latest_commit + assert latest_commit is not None + assert latest_commit.id == commit_id + assert latest_commit.message == message + assert start < latest_commit.timestamp.timestamp() < end + + def test_empty_commit_history(db_path: str) -> None: """Loads an empty commit history from an empty database.""" param_db = ParamDB[CustomStruct](db_path)