Skip to content

Commit

Permalink
Don't always require DatabaseMapping as context manager, drop locks
Browse files Browse the repository at this point in the history
Re #121
  • Loading branch information
soininen committed Jan 14, 2025
1 parent 00c870d commit 23bdb36
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 157 deletions.
11 changes: 8 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- **Breaking change:** The required SQLAlchemy version is now 1.4.
This may break clients that use the low-level query interface of `DatabaseMapping`,
e.g. `entity_sq`, as the queries now return proper SQLAlchemy 1.4 `Result` objects.
The high-level functions (`get_item`, `add_item` etc.) work as usual.
- The low-level query interface of `DatabaseMapping` (`entity_sq`, `alternative_sq`,...)
now return proper SQLAlchemy 1.4 `Result` objects.
- Using `DatabaseMapping` as a context manager (e.g. with the `with` statement)
now opens and closes a session.
- With the higher-level interface (`get_item()`, `add_item()`,...),
the session is opened automatically as needed.
- The low-level query interface requires the session to be opened manually,
i.e. all queries must be done inside a `with` block.

### Deprecated

Expand Down
184 changes: 118 additions & 66 deletions spinedb_api/db_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .db_mapping_commit_mixin import DatabaseMappingCommitMixin
from .db_mapping_query_mixin import DatabaseMappingQueryMixin
from .exception import NothingToCommit, SpineDBAPIError, SpineDBVersionError, SpineIntegrityError
from .export_functions import export_data
from .filters.tools import apply_filter_stack, load_filters, pop_filter_configs
from .helpers import (
Asterisk,
Expand All @@ -46,9 +45,9 @@
create_new_spine_database_from_bind,
model_meta,
)
from .import_functions import import_data
from .mapped_items import item_factory
from .spine_db_client import get_db_url_from_server
from .temp_id import TempId, resolve

logging.getLogger("alembic").setLevel(logging.CRITICAL)

Expand Down Expand Up @@ -86,18 +85,26 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat
You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`.
For example, you can call :meth:`fetch_more` in a dedicated thread while you do some work on the main thread.
This will nicely place items in the in-memory mapping so you can access them later, without
This will nicely place items in the in-memory mapping, so you can access them later, without
the overhead of fetching them from the DB.
The :meth:`query` method is also provided as an alternative way to retrieve data from the DB
while bypassing the in-memory mapping entirely.
You can use this class as a context manager, e.g.::
You must use this class as a context manager, e.g.::
with DatabaseMapping(db_url) as db_map:
# Do stuff with db_map
...
or::
db_map = DatabaseMapping(db_url)
...
with db_map:
# Do stuff with db_map
...
"""

_sq_name_by_item_type = {
Expand Down Expand Up @@ -176,31 +183,27 @@ def __init__(
self._metadata = MetaData()
self._metadata.reflect(self.engine)
self._tablenames = [t.name for t in self._metadata.sorted_tables]
self._connection = None
self._session = None
self._context_open_count = 0
if self._filter_configs is not None:
stack = load_filters(self._filter_configs)
apply_filter_stack(self, stack)

def __enter__(self):
self.closed = False
if self._closed:
return None

Check warning on line 194 in spinedb_api/db_mapping.py

View check run for this annotation

Codecov / codecov/patch

spinedb_api/db_mapping.py#L194

Added line #L194 was not covered by tests
self._context_open_count += 1
if self._connection is None:
self._connection = self.engine.connect()
if self._session is None:
self._session = Session(self._connection)
self._session = Session(self.engine)
return self

def __exit__(self, _exc_type, _exc_val, _exc_tb):
self._context_open_count -= 1
if self._context_open_count == 0:
self.close()
self._session.close()
self._session = None
return False

def __del__(self):
self.close()

@staticmethod
def item_types():
return [x for x in DatabaseMapping._sq_name_by_item_type if not item_factory(x).is_protected]
Expand All @@ -214,7 +217,8 @@ def item_factory(item_type):
return item_factory(item_type)

def _query_commit_count(self):
return self.query(self.commit_sq).count()
with self:
return self.query(self.commit_sq).count()

def _make_sq(self, item_type):
sq_name = self._sq_name_by_item_type[item_type]
Expand Down Expand Up @@ -762,6 +766,79 @@ def purge_items(self, item_type):
"""
return bool(self.remove_items(item_type, Asterisk))

def _make_query(self, item_type, **kwargs):
"""Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type.
Args:
item_type (str): item type
**kwargs: query filters
Returns:
:class:`~spinedb_api.query.Query` or None if the mapping is closed.
"""
sq = self._make_sq(item_type)
qry = self._session.query(sq)
for key, value in kwargs.items():
if isinstance(value, tuple):
continue
value = resolve(value)
if hasattr(sq.c, key):
qry = qry.filter(getattr(sq.c, key) == value)
elif key in self.item_factory(item_type)._external_fields:
src_key, key = self.item_factory(item_type)._external_fields[key]
ref_type = self.item_factory(item_type)._references[src_key]
ref_sq = self._make_sq(ref_type)
try:
qry = qry.filter(getattr(sq.c, src_key) == getattr(ref_sq.c, "id"), getattr(ref_sq.c, key) == value)
except AttributeError:
pass
return qry

def _get_next_chunk(self, item_type, offset, limit, **kwargs):
"""Gets chunk of items from the DB.
Returns:
list(dict): list of dictionary items.
"""
with self:
qry = self._make_query(item_type, **kwargs)
if not qry:
return []

Check warning on line 806 in spinedb_api/db_mapping.py

View check run for this annotation

Codecov / codecov/patch

spinedb_api/db_mapping.py#L806

Added line #L806 was not covered by tests
if not limit:
return [x._asdict() for x in qry]
return [x._asdict() for x in qry.limit(limit).offset(offset)]

Check warning on line 809 in spinedb_api/db_mapping.py

View check run for this annotation

Codecov / codecov/patch

spinedb_api/db_mapping.py#L809

Added line #L809 was not covered by tests

def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, **kwargs):
"""See base class."""
chunk = self._get_next_chunk(item_type, offset, limit, **kwargs)
if not chunk:
return []
if real_commit_count is None:
real_commit_count = self._query_commit_count()
is_db_dirty = self._get_commit_count() != real_commit_count
if is_db_dirty:
# We need to fetch the most recent references because their ids might have changed in the DB
for ref_type in self.item_factory(item_type).ref_types():
if ref_type != item_type:
self.do_fetch_all(ref_type, commit_count=real_commit_count)
mapped_table = self.mapped_table(item_type)
items = []
new_items = []
# Add items first
for x in chunk:
item, new = mapped_table.add_item_from_db(x, not is_db_dirty)
if new:
new_items.append(item)
else:
item.handle_refetch()
items.append(item)
# Once all items are added, add the unique key values
# Otherwise items that refer to other items that come later in the query will be seen as corrupted
for item in new_items:
mapped_table.add_unique(item)
item.become_referrer()
return items

def fetch_more(self, item_type, offset=0, limit=None, **kwargs):
"""Fetches items from the DB into the in-memory mapping, incrementally.
Expand Down Expand Up @@ -835,29 +912,33 @@ def commit_session(self, comment, apply_compatibility_transforms=True):
"""
if not comment:
raise SpineDBAPIError("Commit message cannot be empty.")
dirty_items = self._dirty_items()
if not dirty_items:
raise NothingToCommit()
commit = self._metadata.tables["commit"]
commit_item = {"user": self.username, "date": datetime.now(timezone.utc), "comment": comment}
try:
# TODO: The below locks the DB in sqlite, how about other dialects?
commit_id = self._connection.execute(commit.insert(), commit_item).inserted_primary_key[0]
except DBAPIError as e:
raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e
for tablename, (to_add, to_update, to_remove) in dirty_items:
for item in to_add + to_update + to_remove:
item.commit(commit_id)
# Remove before add, to help with keeping integrity constraints
self._do_remove_items(self._connection, tablename, *{x["id"] for x in to_remove})
self._do_update_items(self._connection, tablename, *to_update)
self._do_add_items(self._connection, tablename, *to_add)
self._session.commit()
if self._memory:
self._memory_dirty = True
transformation_info = compatibility_transformations(self._connection, apply=apply_compatibility_transforms)
self._commit_count = self._query_commit_count()
return transformation_info
with self:
dirty_items = self._dirty_items()
if not dirty_items:
raise NothingToCommit()
commit = self._metadata.tables["commit"]
commit_item = {"user": self.username, "date": datetime.now(timezone.utc), "comment": comment}
connection = self._session.connection()
try:
# TODO: The below locks the DB in sqlite, how about other dialects?
commit_id = connection.execute(commit.insert(), commit_item).inserted_primary_key[0]
except DBAPIError as e:
raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e
for tablename, (to_add, to_update, to_remove) in dirty_items:
for item in to_add + to_update + to_remove:
item.commit(commit_id)
# Remove before add, to help with keeping integrity constraints
self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove})
self._do_update_items(connection, tablename, *to_update)
self._do_add_items(connection, tablename, *to_add)
self._session.commit()
if self._memory:
self._memory_dirty = True
transformation_info = compatibility_transformations(
self._session.connection(), apply=apply_compatibility_transforms
)
self._commit_count = self._query_commit_count()
return transformation_info

def rollback_session(self):
"""Discards all the changes from the in-memory mapping."""
Expand All @@ -878,35 +959,6 @@ def has_external_commits(self):
"""
return self._commit_count != self._query_commit_count()

def close(self):
"""Closes this DB mapping.
For instance::
class MyDBMappingWrapper:
def __init__(self, url):
self._db_map = DatabaseMapping(url)
# More methods that do stuff with self._db_map
def __del__(self):
self._db_map.close()
Otherwise, the usage as context manager is recommended::
with DatabaseMapping(url) as db_map:
# Do stuff with db_map
...
# db_map.close() is automatically called when leaving this block
"""
if self._session is not None:
self._session.commit()
self._session.close()
self._session = None
if self._connection is not None:
self._connection.close()
self._connection = None
self.closed = True

def add_ext_entity_metadata(self, *items, **kwargs):
metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items)
self.add_items("metadata", *metadata_items, **kwargs)
Expand Down
Loading

0 comments on commit 23bdb36

Please sign in to comment.