diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cfd7f41..047cf9a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **Breaking change:** The required SQLAlchemy version is now 1.4. - This may break clients that use the low-level query interface of `DatabaseMapping`, - e.g. `entity_sq`, as the queries now return proper SQLAlchemy 1.4 `Result` objects. - The high-level functions (`get_item`, `add_item` etc.) work as usual. + - The low-level query interface of `DatabaseMapping` (`entity_sq`, `alternative_sq`,...) + now return proper SQLAlchemy 1.4 `Result` objects. + - Using `DatabaseMapping` as a context manager (e.g. with the `with` statement) + now opens and closes a session. + - With the higher-level interface (`get_item()`, `add_item()`,...), + the session is opened automatically as needed. + - The low-level query interface requires the session to be opened manually, + i.e. all queries must be done inside a `with` block. ### Deprecated diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 369c5baf..45ad68a2 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -36,7 +36,6 @@ from .db_mapping_commit_mixin import DatabaseMappingCommitMixin from .db_mapping_query_mixin import DatabaseMappingQueryMixin from .exception import NothingToCommit, SpineDBAPIError, SpineDBVersionError, SpineIntegrityError -from .export_functions import export_data from .filters.tools import apply_filter_stack, load_filters, pop_filter_configs from .helpers import ( Asterisk, @@ -46,9 +45,9 @@ create_new_spine_database_from_bind, model_meta, ) -from .import_functions import import_data from .mapped_items import item_factory from .spine_db_client import get_db_url_from_server +from .temp_id import TempId, resolve logging.getLogger("alembic").setLevel(logging.CRITICAL) @@ -86,18 +85,26 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. For example, you can call :meth:`fetch_more` in a dedicated thread while you do some work on the main thread. - This will nicely place items in the in-memory mapping so you can access them later, without + This will nicely place items in the in-memory mapping, so you can access them later, without the overhead of fetching them from the DB. The :meth:`query` method is also provided as an alternative way to retrieve data from the DB while bypassing the in-memory mapping entirely. - You can use this class as a context manager, e.g.:: + You must use this class as a context manager, e.g.:: with DatabaseMapping(db_url) as db_map: # Do stuff with db_map ... + or:: + + db_map = DatabaseMapping(db_url) + ... + with db_map: + # Do stuff with db_map + ... + """ _sq_name_by_item_type = { @@ -176,7 +183,6 @@ def __init__( self._metadata = MetaData() self._metadata.reflect(self.engine) self._tablenames = [t.name for t in self._metadata.sorted_tables] - self._connection = None self._session = None self._context_open_count = 0 if self._filter_configs is not None: @@ -184,23 +190,20 @@ def __init__( apply_filter_stack(self, stack) def __enter__(self): - self.closed = False + if self._closed: + return None self._context_open_count += 1 - if self._connection is None: - self._connection = self.engine.connect() if self._session is None: - self._session = Session(self._connection) + self._session = Session(self.engine) return self def __exit__(self, _exc_type, _exc_val, _exc_tb): self._context_open_count -= 1 if self._context_open_count == 0: - self.close() + self._session.close() + self._session = None return False - def __del__(self): - self.close() - @staticmethod def item_types(): return [x for x in DatabaseMapping._sq_name_by_item_type if not item_factory(x).is_protected] @@ -214,7 +217,8 @@ def item_factory(item_type): return item_factory(item_type) def _query_commit_count(self): - return self.query(self.commit_sq).count() + with self: + return self.query(self.commit_sq).count() def _make_sq(self, item_type): sq_name = self._sq_name_by_item_type[item_type] @@ -762,6 +766,79 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) + def _make_query(self, item_type, **kwargs): + """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. + + Args: + item_type (str): item type + **kwargs: query filters + + Returns: + :class:`~spinedb_api.query.Query` or None if the mapping is closed. + """ + sq = self._make_sq(item_type) + qry = self._session.query(sq) + for key, value in kwargs.items(): + if isinstance(value, tuple): + continue + value = resolve(value) + if hasattr(sq.c, key): + qry = qry.filter(getattr(sq.c, key) == value) + elif key in self.item_factory(item_type)._external_fields: + src_key, key = self.item_factory(item_type)._external_fields[key] + ref_type = self.item_factory(item_type)._references[src_key] + ref_sq = self._make_sq(ref_type) + try: + qry = qry.filter(getattr(sq.c, src_key) == getattr(ref_sq.c, "id"), getattr(ref_sq.c, key) == value) + except AttributeError: + pass + return qry + + def _get_next_chunk(self, item_type, offset, limit, **kwargs): + """Gets chunk of items from the DB. + + Returns: + list(dict): list of dictionary items. + """ + with self: + qry = self._make_query(item_type, **kwargs) + if not qry: + return [] + if not limit: + return [x._asdict() for x in qry] + return [x._asdict() for x in qry.limit(limit).offset(offset)] + + def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, **kwargs): + """See base class.""" + chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) + if not chunk: + return [] + if real_commit_count is None: + real_commit_count = self._query_commit_count() + is_db_dirty = self._get_commit_count() != real_commit_count + if is_db_dirty: + # We need to fetch the most recent references because their ids might have changed in the DB + for ref_type in self.item_factory(item_type).ref_types(): + if ref_type != item_type: + self.do_fetch_all(ref_type, commit_count=real_commit_count) + mapped_table = self.mapped_table(item_type) + items = [] + new_items = [] + # Add items first + for x in chunk: + item, new = mapped_table.add_item_from_db(x, not is_db_dirty) + if new: + new_items.append(item) + else: + item.handle_refetch() + items.append(item) + # Once all items are added, add the unique key values + # Otherwise items that refer to other items that come later in the query will be seen as corrupted + for item in new_items: + mapped_table.add_unique(item) + item.become_referrer() + return items + def fetch_more(self, item_type, offset=0, limit=None, **kwargs): """Fetches items from the DB into the in-memory mapping, incrementally. @@ -835,29 +912,33 @@ def commit_session(self, comment, apply_compatibility_transforms=True): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - dirty_items = self._dirty_items() - if not dirty_items: - raise NothingToCommit() - commit = self._metadata.tables["commit"] - commit_item = {"user": self.username, "date": datetime.now(timezone.utc), "comment": comment} - try: - # TODO: The below locks the DB in sqlite, how about other dialects? - commit_id = self._connection.execute(commit.insert(), commit_item).inserted_primary_key[0] - except DBAPIError as e: - raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e - for tablename, (to_add, to_update, to_remove) in dirty_items: - for item in to_add + to_update + to_remove: - item.commit(commit_id) - # Remove before add, to help with keeping integrity constraints - self._do_remove_items(self._connection, tablename, *{x["id"] for x in to_remove}) - self._do_update_items(self._connection, tablename, *to_update) - self._do_add_items(self._connection, tablename, *to_add) - self._session.commit() - if self._memory: - self._memory_dirty = True - transformation_info = compatibility_transformations(self._connection, apply=apply_compatibility_transforms) - self._commit_count = self._query_commit_count() - return transformation_info + with self: + dirty_items = self._dirty_items() + if not dirty_items: + raise NothingToCommit() + commit = self._metadata.tables["commit"] + commit_item = {"user": self.username, "date": datetime.now(timezone.utc), "comment": comment} + connection = self._session.connection() + try: + # TODO: The below locks the DB in sqlite, how about other dialects? + commit_id = connection.execute(commit.insert(), commit_item).inserted_primary_key[0] + except DBAPIError as e: + raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e + for tablename, (to_add, to_update, to_remove) in dirty_items: + for item in to_add + to_update + to_remove: + item.commit(commit_id) + # Remove before add, to help with keeping integrity constraints + self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) + self._do_update_items(connection, tablename, *to_update) + self._do_add_items(connection, tablename, *to_add) + self._session.commit() + if self._memory: + self._memory_dirty = True + transformation_info = compatibility_transformations( + self._session.connection(), apply=apply_compatibility_transforms + ) + self._commit_count = self._query_commit_count() + return transformation_info def rollback_session(self): """Discards all the changes from the in-memory mapping.""" @@ -878,35 +959,6 @@ def has_external_commits(self): """ return self._commit_count != self._query_commit_count() - def close(self): - """Closes this DB mapping. - For instance:: - - class MyDBMappingWrapper: - def __init__(self, url): - self._db_map = DatabaseMapping(url) - - # More methods that do stuff with self._db_map - - def __del__(self): - self._db_map.close() - - Otherwise, the usage as context manager is recommended:: - - with DatabaseMapping(url) as db_map: - # Do stuff with db_map - ... - # db_map.close() is automatically called when leaving this block - """ - if self._session is not None: - self._session.commit() - self._session.close() - self._session = None - if self._connection is not None: - self._connection.close() - self._connection = None - self.closed = True - def add_ext_entity_metadata(self, *items, **kwargs): metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 0bb3c27c..070c5383 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -12,7 +12,6 @@ from contextlib import suppress from difflib import SequenceMatcher from enum import Enum, auto, unique -from multiprocessing import Lock, RLock from typing import Set from .exception import SpineDBAPIError from .helpers import Asterisk @@ -42,11 +41,10 @@ class DatabaseMappingBase: """ def __init__(self): - self.closed = True + self._closed = False + self._context_open_count = 0 self._mapped_tables = {} self._fetched = {} - self._locker_lock = Lock() - self._locks = {} self._commit_count = None item_types = self.item_types() self._sorted_item_types = [] @@ -57,6 +55,17 @@ def __init__(self): else: self._sorted_item_types.append(item_type) + def __del__(self): + self.close() + + @property + def closed(self) -> bool: + return self._closed + + def close(self): + """Closes this DB mapping.""" + self._closed = True + @staticmethod def item_types(): """Returns a list of public item types from the DB mapping schema (equivalent to the table names). @@ -93,36 +102,6 @@ def item_factory(item_type): """ raise NotImplementedError() - def _make_query(self, item_type, **kwargs): - """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. - - Args: - item_type (str) - **kwargs: query filters - - Returns: - :class:`~spinedb_api.query.Query` or None if the mapping is closed. - """ - if self.closed: - return None - sq = self._make_sq(item_type) - qry = self.query(sq) - for key, value in kwargs.items(): - if isinstance(value, tuple): - continue - value = resolve(value) - if hasattr(sq.c, key): - qry = qry.filter(getattr(sq.c, key) == value) - elif key in self.item_factory(item_type)._external_fields: - src_key, key = self.item_factory(item_type)._external_fields[key] - ref_type = self.item_factory(item_type)._references[src_key] - ref_sq = self._make_sq(ref_type) - try: - qry = qry.filter(getattr(sq.c, src_key) == getattr(ref_sq.c, "id"), getattr(ref_sq.c, key) == value) - except AttributeError: - pass - return qry - def _make_sq(self, item_type): """Returns a :class:`~sqlalchemy.sql.expression.Alias` object representing a subquery to collect items of given type. @@ -288,19 +267,6 @@ def get_mapped_item(self, item_type, id_, fetch=True): mapped_table = self.mapped_table(item_type) return mapped_table.find_item_by_id(id_, fetch=fetch) or {} - def _get_next_chunk(self, item_type, offset, limit, **kwargs): - """Gets chunk of items from the DB. - - Returns: - list(dict): list of dictionary items. - """ - qry = self._make_query(item_type, **kwargs) - if not qry: - return [] - if not limit: - return [x._asdict() for x in qry] - return [x._asdict() for x in qry.limit(limit).offset(offset)] - def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, **kwargs): """Fetches items from the DB and adds them to the mapping. @@ -314,34 +280,7 @@ def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, Returns: list(MappedItem): items fetched from the DB. """ - chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) - if not chunk: - return [] - if real_commit_count is None: - real_commit_count = self._query_commit_count() - is_db_dirty = self._get_commit_count() != real_commit_count - if is_db_dirty: - # We need to fetch the most recent references because their ids might have changed in the DB - for ref_type in self.item_factory(item_type).ref_types(): - if ref_type != item_type: - self.do_fetch_all(ref_type, commit_count=real_commit_count) - mapped_table = self.mapped_table(item_type) - items = [] - new_items = [] - # Add items first - for x in chunk: - item, new = mapped_table.add_item_from_db(x, not is_db_dirty) - if new: - new_items.append(item) - else: - item.handle_refetch() - items.append(item) - # Once all items are added, add the unique key values - # Otherwise items that refer to other items that come later in the query will be seen as corrupted - for item in new_items: - mapped_table.add_unique(item) - item.become_referrer() - return items + raise NotImplementedError() def _get_commit_count(self): """Returns current commit count. @@ -362,16 +301,11 @@ def do_fetch_all(self, item_type, commit_count=None): item_type (str) commit_count (int,optional) """ - with self._locker_lock: - if item_type not in self._locks: - self._locks[item_type] = RLock() - lock = self._locks[item_type] - with lock: - if commit_count is None: - commit_count = self._get_commit_count() - if self._fetched.get(item_type, -1) < commit_count: - self._fetched[item_type] = commit_count - self.do_fetch_more(item_type, offset=0, limit=None, real_commit_count=commit_count) + if commit_count is None: + commit_count = self._get_commit_count() + if self._fetched.get(item_type, -1) < commit_count: + self._fetched[item_type] = commit_count + self.do_fetch_more(item_type, offset=0, limit=None, real_commit_count=commit_count) class _MappedTable(dict): diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index 9483c721..029b7f9b 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -37,7 +37,7 @@ def write(db_map, writer, *mappings, empty_data_header=True, max_tables=None, ma empty_data_header = len(mappings) * [empty_data_header] if isinstance(group_fns, str): group_fns = len(mappings) * [group_fns] - with _new_write(writer): + with _new_write(writer), db_map: for mapping, header_for_empty_data, group_fn in zip(mappings, empty_data_header, group_fns): mapping = drop_non_positioned_tail(copy(mapping)) for title, title_key in titles(mapping, db_map, limit=max_tables): @@ -130,7 +130,7 @@ def _new_table(writer, table_name, title_key): title_key (dict, optional) Yields: - bool: whether or not the new table was successfully started + bool: whether the new table was successfully started """ try: table_started = writer.start_table(table_name, title_key) diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 9a83d787..3a42dce6 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -33,7 +33,7 @@ def item_factory(item_type): def _query_commit_count(self): return -1 - def _make_query(self, _item_type, **kwargs): + def _make_query(self, _item_type, session, **kwargs): return None