Skip to content

Commit

Permalink
Add ParamDB.path and ParamDB.latest_commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Hadley committed Jun 29, 2023
1 parent dcfca2d commit 7a2f9db
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `ParamDB.path` to retrieve the database path.
- `ParamDB.latest_commit` to retrieve the latest commit entry.

## [0.8.0] (June 9 2023)

### Changed
Expand Down
23 changes: 20 additions & 3 deletions paramdb/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,16 @@ class Root(Struct):
"""

def __init__(self, path: str):
self._path = path
self._engine = create_engine(URL.create("sqlite+pysqlite", database=path))
self._Session = sessionmaker(self._engine) # pylint: disable=invalid-name
_Base.metadata.create_all(self._engine)

@property
def path(self) -> str:
"""Path of the database file."""
return self._path

def commit(self, message: str, data: T) -> int:
"""
Commit the given data to the database with the given message and return the ID
Expand Down Expand Up @@ -198,10 +204,9 @@ def load(self, commit_id: int | None = None, *, load_classes: bool = True) -> An
if data is None:
raise IndexError(
f"cannot load most recent commit because database"
f" '{self._engine.url.database}' has no commits"
f" '{self._path}' has no commits"
if commit_id is None
else f"commit {commit_id} does not exist in database"
f" '{self._engine.url.database}'"
else f"commit {commit_id} does not exist in database" f" '{self._path}'"
)
return json.loads(
_decompress(data),
Expand All @@ -217,6 +222,18 @@ def num_commits(self) -> int:
count = session.execute(select_stmt).scalar()
return count if count is not None else 0

@property
def latest_commit(self) -> CommitEntry | None:
"""Latest commit added to the database, or None if the database is empty."""
max_id_func = func.max(_Snapshot.id) # pylint: disable=not-callable
select_max_id = select(max_id_func).scalar_subquery()
select_stmt = select(
_Snapshot.id, _Snapshot.message, _Snapshot.timestamp
).where(_Snapshot.id == select_max_id)
with self._Session() as session:
latest_entry = session.execute(select_stmt).mappings().first()
return None if latest_entry is None else CommitEntry(**latest_entry)

def commit_history(
self,
start: int | None = None,
Expand Down
32 changes: 32 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def test_create_database(db_path: str) -> None:
assert os.path.exists(db_path)


def test_path(db_path: str) -> None:
"""Database path can be retrieved."""
param_db = ParamDB[Any](db_path)
assert param_db.path == db_path


def test_commit_not_json_serializable_fails(db_path: str) -> None:
"""Fails to commit a class that ParamDB does not know how to convert to JSON."""

Expand Down Expand Up @@ -271,6 +277,32 @@ def test_num_commits(db_path: str, param: CustomParam) -> None:
assert param_db.num_commits == 10


def test_empty_latest_commit(db_path: str) -> None:
"""An empty database has a latest_commit of None."""
param_db = ParamDB[CustomStruct](db_path)
assert param_db.latest_commit is None


def test_latest_commit(db_path: str, param: CustomParam) -> None:
"""The database has the correct value of latest_commit after each commit."""
param_db = ParamDB[CustomParam](db_path)
for i in range(10):
# Make the commit
message = f"Commit {i}"
start = time.time()
sleep_for_datetime()
commit_id = param_db.commit(message, param)
sleep_for_datetime()
end = time.time()

# Assert latest_commit matches the commit that was just made
latest_commit = param_db.latest_commit
assert latest_commit is not None
assert latest_commit.id == commit_id
assert latest_commit.message == message
assert start < latest_commit.timestamp.timestamp() < end


def test_empty_commit_history(db_path: str) -> None:
"""Loads an empty commit history from an empty database."""
param_db = ParamDB[CustomStruct](db_path)
Expand Down

0 comments on commit 7a2f9db

Please sign in to comment.