diff --git a/pyproject.toml b/pyproject.toml index f0f43ad7..2b17c1eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,7 @@ classifiers = [ ] requires-python = ">=3.9" dependencies = [ - # v1.4 does not pass tests - "SQLAlchemy >=1.3, <1.4", + "SQLAlchemy >=1.4, <1.5", "alembic >=1.7", "datapackage >=1.15.2", "python-dateutil >=2.8.1", diff --git a/spinedb_api/alembic/env.py b/spinedb_api/alembic/env.py index e503d1f2..c129e357 100644 --- a/spinedb_api/alembic/env.py +++ b/spinedb_api/alembic/env.py @@ -1,4 +1,3 @@ -from __future__ import with_statement from logging.config import fileConfig from alembic import context from sqlalchemy import engine_from_config, pool diff --git a/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py b/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py index 5a5bbdac..b7a548aa 100644 --- a/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py +++ b/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py @@ -32,20 +32,20 @@ def upgrade(): sa.Column("display_order", sa.Integer, nullable=False), sa.Column( "display_status", - sa.Enum(DisplayStatus, name="display_status_enum"), + sa.Enum(DisplayStatus, name="display_status_enum", create_constraint=True), server_default=DisplayStatus.visible.name, nullable=False, ), sa.Column("display_font_color", sa.String(6), server_default=sa.null()), sa.Column("display_background_color", sa.String(6), server_default=sa.null()), sa.ForeignKeyConstraint( - ["entity_class_display_mode_id"], + ("entity_class_display_mode_id",), ["entity_class_display_mode.id"], onupdate="CASCADE", ondelete="CASCADE", ), sa.ForeignKeyConstraint( - ["entity_class_id"], + ("entity_class_id",), ["entity_class.id"], onupdate="CASCADE", ondelete="CASCADE", diff --git a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py index 9e641ab1..e2fe132f 100644 --- a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py +++ b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py @@ -30,8 +30,8 @@ def upgrade(): tfm = session.query(Base.classes.tool_feature_method).all() session.query(Base.classes.parameter_value_list).delete() session.query(Base.classes.tool_feature_method).delete() - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) # Change schema if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: diff --git a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py index 957a971f..52cb8234 100644 --- a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py +++ b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py @@ -17,8 +17,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("metadata_id", sa.Integer, server_default=sa.null())) diff --git a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py index 6ea3bad9..181e8338 100644 --- a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py +++ b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from alembic import op import sqlalchemy as sa +from sqlalchemy import text from sqlalchemy.ext.automap import automap_base from sqlalchemy.orm import sessionmaker @@ -90,8 +91,8 @@ def alter_tables_after_update(): None, "alternative", ("alternative_id",), ("id",), onupdate="CASCADE", ondelete="CASCADE" ) - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("alternative_id", sa.Integer, server_default=sa.null())) @@ -101,7 +102,8 @@ def alter_tables_after_update(): date = datetime.utcnow() conn = op.get_bind() conn.execute( - """ + text( + """ UPDATE next_id SET user = :user, @@ -109,7 +111,8 @@ def alter_tables_after_update(): alternative_id = 2, scenario_id = 1, scenario_alternative_id = 1 - """, + """ + ), user=user, date=date, ) diff --git a/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py b/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py index ef1ea2a5..5be0e6a8 100644 --- a/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py +++ b/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py @@ -17,8 +17,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("parameter_tag_id", sa.Integer)) diff --git a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py index cec2f6d9..749bbd98 100644 --- a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py +++ b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py @@ -8,6 +8,7 @@ from alembic import op import sqlalchemy as sa +from sqlalchemy import text from spinedb_api.helpers import naming_convention # revision identifiers, used by Alembic. @@ -106,28 +107,28 @@ def upgrade(): def _get_constraints(): conn = op.get_bind() - meta = sa.MetaData(conn) - meta.reflect() + meta = sa.MetaData() + meta.reflect(conn) return [[c.name for c in meta.tables[tname].constraints] for tname in ["entity_class", "entity"]] def _persist_data(): conn = op.get_bind() - meta = sa.MetaData(conn) - meta.reflect() + meta = sa.MetaData() + meta.reflect(conn) ecd_items = [ - {"entity_class_id": x["entity_class_id"], "dimension_id": x["member_class_id"], "position": x["dimension"]} - for x in conn.execute("SELECT * FROM relationship_entity_class") + {"entity_class_id": x.entity_class_id, "dimension_id": x.member_class_id, "position": x.dimension} + for x in conn.execute(text("SELECT * FROM relationship_entity_class")) ] ee_items = [ { - "entity_id": x["entity_id"], - "entity_class_id": x["entity_class_id"], - "element_id": x["member_id"], - "dimension_id": x["member_class_id"], - "position": x["dimension"], + "entity_id": x.entity_id, + "entity_class_id": x.entity_class_id, + "element_id": x.member_id, + "dimension_id": x.member_class_id, + "position": x.dimension, } - for x in conn.execute("SELECT * FROM relationship_entity") + for x in conn.execute(text("SELECT * FROM relationship_entity")) ] op.bulk_insert(meta.tables["entity_class_dimension"], ecd_items) op.bulk_insert(meta.tables["entity_element"], ee_items) diff --git a/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py b/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py index 827e40b8..65c5fdbe 100644 --- a/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py +++ b/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py @@ -8,6 +8,7 @@ from alembic import op import sqlalchemy as sa +from spinedb_api import naming_convention # revision identifiers, used by Alembic. revision = "7e2e66ae0f8f" @@ -17,8 +18,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) with op.batch_alter_table("entity_class_display_mode") as batch_op: batch_op.drop_constraint("pk_entity_class_display_mode", type_="primary") batch_op.create_primary_key("pk_display_mode", ["id"]) diff --git a/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py b/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py index 903f3db9..d7c5793c 100644 --- a/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py +++ b/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py @@ -22,8 +22,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.alter_column("parameter_id", new_column_name="parameter_definition_id", type_=sa.Integer) diff --git a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py index b4c0dbd4..a0275621 100644 --- a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py +++ b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py @@ -25,8 +25,8 @@ def upgrade(): conn = op.get_bind() Session = sessionmaker(bind=conn) session = Session() - meta = MetaData(conn) - meta.reflect() + meta = MetaData() + meta.reflect(conn) list_value = meta.tables["list_value"] parameter_definition = meta.tables["parameter_definition"] parameter_value = meta.tables["parameter_value"] diff --git a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py index dbff1e34..cdc4e55d 100644 --- a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py +++ b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py @@ -17,8 +17,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("entity_group_id", sa.Integer)) diff --git a/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py b/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py index 4a2902fd..e1ff6c50 100644 --- a/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py +++ b/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py @@ -9,6 +9,7 @@ from datetime import datetime from alembic import op import sqlalchemy as sa +from sqlalchemy import text # revision identifiers, used by Alembic. revision = "bba1e2ef5153" @@ -111,128 +112,138 @@ def insert_into_new_tables(): op.execute("""INSERT INTO entity_type (id, name) VALUES (2, "relationship")""") # More difficult ones conn = op.get_bind() - meta = sa.MetaData(conn) - meta.reflect() + meta = sa.MetaData() + meta.reflect(conn) # entity class level entity_classes = [ { "type_id": 1, - "name": r["name"], - "description": r["description"], - "display_order": r["display_order"], - "display_icon": r["display_icon"], - "hidden": r["hidden"], - "commit_id": r["commit_id"], + "name": r.name, + "description": r.description, + "display_order": r.display_order, + "display_icon": r.display_icon, + "hidden": r.hidden, + "commit_id": r.commit_id, } for r in conn.execute( - "SELECT name, description, display_order, display_icon, hidden, commit_id FROM object_class" + text("SELECT name, description, display_order, display_icon, hidden, commit_id FROM object_class") ) ] + [ { "type_id": 2, - "name": r["name"], + "name": r.name, "description": None, "display_order": None, "display_icon": None, - "hidden": r["hidden"], - "commit_id": r["commit_id"], + "hidden": r.hidden, + "commit_id": r.commit_id, } - for r in conn.execute("SELECT name, hidden, commit_id FROM relationship_class GROUP BY name") + for r in conn.execute(text("SELECT name, hidden, commit_id FROM relationship_class GROUP BY name")) ] op.bulk_insert(meta.tables["entity_class"], entity_classes) # Id mappings obj_cls_to_ent_cls = { - r["object_class_id"]: r["entity_class_id"] + r.object_class_id: r.entity_class_id for r in conn.execute( - """ + text( + """ SELECT object_class.id AS object_class_id, entity_class.id AS entity_class_id FROM object_class, entity_class WHERE entity_class.type_id = 1 AND object_class.name = entity_class.name """ + ) ) } rel_cls_to_ent_cls = { - r["relationship_class_id"]: r["entity_class_id"] + r.relationship_class_id: r.entity_class_id for r in conn.execute( - """ + text( + """ SELECT relationship_class.id AS relationship_class_id, entity_class.id AS entity_class_id FROM relationship_class, entity_class WHERE entity_class.type_id = 2 AND relationship_class.name = entity_class.name - GROUP BY relationship_class_id, entity_class_id + GROUP BY relationship_class_id, entity_class_id """ + ) ) } temp_relationship_classes = [ - {"entity_class_id": r["id"], "type_id": 2, "commit_id": r["commit_id"]} - for r in conn.execute("SELECT id, commit_id FROM entity_class WHERE type_id = 2") + {"entity_class_id": r.id, "type_id": 2, "commit_id": r.commit_id} + for r in conn.execute(text("SELECT id, commit_id FROM entity_class WHERE type_id = 2")) ] op.bulk_insert(meta.tables["temp_relationship_class"], temp_relationship_classes) relationship_entity_classes = [ { - "entity_class_id": rel_cls_to_ent_cls[r["id"]], - "dimension": r["dimension"], - "member_class_id": obj_cls_to_ent_cls[r["object_class_id"]], + "entity_class_id": rel_cls_to_ent_cls[r.id], + "dimension": r.dimension, + "member_class_id": obj_cls_to_ent_cls[r.object_class_id], "member_class_type_id": 1, - "commit_id": r["commit_id"], + "commit_id": r.commit_id, } - for r in conn.execute("SELECT id, dimension, object_class_id, commit_id FROM relationship_class") + for r in conn.execute(text("SELECT id, dimension, object_class_id, commit_id FROM relationship_class")) ] op.bulk_insert(meta.tables["relationship_entity_class"], relationship_entity_classes) # entity level entities = [ - {"type_id": 1, "class_id": obj_cls_to_ent_cls[r["class_id"]], "name": r["name"], "commit_id": r["commit_id"]} - for r in conn.execute("SELECT class_id, name, commit_id FROM object") + {"type_id": 1, "class_id": obj_cls_to_ent_cls[r.class_id], "name": r.name, "commit_id": r.commit_id} + for r in conn.execute(text("SELECT class_id, name, commit_id FROM object")) ] + [ - {"type_id": 2, "class_id": rel_cls_to_ent_cls[r["class_id"]], "name": r["name"], "commit_id": r["commit_id"]} - for r in conn.execute("SELECT class_id, name, commit_id FROM relationship GROUP BY class_id, name") + {"type_id": 2, "class_id": rel_cls_to_ent_cls[r.class_id], "name": r.name, "commit_id": r.commit_id} + for r in conn.execute(text("SELECT class_id, name, commit_id FROM relationship GROUP BY class_id, name")) ] op.bulk_insert(meta.tables["entity"], entities) # Id mappings obj_to_ent = { - r["object_id"]: r["entity_id"] + r.object_id: r.entity_id for r in conn.execute( - """ + text( + """ SELECT object.id AS object_id, entity.id AS entity_id FROM object, entity WHERE entity.type_id = 1 AND object.name = entity.name """ + ) ) } rel_to_ent = { - r["relationship_id"]: r["entity_id"] + r.relationship_id: r.entity_id for r in conn.execute( - """ + text( + """ SELECT relationship.id AS relationship_id, entity.id AS entity_id FROM relationship, entity WHERE entity.type_id = 2 AND relationship.name = entity.name - GROUP BY relationship_id, entity_id + GROUP BY relationship_id, entity_id """ + ) ) } temp_relationships = [ - {"entity_id": r["id"], "entity_class_id": r["class_id"], "type_id": 2, "commit_id": r["commit_id"]} - for r in conn.execute("SELECT id, class_id, commit_id FROM entity WHERE type_id = 2") + {"entity_id": r.id, "entity_class_id": r.class_id, "type_id": 2, "commit_id": r.commit_id} + for r in conn.execute(text("SELECT id, class_id, commit_id FROM entity WHERE type_id = 2")) ] op.bulk_insert(meta.tables["temp_relationship"], temp_relationships) relationship_entities = [ { - "entity_id": rel_to_ent[r["id"]], - "entity_class_id": rel_cls_to_ent_cls[r["class_id"]], - "dimension": r["dimension"], - "member_id": obj_to_ent[r["object_id"]], - "member_class_id": obj_cls_to_ent_cls[r["object_class_id"]], - "commit_id": r["commit_id"], + "entity_id": rel_to_ent[r.id], + "entity_class_id": rel_cls_to_ent_cls[r.class_id], + "dimension": r.dimension, + "member_id": obj_to_ent[r.object_id], + "member_class_id": obj_cls_to_ent_cls[r.object_class_id], + "commit_id": r.commit_id, } for r in conn.execute( - """ + text( + """ SELECT r.id, r.class_id, r.dimension, o.class_id AS object_class_id, r.object_id, r.commit_id FROM relationship AS r, object AS o WHERE r.object_id = o.id """ + ) ) ] op.bulk_insert(meta.tables["relationship_entity"], relationship_entities) @@ -291,42 +302,48 @@ def alter_tables_before_update(meta): def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_to_ent): conn = op.get_bind() - ent_to_ent_cls = {r["id"]: r["class_id"] for r in conn.execute("SELECT id, class_id FROM entity")} + ent_to_ent_cls = {r.id: r.class_id for r in conn.execute(text("SELECT id, class_id FROM entity"))} for object_class_id, entity_class_id in obj_cls_to_ent_cls.items(): conn.execute( - "UPDATE object_class SET entity_class_id = :entity_class_id, type_id = 1 WHERE id = :object_class_id", + text("UPDATE object_class SET entity_class_id = :entity_class_id, type_id = 1 WHERE id = :object_class_id"), entity_class_id=entity_class_id, object_class_id=object_class_id, ) conn.execute( - """ + text( + """ UPDATE parameter_definition SET entity_class_id = :entity_class_id WHERE object_class_id = :object_class_id - """, + """ + ), entity_class_id=entity_class_id, object_class_id=object_class_id, ) for relationship_class_id, entity_class_id in rel_cls_to_ent_cls.items(): conn.execute( - """ + text( + """ UPDATE parameter_definition SET entity_class_id = :entity_class_id WHERE relationship_class_id = :relationship_class_id - """, + """ + ), entity_class_id=entity_class_id, relationship_class_id=relationship_class_id, ) for object_id, entity_id in obj_to_ent.items(): conn.execute( - "UPDATE object SET entity_id = :entity_id, type_id = 1 WHERE id = :object_id", + text("UPDATE object SET entity_id = :entity_id, type_id = 1 WHERE id = :object_id"), entity_id=entity_id, object_id=object_id, ) entity_class_id = ent_to_ent_cls[entity_id] conn.execute( - """ + text( + """ UPDATE parameter_value SET entity_id = :entity_id, entity_class_id = :entity_class_id WHERE object_id = :object_id - """, + """ + ), entity_id=entity_id, entity_class_id=entity_class_id, object_id=object_id, @@ -334,28 +351,31 @@ def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_ for relationship_id, entity_id in rel_to_ent.items(): entity_class_id = ent_to_ent_cls[entity_id] conn.execute( - """ + text( + """ UPDATE parameter_value SET entity_id = :entity_id, entity_class_id = :entity_class_id WHERE relationship_id = :relationship_id - """, + """ + ), entity_id=entity_id, entity_class_id=entity_class_id, relationship_id=relationship_id, ) # Clean our potential mess. # E.g., I've seen parameter definitions with an invalid relationship_class_id for some reason...! - conn.execute("DELETE FROM parameter_definition WHERE entity_class_id IS NULL") - conn.execute("DELETE FROM parameter_value WHERE entity_class_id IS NULL OR entity_id IS NULL") + conn.execute(text("DELETE FROM parameter_definition WHERE entity_class_id IS NULL")) + conn.execute(text("DELETE FROM parameter_value WHERE entity_class_id IS NULL OR entity_id IS NULL")) if "next_id" not in meta.tables: return - row = conn.execute("SELECT MAX(id) FROM entity_class").fetchone() + row = conn.execute(text("SELECT MAX(id) FROM entity_class")).fetchone() entity_class_id = row[0] + 1 if row else 1 - row = conn.execute("SELECT MAX(id) FROM entity").fetchone() + row = conn.execute(text("SELECT MAX(id) FROM entity")).fetchone() entity_id = row[0] + 1 if row else 1 user = "alembic" date = datetime.utcnow() conn.execute( - """ + text( + """ UPDATE next_id SET user = :user, @@ -364,7 +384,8 @@ def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_ entity_type_id = 3, entity_class_id = :entity_class_id, entity_id = :entity_id - """, + """ + ), user=user, date=date, entity_class_id=entity_class_id, diff --git a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py index d42bc00a..03b758ca 100644 --- a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py +++ b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py @@ -48,6 +48,6 @@ def _get_scalar_values_by_id(table, value_label, type_label, connection): value_column = getattr(table.c, value_label) type_column = getattr(table.c, type_label) return { - row["id"]: row[value_label] + row.id: row._mapping[value_label] for row in connection.execute(sa.select([table.c.id, value_column]).where(type_column == None)) } diff --git a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py index 822e994b..68ebe50f 100644 --- a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py +++ b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py @@ -17,8 +17,8 @@ def upgrade(): - m = sa.MetaData(op.get_bind()) - m.reflect() + m = sa.MetaData() + m.reflect(op.get_bind()) if "next_id" in m.tables: with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("tool_id", sa.Integer, server_default=sa.null())) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 81133bb8..34b2122a 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -27,8 +27,8 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea Returns: tuple: list of entity classes to add, update and ids to remove """ - meta = sa.MetaData(conn) - meta.reflect() + meta = sa.MetaData() + meta.reflect(conn) lv_table = meta.tables["list_value"] pd_table = meta.tables["parameter_definition"] if use_existing_tool_feature_method: @@ -53,7 +53,7 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea # It's a new DB without tool/feature/method or we don't want to use them... # we take 'is_active' as feature and JSON "yes" and true as methods lv_id_by_pdef_id = { - x["parameter_definition_id"]: x["id"] + x.parameter_definition_id: x.id for x in conn.execute( sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")]) .where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id) @@ -63,10 +63,10 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea } # Collect 'is_active' default values list_value_id = sa.case( - [(pd_table.c.default_type == "list_value_ref", sa.cast(pd_table.c.default_value, sa.Integer()))], else_=None + (pd_table.c.default_type == "list_value_ref", sa.cast(pd_table.c.default_value, sa.Integer())), else_=None ) is_active_default_vals = [ - {c: x[c] for c in ("entity_class_id", "parameter_definition_id", "list_value_id")} + {c: x._mapping[c] for c in ("entity_class_id", "parameter_definition_id", "list_value_id")} for x in conn.execute( sa.select( [ @@ -114,8 +114,8 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe list: entity_alternative items to update list: parameter_value ids to remove """ - meta = sa.MetaData(conn) - meta.reflect() + meta = sa.MetaData() + meta.reflect(conn) ea_table = meta.tables["entity_alternative"] lv_table = meta.tables["list_value"] pv_table = meta.tables["parameter_value"] @@ -142,7 +142,7 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe # we take 'is_active' as feature and JSON "yes" and true as methods pd_table = meta.tables["parameter_definition"] lv_id_by_pdef_id = { - x["parameter_definition_id"]: x["id"] + x.parameter_definition_id: x.id for x in conn.execute( sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")]) .where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id) @@ -151,11 +151,9 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe ) } # Collect 'is_active' parameter values - list_value_id = sa.case( - [(pv_table.c.type == "list_value_ref", sa.cast(pv_table.c.value, sa.Integer()))], else_=None - ) + list_value_id = sa.case((pv_table.c.type == "list_value_ref", sa.cast(pv_table.c.value, sa.Integer())), else_=None) is_active_pvals = [ - {c: x[c] for c in ("id", "entity_id", "alternative_id", "parameter_definition_id", "list_value_id")} + {c: x._mapping[c] for c in ("id", "entity_id", "alternative_id", "parameter_definition_id", "list_value_id")} for x in conn.execute( sa.select([pv_table, list_value_id.label("list_value_id")]).where( pv_table.c.parameter_definition_id.in_(lv_id_by_pdef_id) @@ -164,7 +162,7 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe ] # Compute new entity_alternative items from 'is_active' parameter values, # where 'active' is True if the value of 'is_active' is the one from the tool_feature_method specification - current_ea_ids = {(x["entity_id"], x["alternative_id"]): x["id"] for x in conn.execute(sa.select([ea_table]))} + current_ea_ids = {(x.entity_id, x.alternative_id): x.id for x in conn.execute(sa.select([ea_table]))} new_ea_items = { (x["entity_id"], x["alternative_id"]): { "entity_id": x["entity_id"], diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 11df2525..3e777694 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -26,10 +26,11 @@ from alembic.migration import MigrationContext from alembic.script import ScriptDirectory from alembic.util.exc import CommandError -from sqlalchemy import MetaData, create_engine, inspect +from sqlalchemy import MetaData, create_engine, inspect, text from sqlalchemy.engine.url import URL, make_url from sqlalchemy.event import listen from sqlalchemy.exc import ArgumentError, DatabaseError, DBAPIError +from sqlalchemy.orm import Query, Session from sqlalchemy.pool import NullPool from .compatibility import compatibility_transformations from .db_mapping_base import DatabaseMappingBase, Status @@ -46,7 +47,6 @@ model_meta, ) from .mapped_items import item_factory -from .query import Query from .spine_db_client import get_db_url_from_server logging.getLogger("alembic").setLevel(logging.CRITICAL) @@ -152,7 +152,8 @@ def __init__( if isinstance(db_url, str): filter_configs, db_url = pop_filter_configs(db_url) elif isinstance(db_url, URL): - filter_configs = db_url.query.pop("spinedbfilter", []) + filter_configs = db_url.query.get("spinedbfilter", []) + db_url = db_url.difference_update_query("spinedbfilter") else: filter_configs = [] self._filter_configs = filter_configs if apply_filters else None @@ -171,18 +172,25 @@ def __init__( listen(self.engine, "close", self._receive_engine_close) if self._memory: copy_database_bind(self.engine, self._original_engine) - self._metadata = MetaData(self.engine) - self._metadata.reflect() + self._metadata = MetaData() + self._metadata.reflect(self.engine) self._tablenames = [t.name for t in self._metadata.sorted_tables] + self._connection = self.engine.connect() + self._session = Session(self._connection) if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) def __enter__(self): + if self._connection is None: + self._connection = self.engine.connect() + if self._session is None: + self._session = Session(self._connection) return self def __exit__(self, _exc_type, _exc_val, _exc_tb): self.close() + return False def __del__(self): self.close() @@ -302,7 +310,7 @@ def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_tim ) from None with engine.begin() as connection: if sa_url.drivername == "sqlite": - connection.execute("BEGIN IMMEDIATE") + connection.execute(text("BEGIN IMMEDIATE")) # TODO: Do other dialects need to lock? migration_context = MigrationContext.configure(connection) try: @@ -777,7 +785,7 @@ def fetch_all(self, *item_types): item_type = self.real_item_type(item_type) self.do_fetch_all(item_type, commit_count) - def query(self, *args, **kwargs): + def query(self, *entities, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. To perform custom ``SELECT`` statements, call this method with one or more of the documented @@ -805,9 +813,9 @@ def query(self, *args, **kwargs): ).group_by(db_map.entity_class_sq.c.name).all() Returns: - :class:`~spinedb_api.query.Query`: The resulting query. + :class:`~sqlalchemy.orm.Query`: The resulting query. """ - return Query(self.engine, *args) + return self._session.query(*entities, **kwargs) def commit_session(self, comment, apply_compatibility_transforms=True): """Commits the changes from the in-memory mapping to the database. @@ -866,7 +874,7 @@ def has_external_commits(self): return self._commit_count != self._query_commit_count() def close(self): - """Closes this DB mapping. This is only needed if you're keeping a long-lived session. + """Closes this DB mapping. For instance:: class MyDBMappingWrapper: @@ -885,6 +893,12 @@ def __del__(self): ... # db_map.close() is automatically called when leaving this block """ + if self._session is not None: + self._session.close() + self._session = None + if self._connection is not None: + self._connection.close() + self._connection = None self.closed = True def add_ext_entity_metadata(self, *items, **kwargs): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 311a7052..cf8c8ed4 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -298,8 +298,8 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs): if not qry: return [] if not limit: - return [dict(x) for x in qry] - return [dict(x) for x in qry.limit(limit).offset(offset)] + return [x._asdict() for x in qry] + return [x._asdict() for x in qry.limit(limit).offset(offset)] def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, **kwargs): """Fetches items from the DB and adds them to the mapping. diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 3fa4e698..e888b422 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -47,7 +47,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): if id_items: connection.execute(table.insert(), [x.resolve() for x in id_items]) if temp_id_items: - current_ids = {x["id"] for x in connection.execute(table.select())} + current_ids = {x.id for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 available_ids = set(range(1, next_id)) - current_ids required_id_count = len(temp_id_items) - len(available_ids) diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index da606d3c..29b9ef5b 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -13,7 +13,8 @@ from types import MethodType from sqlalchemy import Integer, Table, and_, case, cast, func, or_ from sqlalchemy.orm import aliased -from sqlalchemy.sql.expression import Alias, label +from sqlalchemy.sql import Subquery +from sqlalchemy.sql.expression import label from .helpers import forward_sweep, group_concat @@ -95,7 +96,7 @@ def _func(x, tables): getattr(self, attr) table_to_sq_attr = {} for attr, val in vars(self).items(): - if not isinstance(val, Alias): + if not isinstance(val, Subquery): continue tables = set() forward_sweep(val, _func, tables) @@ -124,7 +125,7 @@ def _subquery(self, tablename): tablename (str): the table to be queried. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ table = self._metadata.tables[tablename] return self.query(table).subquery(tablename + "_sq") @@ -138,7 +139,7 @@ def superclass_subclass_sq(self): SELECT * FROM superclass_subclass Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._superclass_subclass_sq is None: self._superclass_subclass_sq = self._subquery("superclass_subclass") @@ -153,7 +154,7 @@ def entity_class_sq(self): SELECT * FROM entity_class Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_class_sq is None: self._entity_class_sq = self._make_entity_class_sq() @@ -168,7 +169,7 @@ def entity_class_dimension_sq(self): SELECT * FROM entity_class_dimension Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_class_dimension_sq is None: self._entity_class_dimension_sq = self._subquery("entity_class_dimension") @@ -191,7 +192,7 @@ def wide_entity_class_sq(self): ec.id == ecd.entity_class_id Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._wide_entity_class_sq is None: entity_class_dimension_sq = ( @@ -258,7 +259,7 @@ def entity_sq(self): SELECT * FROM entity Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_sq is None: self._entity_sq = self._make_entity_sq() @@ -273,7 +274,7 @@ def entity_element_sq(self): SELECT * FROM entity_element Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_element_sq is None: self._entity_element_sq = self._make_entity_element_sq() @@ -295,7 +296,7 @@ def wide_entity_sq(self): e.id == ee.entity_id Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._wide_entity_sq is None: entity_element_sq = ( @@ -342,7 +343,7 @@ def entity_group_sq(self): SELECT * FROM entity_group Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_group_sq is None: self._entity_group_sq = self._make_entity_group_sq() @@ -357,7 +358,7 @@ def display_mode_sq(self): SELECT * FROM display_mode Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._display_mode_sq is None: self._display_mode_sq = self._subquery("display_mode") @@ -372,7 +373,7 @@ def entity_class_display_mode_sq(self): SELECT * FROM entity_class_display_mode Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_class_display_mode_sq is None: self._entity_class_display_mode_sq = self._subquery("entity_class_display_mode") @@ -387,7 +388,7 @@ def alternative_sq(self): SELECT * FROM alternative Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._alternative_sq is None: self._alternative_sq = self._make_alternative_sq() @@ -402,7 +403,7 @@ def scenario_sq(self): SELECT * FROM scenario Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._scenario_sq is None: self._scenario_sq = self._make_scenario_sq() @@ -417,7 +418,7 @@ def scenario_alternative_sq(self): SELECT * FROM scenario_alternative Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._scenario_alternative_sq is None: self._scenario_alternative_sq = self._make_scenario_alternative_sq() @@ -432,7 +433,7 @@ def entity_alternative_sq(self): SELECT * FROM entity_alternative Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_alternative_sq is None: self._entity_alternative_sq = self._make_entity_alternative_sq() @@ -447,7 +448,7 @@ def parameter_value_list_sq(self): SELECT * FROM parameter_value_list Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._parameter_value_list_sq is None: self._parameter_value_list_sq = self._subquery("parameter_value_list") @@ -462,7 +463,7 @@ def list_value_sq(self): SELECT * FROM list_value Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._list_value_sq is None: self._list_value_sq = self._subquery("list_value") @@ -477,7 +478,7 @@ def parameter_definition_sq(self): SELECT * FROM parameter_definition Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._parameter_definition_sq is None: @@ -485,7 +486,7 @@ def parameter_definition_sq(self): return self._parameter_definition_sq @property - def wide_parameter_definition_sq(self) -> Alias: + def wide_parameter_definition_sq(self) -> Subquery: """A subquery of the form: .. code-block:: sql @@ -545,7 +546,7 @@ def parameter_type_sq(self): SELECT * FROM parameter_type Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._parameter_type_sq is None: self._parameter_type_sq = self._subquery("parameter_type") @@ -560,7 +561,7 @@ def parameter_value_sq(self): SELECT * FROM parameter_value Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._parameter_value_sq is None: self._parameter_value_sq = self._make_parameter_value_sq() @@ -575,7 +576,7 @@ def metadata_sq(self): SELECT * FROM list_value Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._metadata_sq is None: self._metadata_sq = self._subquery("metadata") @@ -590,7 +591,7 @@ def parameter_value_metadata_sq(self): SELECT * FROM parameter_value_metadata Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._parameter_value_metadata_sq is None: self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata") @@ -605,7 +606,7 @@ def entity_metadata_sq(self): SELECT * FROM entity_metadata Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._entity_metadata_sq is None: self._entity_metadata_sq = self._subquery("entity_metadata") @@ -620,7 +621,7 @@ def commit_sq(self): SELECT * FROM commit Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~sqlalchemy.sql.Subquery` """ if self._commit_sq is None: commit_sq = self._subquery("commit") @@ -1263,7 +1264,7 @@ def _make_entity_class_sq(self): Creates a subquery for entity classes. Returns: - Alias: an entity class subquery + Subquery: an entity class subquery """ return self._subquery("entity_class") @@ -1272,7 +1273,7 @@ def _make_entity_sq(self): Creates a subquery for entities. Returns: - Alias: an entity subquery + Subquery: an entity subquery """ return self._subquery("entity") @@ -1281,7 +1282,7 @@ def _make_entity_element_sq(self): Creates a subquery for entity-elements. Returns: - Alias: an entity_element subquery + Subquery: an entity_element subquery """ return self._subquery("entity_element") @@ -1299,7 +1300,7 @@ def _make_entity_alternative_sq(self): Creates a subquery for entity-alternatives. Returns: - Alias: an entity_alternative subquery + Subquery: an entity_alternative subquery """ return self._subquery("entity_alternative") @@ -1308,18 +1309,18 @@ def _make_parameter_definition_sq(self): Creates a subquery for parameter definitions. Returns: - Alias: a parameter definition subquery + Subquery: a parameter definition subquery """ par_def_sq = self._subquery("parameter_definition") list_value_id = case( - [(par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer()))], else_=None + (par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer())), else_=None ) default_value = case( - [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value)], + (par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value), else_=par_def_sq.c.default_value, ) default_type = case( - [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type)], + (par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type), else_=par_def_sq.c.default_type, ) return ( @@ -1343,12 +1344,12 @@ def _make_parameter_value_sq(self): Creates a subquery for parameter values. Returns: - Alias: a parameter value subquery + Subquery: a parameter value subquery """ par_val_sq = self._subquery("parameter_value") - list_value_id = case([(par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer()))], else_=None) - value = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value)], else_=par_val_sq.c.value) - type_ = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type)], else_=par_val_sq.c.type) + list_value_id = case((par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer())), else_=None) + value = case((par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value), else_=par_val_sq.c.value) + type_ = case((par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type), else_=par_val_sq.c.type) return ( self.query( par_val_sq.c.id.label("id"), @@ -1371,7 +1372,7 @@ def _make_alternative_sq(self): Creates a subquery for alternatives. Returns: - Alias: an alternative subquery + Subquery: an alternative subquery """ return self._subquery("alternative") @@ -1380,7 +1381,7 @@ def _make_scenario_sq(self): Creates a subquery for scenarios. Returns: - Alias: a scenario subquery + Subquery: a scenario subquery """ return self._subquery("scenario") @@ -1389,7 +1390,7 @@ def _make_scenario_alternative_sq(self): Creates a subquery for scenario alternatives. Returns: - Alias: a scenario alternative subquery + Subquery: a scenario alternative subquery """ return self._subquery("scenario_alternative") @@ -1399,7 +1400,7 @@ def override_entity_class_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns entity class subquery as an :class:`Alias` object + returns entity class subquery as an :class:`Subquery` object """ self._make_entity_class_sq = MethodType(method, self) self._clear_subqueries("entity_class") @@ -1410,7 +1411,7 @@ def override_entity_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns entity subquery as an :class:`Alias` object + returns entity subquery as an :class:`Subquery` object """ self._make_entity_sq = MethodType(method, self) self._clear_subqueries("entity") @@ -1421,7 +1422,7 @@ def override_entity_element_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns entity_element subquery as an :class:`Alias` object + returns entity_element subquery as an :class:`Subquery` object """ self._make_entity_element_sq = MethodType(method, self) self._clear_subqueries("entity_element") @@ -1443,7 +1444,7 @@ def override_entity_alternative_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns entity alternative subquery as an :class:`Alias` object + returns entity alternative subquery as an :class:`Subquery` object """ self._make_entity_alternative_sq = MethodType(method, self) self._clear_subqueries("entity_alternative") @@ -1454,7 +1455,7 @@ def override_parameter_definition_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns parameter definition subquery as an :class:`Alias` object + returns parameter definition subquery as an :class:`Subquery` object """ self._make_parameter_definition_sq = MethodType(method, self) self._clear_subqueries("parameter_definition") @@ -1465,7 +1466,7 @@ def override_parameter_value_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns parameter value subquery as an :class:`Alias` object + returns parameter value subquery as an :class:`Subquery` object """ self._make_parameter_value_sq = MethodType(method, self) self._clear_subqueries("parameter_value") @@ -1476,7 +1477,7 @@ def override_alternative_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns alternative subquery as an :class:`Alias` object + returns alternative subquery as an :class:`Subquery` object """ self._make_alternative_sq = MethodType(method, self) self._clear_subqueries("alternative") @@ -1487,7 +1488,7 @@ def override_scenario_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns scenario subquery as an :class:`Alias` object + returns scenario subquery as an :class:`Subquery` object """ self._make_scenario_sq = MethodType(method, self) self._clear_subqueries("scenario") @@ -1498,7 +1499,7 @@ def override_scenario_alternative_sq_maker(self, method): Args: method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and - returns scenario alternative subquery as an :class:`Alias` object + returns scenario alternative subquery as an :class:`Subquery` object """ self._make_scenario_alternative_sq = MethodType(method, self) self._clear_subqueries("scenario_alternative") @@ -1544,62 +1545,54 @@ def restore_scenario_alternative_sq_maker(self): self._clear_subqueries("scenario_alternative") def _object_class_id(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None - ) + return case((self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id), else_=None) def _relationship_class_id(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id)], else_=None - ) + return case((self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id), else_=None) def _object_id(self): - return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id)], else_=None) + return case((self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id), else_=None) def _relationship_id(self): - return case([(self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id)], else_=None) + return case((self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id), else_=None) def _object_class_name(self): return case( - [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name)], else_=None + (self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name), else_=None ) def _relationship_class_name(self): return case( - [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name)], else_=None + (self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name), else_=None ) def _object_class_id_list(self): return case( - [ - ( - self.wide_entity_class_sq.c.dimension_id_list != None, - self.wide_relationship_class_sq.c.object_class_id_list, - ) - ], + ( + self.wide_entity_class_sq.c.dimension_id_list != None, + self.wide_relationship_class_sq.c.object_class_id_list, + ), else_=None, ) def _object_class_name_list(self): return case( - [ - ( - self.wide_entity_class_sq.c.dimension_id_list != None, - self.wide_relationship_class_sq.c.object_class_name_list, - ) - ], + ( + self.wide_entity_class_sq.c.dimension_id_list != None, + self.wide_relationship_class_sq.c.object_class_name_list, + ), else_=None, ) def _object_name(self): - return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name)], else_=None) + return case((self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name), else_=None) def _object_id_list(self): return case( - [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None + (self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list), else_=None ) def _object_name_list(self): return case( - [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None + (self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list), else_=None ) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 68a1e1d7..2a772e2e 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -10,8 +10,8 @@ # this program. If not, see . ###################################################################################################################### """ Functions for exporting data from a Spine database in a standard format. """ +from collections import namedtuple from operator import itemgetter -from sqlalchemy.util import KeyedTuple from .helpers import Asterisk from .parameter_value import from_database @@ -108,6 +108,9 @@ def _make_item_processor(db_map, tablename): return lambda item: (item,) +ValueRow = namedtuple("ValueRow", ["name", "value", "type"]) + + class _ParameterValueListProcessor: def __init__(self, value_items): self._value_items_by_list_id = {} @@ -116,7 +119,7 @@ def __init__(self, value_items): def __call__(self, item): for list_value_item in sorted(self._value_items_by_list_id.get(item.id, ()), key=lambda x: x.index): - yield KeyedTuple([item.name, list_value_item.value, list_value_item.type], ["name", "value", "type"]) + yield ValueRow(item.name, list_value_item.value, list_value_item.type) def export_parameter_value_lists(db_map, ids=Asterisk, parse_value=from_database): diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 7ceac7eb..6663abea 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from itertools import cycle, dropwhile, islice from sqlalchemy import and_ -from sqlalchemy.sql.expression import literal +from sqlalchemy.sql.expression import literal, null, select from ..mapping import Mapping, Position, is_pivoted, is_regular, unflatten from ..parameter_value import ( IndexedValue, @@ -146,19 +146,13 @@ def reconstruct(cls, position, value, header, filter_re, ignorable, mapping_dict mapping.set_ignorable(ignorable) return mapping - def add_query_columns(self, db_map, query): - """Adds columns to the mapping query if needed, and returns the new query. - - The base class implementation just returns the same query without adding any new columns. + def build_query_columns(self, db_map, columns): + """Appends columns needed to query the mapping's data into a list. Args: - db_map (DatabaseMapping) - query (Alias or dict) - - Returns: - Alias: expanded query, or the same if nothing to add. + db_map (DatabaseMapping): database mapping + columns (list of Column): list of columns to append to """ - return query def filter_query(self, db_map, query): """Filters the mapping query if needed, and returns the new query. @@ -167,10 +161,10 @@ def filter_query(self, db_map, query): Args: db_map (DatabaseMapping) - query (Alias or dict) + query (Subquery or dict) Returns: - Alias: filtered query, or the same if nothing to add. + Subquery: filtered query, or the same if nothing to add. """ return query @@ -202,10 +196,11 @@ def _build_query(self, db_map, title_state): """ mappings = self.flatten() # Start with empty query - qry = db_map.query(literal(None)) # Add columns + columns = [] for m in mappings: - qry = m.add_query_columns(db_map, qry) + m.build_query_columns(db_map, columns) + qry = db_map.query(*columns) # Apply filters for m in mappings: qry = m.filter_query(db_map, qry) @@ -234,11 +229,12 @@ def _build_title_query(self, db_map): if mappings[-1].position == Position.table_name: break mappings.pop(-1) - # Start with empty query - qry = db_map.query(literal(None)) - # Add columns + columns = [] for m in mappings: - qry = m.add_query_columns(db_map, qry) + m.build_query_columns(db_map, columns) + if not columns: + return None + qry = db_map.query(*columns) # Apply filters for m in mappings: qry = m.filter_query(db_map, qry) @@ -262,18 +258,17 @@ def _build_header_query(self, db_map, title_state, buddies): if m.position in (Position.header, Position.table_name) or m in flat_buddies: break mappings.pop(-1) - # Start with empty query - qry = db_map.query(literal(None)) - # Add columns + columns = [] for m in mappings: - qry = m.add_query_columns(db_map, qry) - # Apply filters + m.build_query_columns(db_map, columns) + if not columns: + return None + qry = db_map.query(*columns) for m in mappings: qry = m.filter_query(db_map, qry) # Apply special title filters (first, so we clean up the state) for m in mappings: qry = m.filter_query_by_title(qry, title_state) - # Apply standard title filters if not title_state: return qry # Use a _FilteredQuery, since building a subquery to query it again leads to parser stack overflow @@ -307,7 +302,7 @@ def _data(self, row): The base class implementation returns the field given by ``name_field()``. Args: - row (KeyedTuple) + row (Row) Returns: any @@ -349,7 +344,7 @@ def _get_rows(self, db_row): """Yields rows issued by this mapping for given database row. Args: - db_row (KeyedTuple) + db_row (Row) Returns: generator(dict) @@ -368,7 +363,7 @@ def get_rows_recursive(self, db_row): """Takes a database row and yields rows issued by this mapping and its children combined. Args: - db_row (KeyedTuple) + db_row (Row) Returns: generator(dict) @@ -415,7 +410,7 @@ def _title_state(self, db_row): to the corresponding field from the row. Args: - db_row (KeyedTuple) + db_row (Row) Returns: dict @@ -429,7 +424,7 @@ def _get_titles(self, db_row, limit=None): """Yields pairs (title, title state) issued by this mapping for given database row. Args: - db_row (KeyedTuple) + db_row (Row) limit (int, optional): yield only this many items Returns: @@ -452,7 +447,7 @@ def get_titles_recursive(self, db_row, limit=None): """Takes a database row and yields pairs (title, title state) issued by this mapping and its children combined. Args: - db_row (KeyedTuple) + db_row (Row) limit (int, optional): yield only this many items Returns: @@ -478,6 +473,9 @@ def _non_unique_titles(self, db_map, limit=None): tuple(str,dict): title, and associated title state dictionary """ qry = self._build_title_query(db_map) + if qry is None: + yield from self.get_titles_recursive((), limit=limit) + return for db_row in qry: yield from self.get_titles_recursive(db_row, limit=limit) @@ -512,9 +510,7 @@ def make_header_recursive(self, query, buddies): """Builds the header recursively. Args: - build_header_query (callable): a function that any mapping in the hierarchy can call to get the query - db_map (DatabaseMapping): database map - title_state (dict): title state + query (Query, optional): export query buddies (list of tuple): buddy mappings Returns @@ -525,7 +521,7 @@ def make_header_recursive(self, query, buddies): return {} return {self.position: self.header} header = self.child.make_header_recursive(query, buddies) - if self.position == Position.header: + if self.position == Position.header and query is not None: buddy = find_my_buddy(self, buddies) if buddy is not None: query.rewind() @@ -547,7 +543,9 @@ def make_header(self, db_map, title_state, buddies): Returns dict: a mapping from column index to string header """ - query = _Rewindable(self._build_header_query(db_map, title_state, buddies)) + query = self._build_header_query(db_map, title_state, buddies) + if query is not None: + query = _Rewindable(query) return self.make_header_recursive(query, buddies) @@ -582,8 +580,8 @@ def __init__(self, position, value, header="", filter_re=""): Args: position (int or Position, optional): mapping's position value (Any): value to yield - header (str, optional); A string column header that's yielt as 'first row', if not empty. - The default is an empty string (so it's not yielt). + header (str, optional); A string column header that's yielded as 'first row', if not empty. + The default is an empty string (so it's not yielded). filter_re (str, optional): A regular expression to filter the mapped values by """ super().__init__(position, value, header, filter_re) @@ -609,16 +607,15 @@ def __init__(self, position, value=None, header="", filter_re="", highlight_posi super().__init__(position, value, header, filter_re) self.highlight_position = highlight_position - def add_query_columns(self, db_map, query): - query = query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.wide_entity_class_sq.c.id.label("entity_class_id"), db_map.wide_entity_class_sq.c.name.label("entity_class_name"), db_map.wide_entity_class_sq.c.dimension_id_list.label("dimension_id_list"), db_map.wide_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), - ) + ] if self.highlight_position is not None: - query = query.add_columns(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id")) - return query + columns.append(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id")) def filter_query(self, db_map, query): if any(isinstance(m, (DimensionMapping, ElementMapping)) for m in self.flatten()): @@ -675,16 +672,15 @@ class EntityMapping(ExportMapping): MAP_TYPE = "Entity" - def add_query_columns(self, db_map, query): - query = query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name"), db_map.wide_entity_sq.c.element_id_list, db_map.wide_entity_sq.c.element_name_list, - ) + ] if self.query_parents("highlight_position") is not None: - query = query.add_columns(db_map.entity_element_sq.c.element_id.label("highlighted_element_id")) - return query + columns.append(db_map.entity_element_sq.c.element_id.label("highlighted_element_id")) def filter_query(self, db_map, query): query = query.outerjoin( @@ -727,8 +723,8 @@ class EntityGroupMapping(ExportMapping): MAP_TYPE = "EntityGroup" - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name) + def build_query_columns(self, db_map, columns): + columns += [db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name] def filter_query(self, db_map, query): return query.outerjoin( @@ -756,10 +752,8 @@ class EntityGroupEntityMapping(ExportMapping): MAP_TYPE = "EntityGroupEntity" - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name") - ) + def build_query_columns(self, db_map, columns): + columns += [db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name")] def filter_query(self, db_map, query): return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.wide_entity_sq.c.id) @@ -864,11 +858,11 @@ class ParameterDefinitionMapping(ExportMapping): MAP_TYPE = "ParameterDefinition" - def add_query_columns(self, db_map, query): - return query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.parameter_definition_sq.c.id.label("parameter_definition_id"), db_map.parameter_definition_sq.c.name.label("parameter_definition_name"), - ) + ] def filter_query(self, db_map, query): if self.query_parents("highlight_position") is not None: @@ -898,10 +892,8 @@ class ParameterDefaultValueMapping(ExportMapping): MAP_TYPE = "ParameterDefaultValue" - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type - ) + def build_query_columns(self, db_map, columns): + columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type] @staticmethod def name_field(): @@ -986,12 +978,10 @@ class ParameterDefaultValueIndexMapping(_MappingWithLeafMixin, ExportMapping): MAP_TYPE = "ParameterDefaultValueIndex" - def add_query_columns(self, db_map, query): - if "default_value" in set(query.column_names()): - return query - return query.add_columns( - db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type - ) + def build_query_columns(self, db_map, columns): + if any(c.name == "default_value" for c in columns): + return + columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type] def _expand_data(self, data): yield from _expand_indexed_data(data, self) @@ -1046,11 +1036,11 @@ class ParameterValueMapping(ExportMapping): MAP_TYPE = "ParameterValue" _selects_value = False - def add_query_columns(self, db_map, query): - if "value" in set(query.column_names()): - return query + def build_query_columns(self, db_map, columns): + if any(c.name == "value" for c in columns): + return self._selects_value = True - return query.add_columns(db_map.parameter_value_sq.c.value, db_map.parameter_value_sq.c.type) + columns += [db_map.parameter_value_sq.c.value, db_map.parameter_value_sq.c.type] def filter_query(self, db_map, query): if not self._selects_value: @@ -1113,7 +1103,7 @@ def filter_query_by_title(self, query, title_state): pv = title_state.pop("type_and_dimensions", None) if pv is None: return query - if "value" not in set(query.column_names()): + if all(d["name"] != "value" for d in query.column_descriptions): return query return _FilteredQuery( query, lambda db_row: (db_row.type, from_database_to_dimension_count(db_row.value, db_row.type) == pv) @@ -1190,11 +1180,11 @@ class ParameterValueListMapping(ExportMapping): MAP_TYPE = "ParameterValueList" - def add_query_columns(self, db_map, query): - return query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.parameter_value_list_sq.c.id.label("parameter_value_list_id"), db_map.parameter_value_list_sq.c.name.label("parameter_value_list_name"), - ) + ] def filter_query(self, db_map, query): if self.parent is None: @@ -1222,8 +1212,8 @@ class ParameterValueListValueMapping(ExportMapping): MAP_TYPE = "ParameterValueListValue" - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.ord_list_value_sq.c.value, db_map.ord_list_value_sq.c.type) + def build_query_columns(self, db_map, columns): + columns += [db_map.ord_list_value_sq.c.value, db_map.ord_list_value_sq.c.type] def filter_query(self, db_map, query): return query.filter(db_map.ord_list_value_sq.c.parameter_value_list_id == db_map.parameter_value_list_sq.c.id) @@ -1252,12 +1242,12 @@ class AlternativeMapping(ExportMapping): MAP_TYPE = "Alternative" - def add_query_columns(self, db_map, query): - return query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.alternative_sq.c.id.label("alternative_id"), db_map.alternative_sq.c.name.label("alternative_name"), db_map.alternative_sq.c.description.label("description"), - ) + ] def filter_query(self, db_map, query): parent = self.parent @@ -1284,12 +1274,12 @@ class ScenarioMapping(ExportMapping): MAP_TYPE = "Scenario" - def add_query_columns(self, db_map, query): - return query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.scenario_sq.c.id.label("scenario_id"), db_map.scenario_sq.c.name.label("scenario_name"), db_map.scenario_sq.c.description.label("description"), - ) + ] @staticmethod def name_field(): @@ -1308,18 +1298,19 @@ class ScenarioAlternativeMapping(ExportMapping): MAP_TYPE = "ScenarioAlternative" - def add_query_columns(self, db_map, query): + def build_query_columns(self, db_map, columns): if self._child is None: - return query.add_columns( + columns += [ db_map.ext_scenario_sq.c.alternative_id, db_map.ext_scenario_sq.c.alternative_name, db_map.ext_scenario_sq.c.rank, - ) - # Legacy: expecting child to be ScenarioBeforeAlternativeMapping - return query.add_columns( - db_map.ext_linked_scenario_alternative_sq.c.alternative_id, - db_map.ext_linked_scenario_alternative_sq.c.alternative_name, - ) + ] + else: + # Legacy: expecting child to be ScenarioBeforeAlternativeMapping + columns += [ + db_map.ext_linked_scenario_alternative_sq.c.alternative_id, + db_map.ext_linked_scenario_alternative_sq.c.alternative_name, + ] def filter_query(self, db_map, query): if self._child is None: @@ -1354,11 +1345,11 @@ class ScenarioBeforeAlternativeMapping(ExportMapping): MAP_TYPE = "ScenarioBeforeAlternative" - def add_query_columns(self, db_map, query): - return query.add_columns( + def build_query_columns(self, db_map, columns): + columns += [ db_map.ext_linked_scenario_alternative_sq.c.before_alternative_id, db_map.ext_linked_scenario_alternative_sq.c.before_alternative_name, - ) + ] @staticmethod def name_field(): diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index 79e144c4..0dbad0ab 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -202,8 +202,8 @@ def _make_renaming_entity_class_sq(db_map, state): subquery = state.original_entity_class_sq if not state.id_to_name: return subquery - cases = [(subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items()] - new_class_name = case(cases, else_=subquery.c.name) # if not in the name map, just keep the original name + cases = ((subquery.c.id == id_, new_name) for id_, new_name in state.id_to_name.items()) + new_class_name = case(*cases, else_=subquery.c.name) # if not in the name map, just keep the original name entity_class_sq = db_map.query( subquery.c.id, new_class_name.label("name"), @@ -262,8 +262,8 @@ def _make_renaming_parameter_definition_sq(db_map, state): subquery = state.original_parameter_definition_sq if not state.id_to_name: return subquery - cases = [(subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items()] - new_parameter_name = case(cases, else_=subquery.c.name) # if not in the name map, just keep the original name + cases = ((subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items()) + new_parameter_name = case(*cases, else_=subquery.c.name) # if not in the name map, just keep the original name parameter_definition_sq = db_map.query( subquery.c.id, new_parameter_name.label("name"), diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 7a9cae29..a53f53c8 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -182,8 +182,8 @@ def _make_parameter_value_transforming_sq(db_map, state): ] statements += [select([literal(i), literal(v), literal(t)]) for i, v, t in transformed_rows[1:]] temp_sq = union_all(*statements).alias("transformed_values") - new_value = case([(temp_sq.c.transformed_value != None, temp_sq.c.transformed_value)], else_=subquery.c.value) - new_type = case([(temp_sq.c.transformed_type != None, temp_sq.c.transformed_type)], else_=subquery.c.type) + new_value = case((temp_sq.c.transformed_value != None, temp_sq.c.transformed_value), else_=subquery.c.value) + new_type = case((temp_sq.c.transformed_type != None, temp_sq.c.transformed_type), else_=subquery.c.type) parameter_value_sq = ( db_map.query( subquery.c.id.label("id"), diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 9ed7b87e..ef310f4e 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -45,6 +45,7 @@ inspect, null, select, + text, true, ) from sqlalchemy.dialects.mysql import DOUBLE, TINYINT @@ -106,19 +107,6 @@ def name_from_dimensions(dimensions): return name_from_elements(dimensions) -# NOTE: Deactivated since foreign keys are too difficult to get right in the diff tables. -# For example, the diff_object table would need a `class_id` field and a `diff_class_id` field, -# plus a CHECK constraint that at least one of the two is NOT NULL. -# @event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection, connection_record): - module_name = dbapi_connection.__class__.__module__ - if not module_name.lower().startswith("sqlite"): - return - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - @compiles(TINYINT, "sqlite") def compile_TINYINT_mysql_sqlite(element, compiler, **kw): """Handles mysql TINYINT datatype as INTEGER in sqlite.""" @@ -134,6 +122,7 @@ def compile_DOUBLE_mysql_sqlite(element, compiler, **kw): class group_concat(FunctionElement): type = String() name = "group_concat" + inherit_cache = True def _parse_group_concat_clauses(clauses): @@ -514,7 +503,7 @@ def create_spine_metadata(): Column("display_order", Integer, nullable=False), Column( "display_status", - Enum(DisplayStatus, name="display_status_enum"), + Enum(DisplayStatus, name="display_status_enum", create_constraint=True), server_default=DisplayStatus.visible.name, nullable=False, ), @@ -682,17 +671,17 @@ def create_new_spine_database(db_url): def create_new_spine_database_from_bind(bind): # Drop existing tables. This is a Spine db now... - meta = MetaData(bind) - meta.reflect() - meta.drop_all() + meta = MetaData() + meta.reflect(bind) + meta.drop_all(bind) # Create new tables meta = create_spine_metadata() version = get_head_alembic_version() try: meta.create_all(bind) - bind.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") - bind.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - bind.execute(f"INSERT INTO alembic_version VALUES ('{version}')") + bind.execute(text("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')")) + bind.execute(text("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)")) + bind.execute(text(f"INSERT INTO alembic_version VALUES ('{version}')")) except DatabaseError as e: raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None @@ -706,8 +695,8 @@ def _create_first_spine_database(db_url): except DatabaseError as e: raise SpineDBAPIError(f"Could not connect to '{db_url}': {e.orig.args}") from None # Drop existing tables. This is a Spine db now... - meta = MetaData(engine) - meta.reflect() + meta = MetaData() + meta.reflect(engine) meta.drop_all(engine) # Create new tables meta = MetaData(naming_convention=naming_convention) diff --git a/spinedb_api/query.py b/spinedb_api/query.py deleted file mode 100644 index 8a78bcbc..00000000 --- a/spinedb_api/query.py +++ /dev/null @@ -1,157 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -"""The :class:`Query` class.""" - -from sqlalchemy import and_, select -from sqlalchemy.sql.functions import count -from .exception import SpineDBAPIError - - -class Query: - """A clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`.""" - - def __init__(self, bind, *entities): - """ - Args: - bind(Engine or Connection): An engine or connection to a DB against which the query will be executed. - entities(Iterable): A sequence of SQL expressions. - """ - self._bind = bind - self._entities = entities - self._select = select(entities) - self._from = None - - def __str__(self): - return str(self._select) - - @property - def column_descriptions(self): - return [{"name": c.name} for c in self._select.columns] - - def column_names(self): - yield from (c.name for c in self._select.columns) - - def subquery(self, name=None): - return self._select.alias(name) - - def add_columns(self, *columns): - self._entities += columns - self._select = select(self._entities) - return self - - def filter(self, *clauses): - for clause in clauses: - self._select = self._select.where(clause) - return self - - def filter_by(self, **kwargs): - if len(self._entities) != 1: - raise SpineDBAPIError(f"can't find a unique 'from-clause' to filter, candidates are {self._entities}") - return self.filter(and_(getattr(self._entities[0].c, k) == v for k, v in kwargs.items())) - - def _get_from(self, right, on): - if self._from is not None: - return self._from - from_candidates = (set(_get_descendant_tables(on)) - {right}) & set(self._select.get_children()) - if len(from_candidates) != 1: - raise SpineDBAPIError(f"can't find a unique 'from-clause' to join into, candidates are {from_candidates}") - return next(iter(from_candidates)) - - def join(self, right, on, isouter=False): - self._from = self._get_from(right, on).join(right, on, isouter=isouter) - self._select = self._select.select_from(self._from) - return self - - def outerjoin(self, right, on): - return self.join(right, on, isouter=True) - - def order_by(self, *args): - self._select = self._select.order_by(*args) - return self - - def group_by(self, *args): - self._select = self._select.group_by(*args) - return self - - def limit(self, *args): - self._select = self._select.limit(*args) - return self - - def offset(self, *args): - self._select = self._select.offset(*args) - return self - - def distinct(self, *args): - self._select = self._select.distinct(*args) - return self - - def having(self, *args): - self._select = self._select.having(*args) - return self - - def _result(self): - return self._bind.execute(self._select) - - def all(self): - return self._result().fetchall() - - def first(self): - return self._result().first() - - def one(self): - result = self._result() - first = result.fetchone() - if first is None: - return SpineDBAPIError("no results found for one()") - second = result.fetchone() - if second is not None: - raise SpineDBAPIError("multiple results found for one()") - return first - - def one_or_none(self): - result = self._result() - first = result.fetchone() - if first is None: - return None - second = result.fetchone() - if second is not None: - raise SpineDBAPIError("multiple results found for one_or_none()") - return first - - def scalar(self): - return self._result().scalar() - - def count(self): - return self._bind.execute(select([count()]).select_from(self._select)).scalar() - - def __iter__(self): - return self._result() or iter([]) - - -def _get_leaves(parent): - children = parent.get_children() - if not children: - try: - yield parent.table - except AttributeError: - pass - for child in children: - yield from _get_leaves(child) - - -def _get_descendant_tables(on): - for x in on.get_children(): - try: - yield x.table - except AttributeError: - yield from _get_descendant_tables(x) diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index 82e9a0c3..e8e375e4 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -39,7 +39,9 @@ class TestEntityClassRenamer(unittest.TestCase): @classmethod def setUpClass(cls): cls._temp_dir = TemporaryDirectory() - cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_entity_class_renamer.sqlite").as_posix()) + cls._db_url = URL.create( + "sqlite", database=Path(cls._temp_dir.name, "test_entity_class_renamer.sqlite").as_posix() + ) def setUp(self): create_new_spine_database(self._db_url) @@ -62,11 +64,10 @@ def test_renaming_singe_entity_class(self): classes = list(self._db_map.query(self._db_map.entity_class_sq).all()) self.assertEqual(len(classes), 1) class_row = classes[0] - keys = tuple(class_row.keys()) expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default") - self.assertEqual(len(keys), len(expected_keys)) + self.assertEqual(len(class_row._fields), len(expected_keys)) for expected_key in expected_keys: - self.assertIn(expected_key, keys) + self.assertIn(expected_key, class_row._fields) self.assertEqual(class_row.name, "new_name") def test_renaming_singe_relationship_class(self): @@ -119,11 +120,10 @@ def test_entity_class_renamer_from_dict(self): classes = list(self._db_map.query(self._db_map.entity_class_sq).all()) self.assertEqual(len(classes), 1) class_row = classes[0] - keys = tuple(class_row.keys()) expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default") - self.assertEqual(len(keys), len(expected_keys)) + self.assertEqual(len(class_row._fields), len(expected_keys)) for expected_key in expected_keys: - self.assertIn(expected_key, keys) + self.assertIn(expected_key, class_row._fields) self.assertEqual(class_row.name, "new_name") @@ -144,7 +144,9 @@ class TestParameterRenamer(unittest.TestCase): @classmethod def setUpClass(cls): cls._temp_dir = TemporaryDirectory() - cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_parameter_renamer.sqlite").as_posix()) + cls._db_url = URL.create( + "sqlite", database=Path(cls._temp_dir.name, "test_parameter_renamer.sqlite").as_posix() + ) def setUp(self): create_new_spine_database(self._db_url) @@ -168,7 +170,6 @@ def test_renaming_single_parameter(self): parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all()) self.assertEqual(len(parameters), 1) parameter_row = parameters[0] - keys = tuple(parameter_row.keys()) expected_keys = ( "id", "name", @@ -180,9 +181,9 @@ def test_renaming_single_parameter(self): "commit_id", "parameter_value_list_id", ) - self.assertEqual(len(keys), len(expected_keys)) + self.assertEqual(len(parameter_row._fields), len(expected_keys)) for expected_key in expected_keys: - self.assertIn(expected_key, keys) + self.assertIn(expected_key, parameter_row._fields) self.assertEqual(parameter_row.name, "new_name") def test_renaming_applies_to_correct_parameter(self): @@ -214,7 +215,6 @@ def test_parameter_renamer_from_dict(self): parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all()) self.assertEqual(len(parameters), 1) parameter_row = parameters[0] - keys = tuple(parameter_row.keys()) expected_keys = ( "id", "name", @@ -226,9 +226,9 @@ def test_parameter_renamer_from_dict(self): "commit_id", "parameter_value_list_id", ) - self.assertEqual(len(keys), len(expected_keys)) + self.assertEqual(len(parameter_row._fields), len(expected_keys)) for expected_key in expected_keys: - self.assertIn(expected_key, keys) + self.assertIn(expected_key, parameter_row._fields) self.assertEqual(parameter_row.name, "new_name") diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index d27cacc4..a98c0515 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -44,7 +44,7 @@ def test_filter_entities_with_default_activity_only(self): apply_filter_stack(db_map, [scenario_filter_config("S")]) entities = db_map.query(db_map.wide_entity_sq).all() self.assertEqual(len(entities), 1) - self.assertEqual(entities[0]["name"], "visible_object") + self.assertEqual(entities[0].name, "visible_object") def test_filter_entities_with_default_activity(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -79,8 +79,8 @@ def test_filter_entities_with_default_activity(self): apply_filter_stack(db_map, [scenario_filter_config("S")]) entities = db_map.query(db_map.wide_entity_sq).all() self.assertEqual(len(entities), 2) - self.assertEqual(entities[0]["name"], "visible") - self.assertEqual(entities[1]["name"], "visible") + self.assertEqual(entities[0].name, "visible") + self.assertEqual(entities[1].name, "visible") def test_filter_entity_that_is_not_active_in_scenario(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -323,14 +323,14 @@ def test_scenario_filter(self): url = "sqlite:///" + str(Path(temp_dir, "db.sqlite")) with DatabaseMapping(url, create=True) as db_map: self._build_data_with_single_scenario(db_map) - with DatabaseMapping(url, create=True): + with DatabaseMapping(url) as db_map: apply_scenario_filter_to_subqueries(db_map, "scenario") parameters = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -353,7 +353,7 @@ def test_scenario_filter_uncommitted_data(self): apply_scenario_filter_to_subqueries(db_map, "scenario") parameters = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 0) - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual( alternatives, [{"name": "Base", "description": "Base alternative", "id": 1, "commit_id": 1}] ) @@ -438,7 +438,7 @@ def test_scenario_filter_works_for_entity_sq(self): entity_names = { name for x in entities - for name in (x["element_name_list"].split(",") if x["element_name_list"] else (x["name"],)) + for name in (x.element_name_list.split(",") if x.element_name_list else (x.name,)) } self.assertFalse("obj2" in entity_names) @@ -452,9 +452,9 @@ def test_scenario_filter_works_for_object_parameter_value_sq(self): parameters = db_map.query(db_map.object_parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -514,9 +514,9 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self): parameters = db_map.query(db_map.relationship_parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual((parameters[0].value, parameters[0].type), to_database(23.0)) - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -610,7 +610,7 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): parameters = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual((parameters[0].value, parameters[0].type), to_database(2000.0)) - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual( alternatives, [ @@ -619,7 +619,7 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, ], ) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -735,7 +735,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s parameters = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual((parameters[0].value, parameters[0].type), to_database(2000.0)) - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual( alternatives, [ @@ -744,7 +744,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, ], ) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -875,9 +875,9 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): apply_scenario_filter_to_subqueries(db_map, "scenario") parameters = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 4) - object_names = {o.id: o.name for o in db_map.query(db_map.object_sq).all()} - alternative_names = {a.id: a.name for a in db_map.query(db_map.alternative_sq).all()} - parameter_names = {d.id: d.name for d in db_map.query(db_map.parameter_definition_sq).all()} + object_names = {o.id: o.name for o in db_map.query(db_map.object_sq)} + alternative_names = {a.id: a.name for a in db_map.query(db_map.alternative_sq)} + parameter_names = {d.id: d.name for d in db_map.query(db_map.parameter_definition_sq)} datamined_values = {} for parameter in parameters: self.assertEqual(alternative_names[parameter.alternative_id], "alternative") @@ -890,9 +890,9 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): "object2": {"parameter1": b"20.0", "parameter2": b"22.0"}, }, ) - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -940,7 +940,7 @@ def test_filters_scenarios_and_alternatives(self): db_map.commit_session("Add test data.") with DatabaseMapping(url, create=True) as db_map: apply_scenario_filter_to_subqueries(db_map, "scenario2") - alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)] + alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] self.assertEqual( alternatives, [ @@ -948,7 +948,7 @@ def test_filters_scenarios_and_alternatives(self): {"name": "alternative3", "description": None, "id": 4, "commit_id": 2}, ], ) - scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()] + scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)] self.assertEqual( scenarios, [ @@ -1074,10 +1074,10 @@ def test_parameter_values_for_entities_that_have_been_filtered_out_by_default(se with DatabaseMapping(filtered_url) as db_map: entities = db_map.query(db_map.entity_sq).all() self.assertEqual(len(entities), 1) - self.assertEqual(entities[0]["name"], "visible widget") + self.assertEqual(entities[0].name, "visible widget") values = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(values), 1) - self.assertEqual(from_database(values[0]["value"], values[0]["type"]), 2.3) + self.assertEqual(from_database(values[0].value, values[0].type), 2.3) def test_parameter_values_for_entities_that_swim_against_active_by_default(self): with TemporaryDirectory() as temp_dir: @@ -1136,10 +1136,10 @@ def test_parameter_values_for_entities_that_swim_against_active_by_default(self) with DatabaseMapping(filtered_url) as db_map: entities = db_map.query(db_map.entity_sq).all() self.assertEqual(len(entities), 1) - self.assertEqual(entities[0]["name"], "invisible widget") + self.assertEqual(entities[0].name, "invisible widget") values = db_map.query(db_map.parameter_value_sq).all() self.assertEqual(len(values), 1) - self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3) + self.assertEqual(from_database(values[0].value, values[0].type), -2.3) def test_parameter_values_of_multidimensional_entity_whose_elements_have_entity_alternatives(self): with DatabaseMapping("sqlite://", create=True) as db_map: diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py deleted file mode 100644 index a9921663..00000000 --- a/tests/filters/test_tool_filter.py +++ /dev/null @@ -1,221 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Unit tests for ``tool_entity_filter`` module. - -""" -from pathlib import Path -from tempfile import TemporaryDirectory -import unittest -from sqlalchemy.engine.url import URL -from spinedb_api import ( - DatabaseMapping, - SpineDBAPIError, - create_new_spine_database, - import_object_classes, - import_object_parameter_values, - import_object_parameters, - import_objects, - import_parameter_value_lists, - import_relationship_classes, - import_relationship_parameter_values, - import_relationship_parameters, - import_relationships, -) - - -@unittest.skip("obsolete, but need to adapt into the scenario filter") -class TestToolEntityFilter(unittest.TestCase): - _db_url = None - _temp_dir = None - - @classmethod - def setUpClass(cls): - cls._temp_dir = TemporaryDirectory() - cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_tool_filter_mapping.sqlite").as_posix()) - - def setUp(self): - create_new_spine_database(self._db_url) - self._db_map = DatabaseMapping(self._db_url) - - def tearDown(self): - self._db_map.close() - - def _build_data_with_tools(self): - import_object_classes(self._db_map, ["object_class"]) - import_objects( - self._db_map, - [ - ("object_class", "object1"), - ("object_class", "object2"), - ("object_class", "object3"), - ("object_class", "object4"), - ], - ) - import_parameter_value_lists( - self._db_map, [("methods", "methodA"), ("methods", "methodB"), ("methods", "methodC")] - ) - import_object_parameters( - self._db_map, - [ - ("object_class", "parameter1", "methodA", "methods"), - ("object_class", "parameter2", "methodC", "methods"), - ], - ) - import_object_parameter_values( - self._db_map, - [ - ("object_class", "object1", "parameter1", "methodA"), - ("object_class", "object2", "parameter1", "methodB"), - ("object_class", "object3", "parameter1", "methodC"), - ("object_class", "object4", "parameter1", "methodB"), - ("object_class", "object2", "parameter2", "methodA"), - ("object_class", "object3", "parameter2", "methodC"), - ], - ) - import_tools(self._db_map, ["tool1", "tool2"]) - import_features(self._db_map, [("object_class", "parameter1"), ("object_class", "parameter2")]) - import_tool_features( - self._db_map, - [("tool1", "object_class", "parameter1", False), ("tool2", "object_class", "parameter1", False)], - ) - - def test_non_existing_tool_filter_raises(self): - self._build_data_with_tools() - self._db_map.commit_session("Add test data") - self.assertRaises(SpineDBAPIError, apply_tool_filter_to_entity_sq, self._db_map, "notool") - - def test_tool_feature_no_filter(self): - self._build_data_with_tools() - self._db_map.commit_session("Add test data") - apply_tool_filter_to_entity_sq(self._db_map, "tool1") - entities = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(len(entities), 4) - names = [x.name for x in entities] - self.assertIn("object1", names) - self.assertIn("object2", names) - self.assertIn("object3", names) - self.assertIn("object4", names) - - def test_tool_feature_required(self): - self._build_data_with_tools() - import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)]) - self._db_map.commit_session("Add test data") - apply_tool_filter_to_entity_sq(self._db_map, "tool1") - entities = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(len(entities), 2) - names = [x.name for x in entities] - self.assertIn("object2", names) - self.assertIn("object3", names) - - def test_tool_feature_method(self): - self._build_data_with_tools() - import_tool_feature_methods( - self._db_map, - [("tool1", "object_class", "parameter1", "methodB"), ("tool2", "object_class", "parameter1", "methodC")], - ) - self._db_map.commit_session("Add test data") - apply_tool_filter_to_entity_sq(self._db_map, "tool1") - entities = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(len(entities), 2) - names = [x.name for x in entities] - self.assertIn("object2", names) - self.assertIn("object4", names) - - def test_tool_feature_required_and_method(self): - self._build_data_with_tools() - import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)]) - import_tool_feature_methods( - self._db_map, - [("tool1", "object_class", "parameter1", "methodB"), ("tool2", "object_class", "parameter1", "methodC")], - ) - self._db_map.commit_session("Add test data") - apply_tool_filter_to_entity_sq(self._db_map, "tool1") - entities = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(len(entities), 1) - self.assertEqual(entities[0].name, "object2") - - def test_tool_filter_config(self): - config = tool_filter_config("tool name") - self.assertEqual(config, {"type": "tool_filter", "tool": "tool name"}) - - def test_tool_filter_from_dict(self): - self._build_data_with_tools() - import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)]) - self._db_map.commit_session("Add test data") - config = tool_filter_config("tool1") - tool_filter_from_dict(self._db_map, config) - entities = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(len(entities), 2) - names = [x.name for x in entities] - self.assertIn("object2", names) - self.assertIn("object3", names) - - def test_tool_filter_config_to_shorthand(self): - config = tool_filter_config("tool name") - shorthand = tool_filter_config_to_shorthand(config) - self.assertEqual(shorthand, "tool:tool name") - - def test_tool_filter_shorthand_to_config(self): - config = tool_filter_shorthand_to_config("tool:tool name") - self.assertEqual(config, {"type": "tool_filter", "tool": "tool name"}) - - def test_object_activity_control_filter(self): - import_object_classes(self._db_map, ["node", "unit"]) - import_relationship_classes(self._db_map, [["node__unit", ["node", "unit"]]]) - import_objects(self._db_map, [("node", "node1"), ("node", "node2"), ("unit", "unita"), ("unit", "unitb")]) - import_relationships( - self._db_map, - [ - ["node__unit", ["node1", "unita"]], - ["node__unit", ["node1", "unitb"]], - ["node__unit", ["node2", "unita"]], - ], - ) - import_parameter_value_lists(self._db_map, [("boolean", True), ("boolean", False)]) - import_object_parameters(self._db_map, [("node", "is_active", True, "boolean")]) - import_relationship_parameters(self._db_map, [("node__unit", "x")]) - import_object_parameter_values(self._db_map, [("node", "node1", "is_active", False)]) - import_relationship_parameter_values( - self._db_map, - [ - ["node__unit", ["node1", "unita"], "x", 5], - ["node__unit", ["node1", "unitb"], "x", 7], - ["node__unit", ["node2", "unita"], "x", 11], - ], - ) - import_tools(self._db_map, ["obj_act_ctrl"]) - import_features(self._db_map, [("node", "is_active")]) - import_tool_features(self._db_map, [("obj_act_ctrl", "node", "is_active", False)]) - import_tool_feature_methods(self._db_map, [("obj_act_ctrl", "node", "is_active", True)]) - self._db_map.commit_session("Add obj act ctrl filter") - apply_tool_filter_to_entity_sq(self._db_map, "obj_act_ctrl") - objects = self._db_map.query(self._db_map.object_sq).all() - self.assertEqual(len(objects), 3) - object_names = [x.name for x in objects] - self.assertTrue("node1" not in object_names) - self.assertTrue("node2" in object_names) - self.assertTrue("unita" in object_names) - self.assertTrue("unitb" in object_names) - relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() - self.assertEqual(len(relationships), 1) - relationship_object_names = relationships[0].object_name_list.split(",") - self.assertTrue("node1" not in relationship_object_names) - ent_pvals = self._db_map.query(self._db_map.entity_parameter_value_sq).all() - self.assertEqual(len(ent_pvals), 1) - pval_object_names = ent_pvals[0].object_name_list.split(",") - self.assertTrue("node1" not in pval_object_names) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 5bb71348..640de2a2 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -81,13 +81,13 @@ def test_mixture_of_files_and_shorthands(self): class TestApplyFilterStack(unittest.TestCase): - _db_url = URL("sqlite") + _db_url = URL.create("sqlite") _dir = None @classmethod def setUpClass(cls): cls._dir = TemporaryDirectory() - cls._db_url.database = os.path.join(cls._dir.name, ".sqlite") + cls._db_url = cls._db_url.set(database=os.path.join(cls._dir.name, ".sqlite")) db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") @@ -134,7 +134,7 @@ class TestFilteredDatabaseMap(unittest.TestCase): @classmethod def setUpClass(cls): cls._dir = TemporaryDirectory() - cls._db_url.database = os.path.join(cls._dir.name, "TestFilteredDatabaseMap.sqlite") + cls._db_url = cls._db_url.set(database=os.path.join(cls._dir.name, "TestFilteredDatabaseMap.sqlite")) db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 5497567a..985fff8b 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -99,7 +99,9 @@ class TestValueTransformerUsingDatabase(unittest.TestCase): @classmethod def setUpClass(cls): cls._temp_dir = TemporaryDirectory() - cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_value_transformer.sqlite").as_posix()) + cls._db_url = URL.create( + "sqlite", database=Path(cls._temp_dir.name, "test_value_transformer.sqlite").as_posix() + ) @classmethod def tearDownClass(cls): diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index bd8181da..3fad1e32 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -9,10 +9,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Unit tests for SQL writer. - -""" +""" Unit tests for SQL writer. """ from pathlib import Path from tempfile import TemporaryDirectory import unittest diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 0abea455..08da7560 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -10,6 +10,7 @@ # this program. If not, see . ###################################################################################################################### """ Unit tests for DatabaseMapping class. """ +from collections import namedtuple import multiprocessing import os.path from tempfile import TemporaryDirectory @@ -19,7 +20,6 @@ from unittest.mock import patch from dateutil.relativedelta import relativedelta from sqlalchemy.engine.url import URL, make_url -from sqlalchemy.util import KeyedTuple from spinedb_api import ( DatabaseMapping, SpineDBAPIError, @@ -37,6 +37,10 @@ from tests.custom_db_mapping import CustomDatabaseMapping from tests.mock_helpers import AssertSuccessTestCase +ObjectRow = namedtuple("ObjectRow", ["id", "class_id", "name"]) +ObjectClassRow = namedtuple("ObjectClassRow", ["id", "name"]) +RelationshipRow = namedtuple("RelationshipRow", ["id", "object_class_id_list", "name"]) + def create_query_wrapper(db_map): def query_wrapper(*args, orig_query=db_map.query, **kwargs): @@ -72,13 +76,12 @@ def test_construction_with_sqlalchemy_url_and_filters(self): ) as mock_load: db_map = CustomDatabaseMapping(sa_url, create=True) db_map.close() - mock_load.assert_called_once_with(["fltr1", "fltr2"]) + mock_load.assert_called_once_with(("fltr1", "fltr2")) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) def test_shorthand_filter_query_works(self): with TemporaryDirectory() as temp_dir: - url = URL("sqlite") - url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") + url = URL.create("sqlite", database=os.path.join(temp_dir, "test_shorthand_filter_query_works.json")) out_db_map = CustomDatabaseMapping(url, create=True) out_db_map.add_scenarios({"name": "scen1"}) out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) @@ -315,7 +318,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): ) self.assertEqual(len(metadata_records), 1) self.assertEqual( - dict(**metadata_records[0]), + metadata_records[0]._asdict(), { "id": 1, "entity_class_name": "my_class", @@ -416,7 +419,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): ) self.assertEqual(len(metadata_records), 1) self.assertEqual( - dict(**metadata_records[0]), + metadata_records[0]._asdict(), { "id": 1, "entity_class_name": "my_class", @@ -613,9 +616,9 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): ) ) db_map.commit_session("Add test data to see if this crashes.") - entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)} + entity_names = {entity.id: entity.name for entity in db_map.query(db_map.wide_entity_sq)} alternative_names = { - alternative["id"]: alternative["name"] for alternative in db_map.query(db_map.alternative_sq) + alternative.id: alternative.name for alternative in db_map.query(db_map.alternative_sq) } expected = { ("widget1", "Base"): True, @@ -625,9 +628,9 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): in_database = {} entity_alternatives = db_map.query(db_map.entity_alternative_sq) for entity_alternative in entity_alternatives: - entity_name = entity_names[entity_alternative["entity_id"]] - alternative_name = alternative_names[entity_alternative["alternative_id"]] - in_database[(entity_name, alternative_name)] = entity_alternative["active"] + entity_name = entity_names[entity_alternative.entity_id] + alternative_name = alternative_names[entity_alternative.alternative_id] + in_database[(entity_name, alternative_name)] = entity_alternative.active self.assertEqual(in_database, expected) self.assertEqual(db_map.query(db_map.parameter_value_sq).all(), []) @@ -666,12 +669,12 @@ def test_commit_default_value_for_parameter_called_is_active(self): ) db_map.commit_session("Add test data to see if this crashes") active_by_defaults = { - entity_class["name"]: entity_class["active_by_default"] + entity_class.name: entity_class.active_by_default for entity_class in db_map.query(db_map.wide_entity_class_sq) } self.assertEqual(active_by_defaults, {"Widget": True, "Gadget": True, "NoIsActiveDefault": False}) defaults = [ - from_database(definition["default_value"], definition["default_type"]) + from_database(definition.default_value, definition.default_type) for definition in db_map.query(db_map.parameter_definition_sq) ] self.assertEqual(defaults, [True, True, None]) @@ -1104,7 +1107,7 @@ def test_entity_item_active_in_scenario(self): scenario_alternatives = db_map.query(db_map.scenario_alternative_sq).all() self.assertEqual(len(scenario_alternatives), 5) self.assertEqual( - dict(scenario_alternatives[0]), + scenario_alternatives[0]._asdict(), {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 7}, ) import_functions.import_scenarios(db_map, ("scen1",)) @@ -1975,15 +1978,14 @@ def test_construction_with_filters(self): mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) def test_construction_with_sqlalchemy_url_and_filters(self): - sa_url = URL("sqlite") - sa_url.query = {"spinedbfilter": ["fltr1", "fltr2"]} + sa_url = URL.create("sqlite", query={"spinedbfilter": ["fltr1", "fltr2"]}) with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with patch( "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = CustomDatabaseMapping(sa_url, create=True) db_map.close() - mock_load.assert_called_once_with(["fltr1", "fltr2"]) + mock_load.assert_called_once_with(("fltr1", "fltr2")) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) def test_entity_sq(self): @@ -2737,12 +2739,11 @@ def test_add_relationship_class_with_same_name_as_existing_one(self): ): mock_query.side_effect = query_wrapper mock_object_class_sq.return_value = [ - KeyedTuple([1, "fish"], labels=["id", "name"]), - KeyedTuple([2, "dog"], labels=["id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "1,2", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ObjectClassRow(1, "fish"), + ObjectClassRow(2, "dog"), ] + WideObjectClassRow = namedtuple("WideObjectClassRow", ["id", "object_class_id_list", "name"]) + mock_wide_rel_cls_sq.return_value = [WideObjectClassRow(1, "1,2", "fish__dog")] with self.assertRaises(SpineIntegrityError): self._db_map.add_wide_relationship_classes( {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True @@ -2757,7 +2758,7 @@ def test_add_relationship_class_with_invalid_object_class(self): mock.patch.object(CustomDatabaseMapping, "wide_relationship_class_sq"), ): mock_query.side_effect = query_wrapper - mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])] + mock_object_class_sq.return_value = [ObjectClassRow(1, "fish")] with self.assertRaises(SpineIntegrityError): self._db_map.add_wide_relationship_classes( {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True @@ -2814,15 +2815,15 @@ def test_add_relationship_identical_to_existing_one(self): ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) - ] - mock_wide_rel_sq.return_value = [ - KeyedTuple([1, 1, "1,2", "nemo__pluto"], labels=["id", "class_id", "object_id_list", "name"]) + ObjectRow(1, 10, "nemo"), + ObjectRow(2, 20, "pluto"), ] + RelationshipClassRow = namedtuple("RelationshipClassRow", ["id", "object_class_id_list", "name"]) + mock_wide_rel_cls_sq.return_value = [RelationshipClassRow(1, "10,20", "fish__dog")] + WideRelationshipClassRow = namedtuple( + "WideRelationshipClassRow", ["id", "class_id", "object_id_list", "name"] + ) + mock_wide_rel_sq.return_value = [WideRelationshipClassRow(1, 1, "1,2", "nemo__pluto")] with self.assertRaises(SpineIntegrityError): self._db_map.add_wide_relationships( {"name": "nemoy__plutoy", "class_id": 1, "object_id_list": [1, 2]}, strict=True @@ -2839,12 +2840,10 @@ def test_add_relationship_with_invalid_class(self): ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ObjectRow(1, 10, "nemo"), + ObjectRow(2, 20, "pluto"), ] + mock_wide_rel_cls_sq.return_value = [RelationshipRow(1, "10,20", "fish__dog")] with self.assertRaises(SpineIntegrityError): self._db_map.add_wide_relationships( {"name": "nemo__pluto", "class_id": 2, "object_id_list": [1, 2]}, strict=True @@ -2861,12 +2860,10 @@ def test_add_relationship_with_invalid_object(self): ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ObjectRow(1, 10, "nemo"), + ObjectRow(2, 20, "pluto"), ] + mock_wide_rel_cls_sq.return_value = [RelationshipRow(1, "10,20", "fish__dog")] with self.assertRaises(SpineIntegrityError): self._db_map.add_wide_relationships( {"name": "nemo__pluto", "class_id": 1, "object_id_list": [1, 3]}, strict=True @@ -3113,10 +3110,10 @@ def test_add_alternative(self): alternatives = self._db_map.query(self._db_map.alternative_sq).all() self.assertEqual(len(alternatives), 2) self.assertEqual( - dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} + alternatives[0]._asdict(), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} ) self.assertEqual( - dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} + alternatives[1]._asdict(), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} ) def test_add_scenario(self): @@ -3127,7 +3124,7 @@ def test_add_scenario(self): scenarios = self._db_map.query(self._db_map.scenario_sq).all() self.assertEqual(len(scenarios), 1) self.assertEqual( - dict(scenarios[0]), + scenarios[0]._asdict(), {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, ) @@ -3141,7 +3138,7 @@ def test_add_scenario_alternative(self): scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all() self.assertEqual(len(scenario_alternatives), 1) self.assertEqual( - dict(scenario_alternatives[0]), + scenario_alternatives[0]._asdict(), {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, ) @@ -3153,7 +3150,7 @@ def test_add_metadata(self): metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) self.assertEqual( - dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} + metadata[0]._asdict(), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} ) def test_add_metadata_that_exists_does_not_add_it(self): @@ -3163,7 +3160,7 @@ def test_add_metadata_that_exists_does_not_add_it(self): self.assertEqual(items, []) metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) + self.assertEqual(metadata[0]._asdict(), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) def test_add_entity_metadata_for_object(self): import_functions.import_object_classes(self._db_map, ("fish",)) @@ -3177,7 +3174,7 @@ def test_add_entity_metadata_for_object(self): entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - dict(entity_metadata[0]), + entity_metadata[0]._asdict(), { "entity_id": 1, "entity_name": "leviathan", @@ -3203,7 +3200,7 @@ def test_add_entity_metadata_for_relationship(self): entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - dict(entity_metadata[0]), + entity_metadata[0]._asdict(), { "entity_id": 2, "entity_name": "my_object__", @@ -3233,7 +3230,7 @@ def test_add_ext_entity_metadata_for_object(self): entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - dict(entity_metadata[0]), + entity_metadata[0]._asdict(), { "entity_id": 1, "entity_name": "leviathan", @@ -3258,11 +3255,11 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an self._db_map.commit_session("Add entity metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - dict(entity_metadata[0]), + entity_metadata[0]._asdict(), { "entity_id": 1, "entity_name": "leviathan", @@ -3290,7 +3287,7 @@ def test_add_parameter_value_metadata(self): value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - dict(value_metadata[0]), + value_metadata[0]._asdict(), { "alternative_name": "Base", "entity_name": "leviathan", @@ -3326,7 +3323,7 @@ def test_add_ext_parameter_value_metadata(self): value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - dict(value_metadata[0]), + value_metadata[0]._asdict(), { "alternative_name": "Base", "entity_name": "leviathan", @@ -3355,11 +3352,11 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): self._db_map.commit_session("Add value metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - dict(value_metadata[0]), + value_metadata[0]._asdict(), { "alternative_name": "Base", "entity_name": "leviathan", @@ -3596,7 +3593,7 @@ def test_update_parameter_definition_value_list(self): pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) self.assertEqual( - dict(pdefs[0]), + pdefs[0]._asdict(), { "commit_id": 3, "default_type": None, @@ -3669,10 +3666,14 @@ def test_update_object_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}) + self.assertEqual( + metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3} + ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 3}) + self.assertEqual( + entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 3} + ) def test_update_object_metadata_reuses_existing_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -3697,12 +3698,16 @@ def test_update_object_metadata_reuses_existing_metadata(self): metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) self.assertEqual( - dict(metadata_entries[0]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2} + metadata_entries[0]._asdict(), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) - self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2}) + self.assertEqual( + entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3} + ) + self.assertEqual( + entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2} + ) def test_update_object_metadata_keeps_metadata_still_in_use(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -3724,12 +3729,20 @@ def test_update_object_metadata_keeps_metadata_still_in_use(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) - self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3}) + self.assertEqual( + metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2} + ) + self.assertEqual( + metadata_entries[1]._asdict(), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3} + ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) - self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2}) + self.assertEqual( + entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3} + ) + self.assertEqual( + entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2} + ) def test_update_parameter_value_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -3752,11 +3765,13 @@ def test_update_parameter_value_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}) + self.assertEqual( + metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3} + ) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 3} + value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 3} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): @@ -3780,16 +3795,22 @@ def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata( self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) - self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) + self.assertEqual( + metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2} + ) + self.assertEqual( + metadata_entries[1]._asdict(), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3} + ) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} + value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2}) + self.assertEqual( + entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2} + ) def test_update_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) @@ -3802,7 +3823,7 @@ def test_update_metadata(self): metadata_records = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_records), 1) self.assertEqual( - dict(metadata_records[0]), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3} + metadata_records[0]._asdict(), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3} ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 6d233472..8dc86590 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -13,6 +13,7 @@ import unittest +from sqlalchemy import text from spinedb_api.helpers import ( compare_schemas, create_new_spine_database, @@ -49,7 +50,7 @@ def test_same_schema(self): def test_different_schema(self): engine1 = create_new_spine_database("sqlite://") engine2 = create_new_spine_database("sqlite://") - engine2.execute("drop table entity") + engine2.execute(text("drop table entity")) self.assertFalse(compare_schemas(engine1, engine2)) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 74252b0a..ccbd7aa4 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -624,7 +624,7 @@ def test_import_object_parameter_definition(self): count = self._assert_imports(import_object_parameters(db_map, (("my_object_class", "my_parameter"),))) self.assertEqual(count, 1) db_map.commit_session("Add test data.") - parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)] + parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -653,7 +653,7 @@ def test_import_object_parameter_definition_with_value_list(self): ) self.assertEqual(count, 1) db_map.commit_session("Add test data.") - parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)] + parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -682,7 +682,7 @@ def test_import_object_parameter_definition_with_default_value_from_value_list(s ) self.assertEqual(count, 1) db_map.commit_session("Add test data.") - parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)] + parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -1761,7 +1761,7 @@ def test_import_object_parameter_value_metadata(self): metadata = db_map.query(db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(metadata), 2) self.assertEqual( - dict(metadata[0]), + metadata[0]._asdict(), { "alternative_name": "Base", "entity_name": "object", @@ -1775,7 +1775,7 @@ def test_import_object_parameter_value_metadata(self): }, ) self.assertEqual( - dict(metadata[1]), + metadata[1]._asdict(), { "alternative_name": "Base", "entity_name": "object", @@ -1810,7 +1810,7 @@ def test_import_relationship_parameter_value_metadata(self): metadata = db_map.query(db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(metadata), 2) self.assertEqual( - dict(metadata[0]), + metadata[0]._asdict(), { "alternative_name": "Base", "entity_name": "object__", @@ -1824,7 +1824,7 @@ def test_import_relationship_parameter_value_metadata(self): }, ) self.assertEqual( - dict(metadata[1]), + metadata[1]._asdict(), { "alternative_name": "Base", "entity_name": "object__", @@ -1864,7 +1864,7 @@ def test_import_single_entity_class_display_mode(self): display_modes = db_map.query(db_map.entity_class_display_mode_sq).all() self.assertEqual(len(display_modes), 1) self.assertEqual( - dict(display_modes[0]), + display_modes[0]._asdict(), { "id": 1, "display_mode_id": 1, diff --git a/tests/test_migration.py b/tests/test_migration.py index b845dd97..f283c5d9 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -10,89 +10,77 @@ # this program. If not, see . ###################################################################################################################### -""" -Unit tests for migration scripts. - -""" +""" Unit tests for migration scripts. """ import os.path from tempfile import TemporaryDirectory import unittest -from sqlalchemy import inspect +from sqlalchemy import inspect, text from sqlalchemy.engine.url import URL from spinedb_api import DatabaseMapping from spinedb_api.helpers import _create_first_spine_database, create_new_spine_database, is_head_engine, schema_dict class TestMigration(unittest.TestCase): - @unittest.skip( - "default_values's server_default has been changed from 0 to NULL in the create scrip, " - "but there's no associated upgrade script yet." - ) - def test_upgrade_schema(self): - """Tests that the upgrade scripts produce the same schema as the function to create - a Spine db anew. - """ - left_engine = _create_first_spine_database("sqlite://") - is_head_engine(left_engine, upgrade=True) - left_insp = inspect(left_engine) - left_dict = schema_dict(left_insp) - right_engine = create_new_spine_database("sqlite://") - right_insp = inspect(right_engine) - right_dict = schema_dict(right_insp) - self.maxDiff = None - self.assertEqual(str(left_dict), str(right_dict)) - - left_ver = left_engine.execute("SELECT version_num FROM alembic_version").fetchall() - right_ver = right_engine.execute("SELECT version_num FROM alembic_version").fetchall() - self.assertEqual(left_ver, right_ver) - - left_ent_typ = left_engine.execute("SELECT * FROM entity_type").fetchall() - right_ent_typ = right_engine.execute("SELECT * FROM entity_type").fetchall() - left_ent_cls_typ = left_engine.execute("SELECT * FROM entity_class_type").fetchall() - right_ent_cls_typ = right_engine.execute("SELECT * FROM entity_class_type").fetchall() - self.assertEqual(left_ent_typ, right_ent_typ) - self.assertEqual(left_ent_cls_typ, right_ent_cls_typ) - def test_upgrade_content(self): """Tests that the upgrade scripts when applied on a db that has some contents persist that content entirely. """ with TemporaryDirectory() as temp_dir: - db_url = URL("sqlite") - db_url.database = os.path.join(temp_dir, "test_upgrade_content.sqlite") - # Create *first* spine db + db_url = URL.create("sqlite", database=os.path.join(temp_dir, "test_upgrade_content.sqlite")) engine = _create_first_spine_database(db_url) # Insert basic stuff - engine.execute("INSERT INTO object_class (id, name) VALUES (1, 'dog')") - engine.execute("INSERT INTO object_class (id, name) VALUES (2, 'fish')") - engine.execute("INSERT INTO object (id, class_id, name) VALUES (1, 1, 'pluto')") - engine.execute("INSERT INTO object (id, class_id, name) VALUES (2, 1, 'scooby')") - engine.execute("INSERT INTO object (id, class_id, name) VALUES (3, 2, 'nemo')") + engine.execute(text("INSERT INTO object_class (id, name) VALUES (1, 'dog')")) + engine.execute(text("INSERT INTO object_class (id, name) VALUES (2, 'fish')")) + engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (1, 1, 'pluto')")) + engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (2, 1, 'scooby')")) + engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (3, 2, 'nemo')")) + engine.execute( + text( + "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 0, 1)" + ) + ) + engine.execute( + text( + "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 1, 2)" + ) + ) + engine.execute( + text( + "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 0, 1)" + ) + ) + engine.execute( + text( + "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 1, 3)" + ) + ) + engine.execute( + text( + "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 0, 2)" + ) + ) engine.execute( - "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 0, 1)" + text( + "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 1, 3)" + ) ) + engine.execute(text("INSERT INTO parameter (id, object_class_id, name) VALUES (1, 1, 'breed')")) + engine.execute(text("INSERT INTO parameter (id, object_class_id, name) VALUES (2, 2, 'water')")) engine.execute( - "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 1, 2)" + text("INSERT INTO parameter (id, relationship_class_id, name) VALUES (3, 1, 'relative_speed')") ) engine.execute( - "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 0, 1)" + text("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 1, '\"labrador\"')") ) engine.execute( - "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 1, 3)" + text("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 2, '\"big dane\"')") ) engine.execute( - "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 0, 2)" + text("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 1, '100')") ) engine.execute( - "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 1, 3)" + text("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')") ) - engine.execute("INSERT INTO parameter (id, object_class_id, name) VALUES (1, 1, 'breed')") - engine.execute("INSERT INTO parameter (id, object_class_id, name) VALUES (2, 2, 'water')") - engine.execute("INSERT INTO parameter (id, relationship_class_id, name) VALUES (3, 1, 'relative_speed')") - engine.execute("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 1, '\"labrador\"')") - engine.execute("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 2, '\"big dane\"')") - engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 1, '100')") - engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')") # Upgrade the db and check that our stuff is still there db_map = DatabaseMapping(db_url, upgrade=True) object_classes = {x.id: x.name for x in db_map.query(db_map.object_class_sq)}