diff --git a/pyproject.toml b/pyproject.toml
index f0f43ad7..2b17c1eb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,8 +13,7 @@ classifiers = [
]
requires-python = ">=3.9"
dependencies = [
- # v1.4 does not pass tests
- "SQLAlchemy >=1.3, <1.4",
+ "SQLAlchemy >=1.4, <1.5",
"alembic >=1.7",
"datapackage >=1.15.2",
"python-dateutil >=2.8.1",
diff --git a/spinedb_api/alembic/env.py b/spinedb_api/alembic/env.py
index e503d1f2..c129e357 100644
--- a/spinedb_api/alembic/env.py
+++ b/spinedb_api/alembic/env.py
@@ -1,4 +1,3 @@
-from __future__ import with_statement
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
diff --git a/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py b/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py
index 5a5bbdac..b7a548aa 100644
--- a/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py
+++ b/spinedb_api/alembic/versions/02581198a2d8_create_entity_class_display_order_tables.py
@@ -32,20 +32,20 @@ def upgrade():
sa.Column("display_order", sa.Integer, nullable=False),
sa.Column(
"display_status",
- sa.Enum(DisplayStatus, name="display_status_enum"),
+ sa.Enum(DisplayStatus, name="display_status_enum", create_constraint=True),
server_default=DisplayStatus.visible.name,
nullable=False,
),
sa.Column("display_font_color", sa.String(6), server_default=sa.null()),
sa.Column("display_background_color", sa.String(6), server_default=sa.null()),
sa.ForeignKeyConstraint(
- ["entity_class_display_mode_id"],
+ ("entity_class_display_mode_id",),
["entity_class_display_mode.id"],
onupdate="CASCADE",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
- ["entity_class_id"],
+ ("entity_class_id",),
["entity_class.id"],
onupdate="CASCADE",
ondelete="CASCADE",
diff --git a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py
index 9e641ab1..e2fe132f 100644
--- a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py
+++ b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py
@@ -30,8 +30,8 @@ def upgrade():
tfm = session.query(Base.classes.tool_feature_method).all()
session.query(Base.classes.parameter_value_list).delete()
session.query(Base.classes.tool_feature_method).delete()
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
# Change schema
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
diff --git a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py
index 957a971f..52cb8234 100644
--- a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py
+++ b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py
@@ -17,8 +17,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.add_column(sa.Column("metadata_id", sa.Integer, server_default=sa.null()))
diff --git a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py
index 6ea3bad9..181e8338 100644
--- a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py
+++ b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py
@@ -9,6 +9,7 @@
from datetime import datetime, timezone
from alembic import op
import sqlalchemy as sa
+from sqlalchemy import text
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import sessionmaker
@@ -90,8 +91,8 @@ def alter_tables_after_update():
None, "alternative", ("alternative_id",), ("id",), onupdate="CASCADE", ondelete="CASCADE"
)
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.add_column(sa.Column("alternative_id", sa.Integer, server_default=sa.null()))
@@ -101,7 +102,8 @@ def alter_tables_after_update():
date = datetime.utcnow()
conn = op.get_bind()
conn.execute(
- """
+ text(
+ """
UPDATE next_id
SET
user = :user,
@@ -109,7 +111,8 @@ def alter_tables_after_update():
alternative_id = 2,
scenario_id = 1,
scenario_alternative_id = 1
- """,
+ """
+ ),
user=user,
date=date,
)
diff --git a/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py b/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py
index ef1ea2a5..5be0e6a8 100644
--- a/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py
+++ b/spinedb_api/alembic/versions/51fd7b69acf7_add_parameter_tag_and_parameter_value_list.py
@@ -17,8 +17,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.add_column(sa.Column("parameter_tag_id", sa.Integer))
diff --git a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py
index cec2f6d9..749bbd98 100644
--- a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py
+++ b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py
@@ -8,6 +8,7 @@
from alembic import op
import sqlalchemy as sa
+from sqlalchemy import text
from spinedb_api.helpers import naming_convention
# revision identifiers, used by Alembic.
@@ -106,28 +107,28 @@ def upgrade():
def _get_constraints():
conn = op.get_bind()
- meta = sa.MetaData(conn)
- meta.reflect()
+ meta = sa.MetaData()
+ meta.reflect(conn)
return [[c.name for c in meta.tables[tname].constraints] for tname in ["entity_class", "entity"]]
def _persist_data():
conn = op.get_bind()
- meta = sa.MetaData(conn)
- meta.reflect()
+ meta = sa.MetaData()
+ meta.reflect(conn)
ecd_items = [
- {"entity_class_id": x["entity_class_id"], "dimension_id": x["member_class_id"], "position": x["dimension"]}
- for x in conn.execute("SELECT * FROM relationship_entity_class")
+ {"entity_class_id": x.entity_class_id, "dimension_id": x.member_class_id, "position": x.dimension}
+ for x in conn.execute(text("SELECT * FROM relationship_entity_class"))
]
ee_items = [
{
- "entity_id": x["entity_id"],
- "entity_class_id": x["entity_class_id"],
- "element_id": x["member_id"],
- "dimension_id": x["member_class_id"],
- "position": x["dimension"],
+ "entity_id": x.entity_id,
+ "entity_class_id": x.entity_class_id,
+ "element_id": x.member_id,
+ "dimension_id": x.member_class_id,
+ "position": x.dimension,
}
- for x in conn.execute("SELECT * FROM relationship_entity")
+ for x in conn.execute(text("SELECT * FROM relationship_entity"))
]
op.bulk_insert(meta.tables["entity_class_dimension"], ecd_items)
op.bulk_insert(meta.tables["entity_element"], ee_items)
diff --git a/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py b/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py
index 827e40b8..65c5fdbe 100644
--- a/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py
+++ b/spinedb_api/alembic/versions/7e2e66ae0f8f_rename_entity_class_display_mode_tables.py
@@ -8,6 +8,7 @@
from alembic import op
import sqlalchemy as sa
+from spinedb_api import naming_convention
# revision identifiers, used by Alembic.
revision = "7e2e66ae0f8f"
@@ -17,8 +18,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
with op.batch_alter_table("entity_class_display_mode") as batch_op:
batch_op.drop_constraint("pk_entity_class_display_mode", type_="primary")
batch_op.create_primary_key("pk_display_mode", ["id"])
diff --git a/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py b/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py
index 903f3db9..d7c5793c 100644
--- a/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py
+++ b/spinedb_api/alembic/versions/8c19c53d5701_rename_parameter_to_parameter_definition.py
@@ -22,8 +22,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.alter_column("parameter_id", new_column_name="parameter_definition_id", type_=sa.Integer)
diff --git a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py
index b4c0dbd4..a0275621 100644
--- a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py
+++ b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py
@@ -25,8 +25,8 @@ def upgrade():
conn = op.get_bind()
Session = sessionmaker(bind=conn)
session = Session()
- meta = MetaData(conn)
- meta.reflect()
+ meta = MetaData()
+ meta.reflect(conn)
list_value = meta.tables["list_value"]
parameter_definition = meta.tables["parameter_definition"]
parameter_value = meta.tables["parameter_value"]
diff --git a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py
index dbff1e34..cdc4e55d 100644
--- a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py
+++ b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py
@@ -17,8 +17,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.add_column(sa.Column("entity_group_id", sa.Integer))
diff --git a/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py b/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py
index 4a2902fd..e1ff6c50 100644
--- a/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py
+++ b/spinedb_api/alembic/versions/bba1e2ef5153_move_to_entity_based_design.py
@@ -9,6 +9,7 @@
from datetime import datetime
from alembic import op
import sqlalchemy as sa
+from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "bba1e2ef5153"
@@ -111,128 +112,138 @@ def insert_into_new_tables():
op.execute("""INSERT INTO entity_type (id, name) VALUES (2, "relationship")""")
# More difficult ones
conn = op.get_bind()
- meta = sa.MetaData(conn)
- meta.reflect()
+ meta = sa.MetaData()
+ meta.reflect(conn)
# entity class level
entity_classes = [
{
"type_id": 1,
- "name": r["name"],
- "description": r["description"],
- "display_order": r["display_order"],
- "display_icon": r["display_icon"],
- "hidden": r["hidden"],
- "commit_id": r["commit_id"],
+ "name": r.name,
+ "description": r.description,
+ "display_order": r.display_order,
+ "display_icon": r.display_icon,
+ "hidden": r.hidden,
+ "commit_id": r.commit_id,
}
for r in conn.execute(
- "SELECT name, description, display_order, display_icon, hidden, commit_id FROM object_class"
+ text("SELECT name, description, display_order, display_icon, hidden, commit_id FROM object_class")
)
] + [
{
"type_id": 2,
- "name": r["name"],
+ "name": r.name,
"description": None,
"display_order": None,
"display_icon": None,
- "hidden": r["hidden"],
- "commit_id": r["commit_id"],
+ "hidden": r.hidden,
+ "commit_id": r.commit_id,
}
- for r in conn.execute("SELECT name, hidden, commit_id FROM relationship_class GROUP BY name")
+ for r in conn.execute(text("SELECT name, hidden, commit_id FROM relationship_class GROUP BY name"))
]
op.bulk_insert(meta.tables["entity_class"], entity_classes)
# Id mappings
obj_cls_to_ent_cls = {
- r["object_class_id"]: r["entity_class_id"]
+ r.object_class_id: r.entity_class_id
for r in conn.execute(
- """
+ text(
+ """
SELECT object_class.id AS object_class_id, entity_class.id AS entity_class_id
FROM object_class, entity_class
WHERE entity_class.type_id = 1
AND object_class.name = entity_class.name
"""
+ )
)
}
rel_cls_to_ent_cls = {
- r["relationship_class_id"]: r["entity_class_id"]
+ r.relationship_class_id: r.entity_class_id
for r in conn.execute(
- """
+ text(
+ """
SELECT relationship_class.id AS relationship_class_id, entity_class.id AS entity_class_id
FROM relationship_class, entity_class
WHERE entity_class.type_id = 2
AND relationship_class.name = entity_class.name
- GROUP BY relationship_class_id, entity_class_id
+ GROUP BY relationship_class_id, entity_class_id
"""
+ )
)
}
temp_relationship_classes = [
- {"entity_class_id": r["id"], "type_id": 2, "commit_id": r["commit_id"]}
- for r in conn.execute("SELECT id, commit_id FROM entity_class WHERE type_id = 2")
+ {"entity_class_id": r.id, "type_id": 2, "commit_id": r.commit_id}
+ for r in conn.execute(text("SELECT id, commit_id FROM entity_class WHERE type_id = 2"))
]
op.bulk_insert(meta.tables["temp_relationship_class"], temp_relationship_classes)
relationship_entity_classes = [
{
- "entity_class_id": rel_cls_to_ent_cls[r["id"]],
- "dimension": r["dimension"],
- "member_class_id": obj_cls_to_ent_cls[r["object_class_id"]],
+ "entity_class_id": rel_cls_to_ent_cls[r.id],
+ "dimension": r.dimension,
+ "member_class_id": obj_cls_to_ent_cls[r.object_class_id],
"member_class_type_id": 1,
- "commit_id": r["commit_id"],
+ "commit_id": r.commit_id,
}
- for r in conn.execute("SELECT id, dimension, object_class_id, commit_id FROM relationship_class")
+ for r in conn.execute(text("SELECT id, dimension, object_class_id, commit_id FROM relationship_class"))
]
op.bulk_insert(meta.tables["relationship_entity_class"], relationship_entity_classes)
# entity level
entities = [
- {"type_id": 1, "class_id": obj_cls_to_ent_cls[r["class_id"]], "name": r["name"], "commit_id": r["commit_id"]}
- for r in conn.execute("SELECT class_id, name, commit_id FROM object")
+ {"type_id": 1, "class_id": obj_cls_to_ent_cls[r.class_id], "name": r.name, "commit_id": r.commit_id}
+ for r in conn.execute(text("SELECT class_id, name, commit_id FROM object"))
] + [
- {"type_id": 2, "class_id": rel_cls_to_ent_cls[r["class_id"]], "name": r["name"], "commit_id": r["commit_id"]}
- for r in conn.execute("SELECT class_id, name, commit_id FROM relationship GROUP BY class_id, name")
+ {"type_id": 2, "class_id": rel_cls_to_ent_cls[r.class_id], "name": r.name, "commit_id": r.commit_id}
+ for r in conn.execute(text("SELECT class_id, name, commit_id FROM relationship GROUP BY class_id, name"))
]
op.bulk_insert(meta.tables["entity"], entities)
# Id mappings
obj_to_ent = {
- r["object_id"]: r["entity_id"]
+ r.object_id: r.entity_id
for r in conn.execute(
- """
+ text(
+ """
SELECT object.id AS object_id, entity.id AS entity_id
FROM object, entity
WHERE entity.type_id = 1
AND object.name = entity.name
"""
+ )
)
}
rel_to_ent = {
- r["relationship_id"]: r["entity_id"]
+ r.relationship_id: r.entity_id
for r in conn.execute(
- """
+ text(
+ """
SELECT relationship.id AS relationship_id, entity.id AS entity_id
FROM relationship, entity
WHERE entity.type_id = 2
AND relationship.name = entity.name
- GROUP BY relationship_id, entity_id
+ GROUP BY relationship_id, entity_id
"""
+ )
)
}
temp_relationships = [
- {"entity_id": r["id"], "entity_class_id": r["class_id"], "type_id": 2, "commit_id": r["commit_id"]}
- for r in conn.execute("SELECT id, class_id, commit_id FROM entity WHERE type_id = 2")
+ {"entity_id": r.id, "entity_class_id": r.class_id, "type_id": 2, "commit_id": r.commit_id}
+ for r in conn.execute(text("SELECT id, class_id, commit_id FROM entity WHERE type_id = 2"))
]
op.bulk_insert(meta.tables["temp_relationship"], temp_relationships)
relationship_entities = [
{
- "entity_id": rel_to_ent[r["id"]],
- "entity_class_id": rel_cls_to_ent_cls[r["class_id"]],
- "dimension": r["dimension"],
- "member_id": obj_to_ent[r["object_id"]],
- "member_class_id": obj_cls_to_ent_cls[r["object_class_id"]],
- "commit_id": r["commit_id"],
+ "entity_id": rel_to_ent[r.id],
+ "entity_class_id": rel_cls_to_ent_cls[r.class_id],
+ "dimension": r.dimension,
+ "member_id": obj_to_ent[r.object_id],
+ "member_class_id": obj_cls_to_ent_cls[r.object_class_id],
+ "commit_id": r.commit_id,
}
for r in conn.execute(
- """
+ text(
+ """
SELECT r.id, r.class_id, r.dimension, o.class_id AS object_class_id, r.object_id, r.commit_id
FROM relationship AS r, object AS o
WHERE r.object_id = o.id
"""
+ )
)
]
op.bulk_insert(meta.tables["relationship_entity"], relationship_entities)
@@ -291,42 +302,48 @@ def alter_tables_before_update(meta):
def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_to_ent):
conn = op.get_bind()
- ent_to_ent_cls = {r["id"]: r["class_id"] for r in conn.execute("SELECT id, class_id FROM entity")}
+ ent_to_ent_cls = {r.id: r.class_id for r in conn.execute(text("SELECT id, class_id FROM entity"))}
for object_class_id, entity_class_id in obj_cls_to_ent_cls.items():
conn.execute(
- "UPDATE object_class SET entity_class_id = :entity_class_id, type_id = 1 WHERE id = :object_class_id",
+ text("UPDATE object_class SET entity_class_id = :entity_class_id, type_id = 1 WHERE id = :object_class_id"),
entity_class_id=entity_class_id,
object_class_id=object_class_id,
)
conn.execute(
- """
+ text(
+ """
UPDATE parameter_definition SET entity_class_id = :entity_class_id
WHERE object_class_id = :object_class_id
- """,
+ """
+ ),
entity_class_id=entity_class_id,
object_class_id=object_class_id,
)
for relationship_class_id, entity_class_id in rel_cls_to_ent_cls.items():
conn.execute(
- """
+ text(
+ """
UPDATE parameter_definition SET entity_class_id = :entity_class_id
WHERE relationship_class_id = :relationship_class_id
- """,
+ """
+ ),
entity_class_id=entity_class_id,
relationship_class_id=relationship_class_id,
)
for object_id, entity_id in obj_to_ent.items():
conn.execute(
- "UPDATE object SET entity_id = :entity_id, type_id = 1 WHERE id = :object_id",
+ text("UPDATE object SET entity_id = :entity_id, type_id = 1 WHERE id = :object_id"),
entity_id=entity_id,
object_id=object_id,
)
entity_class_id = ent_to_ent_cls[entity_id]
conn.execute(
- """
+ text(
+ """
UPDATE parameter_value SET entity_id = :entity_id, entity_class_id = :entity_class_id
WHERE object_id = :object_id
- """,
+ """
+ ),
entity_id=entity_id,
entity_class_id=entity_class_id,
object_id=object_id,
@@ -334,28 +351,31 @@ def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_
for relationship_id, entity_id in rel_to_ent.items():
entity_class_id = ent_to_ent_cls[entity_id]
conn.execute(
- """
+ text(
+ """
UPDATE parameter_value SET entity_id = :entity_id, entity_class_id = :entity_class_id
WHERE relationship_id = :relationship_id
- """,
+ """
+ ),
entity_id=entity_id,
entity_class_id=entity_class_id,
relationship_id=relationship_id,
)
# Clean our potential mess.
# E.g., I've seen parameter definitions with an invalid relationship_class_id for some reason...!
- conn.execute("DELETE FROM parameter_definition WHERE entity_class_id IS NULL")
- conn.execute("DELETE FROM parameter_value WHERE entity_class_id IS NULL OR entity_id IS NULL")
+ conn.execute(text("DELETE FROM parameter_definition WHERE entity_class_id IS NULL"))
+ conn.execute(text("DELETE FROM parameter_value WHERE entity_class_id IS NULL OR entity_id IS NULL"))
if "next_id" not in meta.tables:
return
- row = conn.execute("SELECT MAX(id) FROM entity_class").fetchone()
+ row = conn.execute(text("SELECT MAX(id) FROM entity_class")).fetchone()
entity_class_id = row[0] + 1 if row else 1
- row = conn.execute("SELECT MAX(id) FROM entity").fetchone()
+ row = conn.execute(text("SELECT MAX(id) FROM entity")).fetchone()
entity_id = row[0] + 1 if row else 1
user = "alembic"
date = datetime.utcnow()
conn.execute(
- """
+ text(
+ """
UPDATE next_id
SET
user = :user,
@@ -364,7 +384,8 @@ def update_tables(meta, obj_cls_to_ent_cls, rel_cls_to_ent_cls, obj_to_ent, rel_
entity_type_id = 3,
entity_class_id = :entity_class_id,
entity_id = :entity_id
- """,
+ """
+ ),
user=user,
date=date,
entity_class_id=entity_class_id,
diff --git a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py
index d42bc00a..03b758ca 100644
--- a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py
+++ b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py
@@ -48,6 +48,6 @@ def _get_scalar_values_by_id(table, value_label, type_label, connection):
value_column = getattr(table.c, value_label)
type_column = getattr(table.c, type_label)
return {
- row["id"]: row[value_label]
+ row.id: row._mapping[value_label]
for row in connection.execute(sa.select([table.c.id, value_column]).where(type_column == None))
}
diff --git a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py
index 822e994b..68ebe50f 100644
--- a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py
+++ b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py
@@ -17,8 +17,8 @@
def upgrade():
- m = sa.MetaData(op.get_bind())
- m.reflect()
+ m = sa.MetaData()
+ m.reflect(op.get_bind())
if "next_id" in m.tables:
with op.batch_alter_table("next_id") as batch_op:
batch_op.add_column(sa.Column("tool_id", sa.Integer, server_default=sa.null()))
diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py
index 81133bb8..34b2122a 100644
--- a/spinedb_api/compatibility.py
+++ b/spinedb_api/compatibility.py
@@ -27,8 +27,8 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea
Returns:
tuple: list of entity classes to add, update and ids to remove
"""
- meta = sa.MetaData(conn)
- meta.reflect()
+ meta = sa.MetaData()
+ meta.reflect(conn)
lv_table = meta.tables["list_value"]
pd_table = meta.tables["parameter_definition"]
if use_existing_tool_feature_method:
@@ -53,7 +53,7 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea
# It's a new DB without tool/feature/method or we don't want to use them...
# we take 'is_active' as feature and JSON "yes" and true as methods
lv_id_by_pdef_id = {
- x["parameter_definition_id"]: x["id"]
+ x.parameter_definition_id: x.id
for x in conn.execute(
sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")])
.where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id)
@@ -63,10 +63,10 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea
}
# Collect 'is_active' default values
list_value_id = sa.case(
- [(pd_table.c.default_type == "list_value_ref", sa.cast(pd_table.c.default_value, sa.Integer()))], else_=None
+ (pd_table.c.default_type == "list_value_ref", sa.cast(pd_table.c.default_value, sa.Integer())), else_=None
)
is_active_default_vals = [
- {c: x[c] for c in ("entity_class_id", "parameter_definition_id", "list_value_id")}
+ {c: x._mapping[c] for c in ("entity_class_id", "parameter_definition_id", "list_value_id")}
for x in conn.execute(
sa.select(
[
@@ -114,8 +114,8 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe
list: entity_alternative items to update
list: parameter_value ids to remove
"""
- meta = sa.MetaData(conn)
- meta.reflect()
+ meta = sa.MetaData()
+ meta.reflect(conn)
ea_table = meta.tables["entity_alternative"]
lv_table = meta.tables["list_value"]
pv_table = meta.tables["parameter_value"]
@@ -142,7 +142,7 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe
# we take 'is_active' as feature and JSON "yes" and true as methods
pd_table = meta.tables["parameter_definition"]
lv_id_by_pdef_id = {
- x["parameter_definition_id"]: x["id"]
+ x.parameter_definition_id: x.id
for x in conn.execute(
sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")])
.where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id)
@@ -151,11 +151,9 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe
)
}
# Collect 'is_active' parameter values
- list_value_id = sa.case(
- [(pv_table.c.type == "list_value_ref", sa.cast(pv_table.c.value, sa.Integer()))], else_=None
- )
+ list_value_id = sa.case((pv_table.c.type == "list_value_ref", sa.cast(pv_table.c.value, sa.Integer())), else_=None)
is_active_pvals = [
- {c: x[c] for c in ("id", "entity_id", "alternative_id", "parameter_definition_id", "list_value_id")}
+ {c: x._mapping[c] for c in ("id", "entity_id", "alternative_id", "parameter_definition_id", "list_value_id")}
for x in conn.execute(
sa.select([pv_table, list_value_id.label("list_value_id")]).where(
pv_table.c.parameter_definition_id.in_(lv_id_by_pdef_id)
@@ -164,7 +162,7 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe
]
# Compute new entity_alternative items from 'is_active' parameter values,
# where 'active' is True if the value of 'is_active' is the one from the tool_feature_method specification
- current_ea_ids = {(x["entity_id"], x["alternative_id"]): x["id"] for x in conn.execute(sa.select([ea_table]))}
+ current_ea_ids = {(x.entity_id, x.alternative_id): x.id for x in conn.execute(sa.select([ea_table]))}
new_ea_items = {
(x["entity_id"], x["alternative_id"]): {
"entity_id": x["entity_id"],
diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py
index 11df2525..3e777694 100644
--- a/spinedb_api/db_mapping.py
+++ b/spinedb_api/db_mapping.py
@@ -26,10 +26,11 @@
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
from alembic.util.exc import CommandError
-from sqlalchemy import MetaData, create_engine, inspect
+from sqlalchemy import MetaData, create_engine, inspect, text
from sqlalchemy.engine.url import URL, make_url
from sqlalchemy.event import listen
from sqlalchemy.exc import ArgumentError, DatabaseError, DBAPIError
+from sqlalchemy.orm import Query, Session
from sqlalchemy.pool import NullPool
from .compatibility import compatibility_transformations
from .db_mapping_base import DatabaseMappingBase, Status
@@ -46,7 +47,6 @@
model_meta,
)
from .mapped_items import item_factory
-from .query import Query
from .spine_db_client import get_db_url_from_server
logging.getLogger("alembic").setLevel(logging.CRITICAL)
@@ -152,7 +152,8 @@ def __init__(
if isinstance(db_url, str):
filter_configs, db_url = pop_filter_configs(db_url)
elif isinstance(db_url, URL):
- filter_configs = db_url.query.pop("spinedbfilter", [])
+ filter_configs = db_url.query.get("spinedbfilter", [])
+ db_url = db_url.difference_update_query("spinedbfilter")
else:
filter_configs = []
self._filter_configs = filter_configs if apply_filters else None
@@ -171,18 +172,25 @@ def __init__(
listen(self.engine, "close", self._receive_engine_close)
if self._memory:
copy_database_bind(self.engine, self._original_engine)
- self._metadata = MetaData(self.engine)
- self._metadata.reflect()
+ self._metadata = MetaData()
+ self._metadata.reflect(self.engine)
self._tablenames = [t.name for t in self._metadata.sorted_tables]
+ self._connection = self.engine.connect()
+ self._session = Session(self._connection)
if self._filter_configs is not None:
stack = load_filters(self._filter_configs)
apply_filter_stack(self, stack)
def __enter__(self):
+ if self._connection is None:
+ self._connection = self.engine.connect()
+ if self._session is None:
+ self._session = Session(self._connection)
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
+ return False
def __del__(self):
self.close()
@@ -302,7 +310,7 @@ def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_tim
) from None
with engine.begin() as connection:
if sa_url.drivername == "sqlite":
- connection.execute("BEGIN IMMEDIATE")
+ connection.execute(text("BEGIN IMMEDIATE"))
# TODO: Do other dialects need to lock?
migration_context = MigrationContext.configure(connection)
try:
@@ -777,7 +785,7 @@ def fetch_all(self, *item_types):
item_type = self.real_item_type(item_type)
self.do_fetch_all(item_type, commit_count)
- def query(self, *args, **kwargs):
+ def query(self, *entities, **kwargs):
"""Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB.
To perform custom ``SELECT`` statements, call this method with one or more of the documented
@@ -805,9 +813,9 @@ def query(self, *args, **kwargs):
).group_by(db_map.entity_class_sq.c.name).all()
Returns:
- :class:`~spinedb_api.query.Query`: The resulting query.
+ :class:`~sqlalchemy.orm.Query`: The resulting query.
"""
- return Query(self.engine, *args)
+ return self._session.query(*entities, **kwargs)
def commit_session(self, comment, apply_compatibility_transforms=True):
"""Commits the changes from the in-memory mapping to the database.
@@ -866,7 +874,7 @@ def has_external_commits(self):
return self._commit_count != self._query_commit_count()
def close(self):
- """Closes this DB mapping. This is only needed if you're keeping a long-lived session.
+ """Closes this DB mapping.
For instance::
class MyDBMappingWrapper:
@@ -885,6 +893,12 @@ def __del__(self):
...
# db_map.close() is automatically called when leaving this block
"""
+ if self._session is not None:
+ self._session.close()
+ self._session = None
+ if self._connection is not None:
+ self._connection.close()
+ self._connection = None
self.closed = True
def add_ext_entity_metadata(self, *items, **kwargs):
diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py
index 311a7052..cf8c8ed4 100644
--- a/spinedb_api/db_mapping_base.py
+++ b/spinedb_api/db_mapping_base.py
@@ -298,8 +298,8 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs):
if not qry:
return []
if not limit:
- return [dict(x) for x in qry]
- return [dict(x) for x in qry.limit(limit).offset(offset)]
+ return [x._asdict() for x in qry]
+ return [x._asdict() for x in qry.limit(limit).offset(offset)]
def do_fetch_more(self, item_type, offset=0, limit=None, real_commit_count=None, **kwargs):
"""Fetches items from the DB and adds them to the mapping.
diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py
index 3fa4e698..e888b422 100644
--- a/spinedb_api/db_mapping_commit_mixin.py
+++ b/spinedb_api/db_mapping_commit_mixin.py
@@ -47,7 +47,7 @@ def _do_add_items(self, connection, tablename, *items_to_add):
if id_items:
connection.execute(table.insert(), [x.resolve() for x in id_items])
if temp_id_items:
- current_ids = {x["id"] for x in connection.execute(table.select())}
+ current_ids = {x.id for x in connection.execute(table.select())}
next_id = max(current_ids, default=0) + 1
available_ids = set(range(1, next_id)) - current_ids
required_id_count = len(temp_id_items) - len(available_ids)
diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py
index da606d3c..29b9ef5b 100644
--- a/spinedb_api/db_mapping_query_mixin.py
+++ b/spinedb_api/db_mapping_query_mixin.py
@@ -13,7 +13,8 @@
from types import MethodType
from sqlalchemy import Integer, Table, and_, case, cast, func, or_
from sqlalchemy.orm import aliased
-from sqlalchemy.sql.expression import Alias, label
+from sqlalchemy.sql import Subquery
+from sqlalchemy.sql.expression import label
from .helpers import forward_sweep, group_concat
@@ -95,7 +96,7 @@ def _func(x, tables):
getattr(self, attr)
table_to_sq_attr = {}
for attr, val in vars(self).items():
- if not isinstance(val, Alias):
+ if not isinstance(val, Subquery):
continue
tables = set()
forward_sweep(val, _func, tables)
@@ -124,7 +125,7 @@ def _subquery(self, tablename):
tablename (str): the table to be queried.
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
table = self._metadata.tables[tablename]
return self.query(table).subquery(tablename + "_sq")
@@ -138,7 +139,7 @@ def superclass_subclass_sq(self):
SELECT * FROM superclass_subclass
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._superclass_subclass_sq is None:
self._superclass_subclass_sq = self._subquery("superclass_subclass")
@@ -153,7 +154,7 @@ def entity_class_sq(self):
SELECT * FROM entity_class
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_class_sq is None:
self._entity_class_sq = self._make_entity_class_sq()
@@ -168,7 +169,7 @@ def entity_class_dimension_sq(self):
SELECT * FROM entity_class_dimension
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_class_dimension_sq is None:
self._entity_class_dimension_sq = self._subquery("entity_class_dimension")
@@ -191,7 +192,7 @@ def wide_entity_class_sq(self):
ec.id == ecd.entity_class_id
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._wide_entity_class_sq is None:
entity_class_dimension_sq = (
@@ -258,7 +259,7 @@ def entity_sq(self):
SELECT * FROM entity
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_sq is None:
self._entity_sq = self._make_entity_sq()
@@ -273,7 +274,7 @@ def entity_element_sq(self):
SELECT * FROM entity_element
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_element_sq is None:
self._entity_element_sq = self._make_entity_element_sq()
@@ -295,7 +296,7 @@ def wide_entity_sq(self):
e.id == ee.entity_id
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._wide_entity_sq is None:
entity_element_sq = (
@@ -342,7 +343,7 @@ def entity_group_sq(self):
SELECT * FROM entity_group
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_group_sq is None:
self._entity_group_sq = self._make_entity_group_sq()
@@ -357,7 +358,7 @@ def display_mode_sq(self):
SELECT * FROM display_mode
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._display_mode_sq is None:
self._display_mode_sq = self._subquery("display_mode")
@@ -372,7 +373,7 @@ def entity_class_display_mode_sq(self):
SELECT * FROM entity_class_display_mode
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_class_display_mode_sq is None:
self._entity_class_display_mode_sq = self._subquery("entity_class_display_mode")
@@ -387,7 +388,7 @@ def alternative_sq(self):
SELECT * FROM alternative
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._alternative_sq is None:
self._alternative_sq = self._make_alternative_sq()
@@ -402,7 +403,7 @@ def scenario_sq(self):
SELECT * FROM scenario
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._scenario_sq is None:
self._scenario_sq = self._make_scenario_sq()
@@ -417,7 +418,7 @@ def scenario_alternative_sq(self):
SELECT * FROM scenario_alternative
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._scenario_alternative_sq is None:
self._scenario_alternative_sq = self._make_scenario_alternative_sq()
@@ -432,7 +433,7 @@ def entity_alternative_sq(self):
SELECT * FROM entity_alternative
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_alternative_sq is None:
self._entity_alternative_sq = self._make_entity_alternative_sq()
@@ -447,7 +448,7 @@ def parameter_value_list_sq(self):
SELECT * FROM parameter_value_list
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._parameter_value_list_sq is None:
self._parameter_value_list_sq = self._subquery("parameter_value_list")
@@ -462,7 +463,7 @@ def list_value_sq(self):
SELECT * FROM list_value
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._list_value_sq is None:
self._list_value_sq = self._subquery("list_value")
@@ -477,7 +478,7 @@ def parameter_definition_sq(self):
SELECT * FROM parameter_definition
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._parameter_definition_sq is None:
@@ -485,7 +486,7 @@ def parameter_definition_sq(self):
return self._parameter_definition_sq
@property
- def wide_parameter_definition_sq(self) -> Alias:
+ def wide_parameter_definition_sq(self) -> Subquery:
"""A subquery of the form:
.. code-block:: sql
@@ -545,7 +546,7 @@ def parameter_type_sq(self):
SELECT * FROM parameter_type
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._parameter_type_sq is None:
self._parameter_type_sq = self._subquery("parameter_type")
@@ -560,7 +561,7 @@ def parameter_value_sq(self):
SELECT * FROM parameter_value
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._parameter_value_sq is None:
self._parameter_value_sq = self._make_parameter_value_sq()
@@ -575,7 +576,7 @@ def metadata_sq(self):
SELECT * FROM list_value
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._metadata_sq is None:
self._metadata_sq = self._subquery("metadata")
@@ -590,7 +591,7 @@ def parameter_value_metadata_sq(self):
SELECT * FROM parameter_value_metadata
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._parameter_value_metadata_sq is None:
self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata")
@@ -605,7 +606,7 @@ def entity_metadata_sq(self):
SELECT * FROM entity_metadata
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._entity_metadata_sq is None:
self._entity_metadata_sq = self._subquery("entity_metadata")
@@ -620,7 +621,7 @@ def commit_sq(self):
SELECT * FROM commit
Returns:
- :class:`~sqlalchemy.sql.expression.Alias`
+ :class:`~sqlalchemy.sql.Subquery`
"""
if self._commit_sq is None:
commit_sq = self._subquery("commit")
@@ -1263,7 +1264,7 @@ def _make_entity_class_sq(self):
Creates a subquery for entity classes.
Returns:
- Alias: an entity class subquery
+ Subquery: an entity class subquery
"""
return self._subquery("entity_class")
@@ -1272,7 +1273,7 @@ def _make_entity_sq(self):
Creates a subquery for entities.
Returns:
- Alias: an entity subquery
+ Subquery: an entity subquery
"""
return self._subquery("entity")
@@ -1281,7 +1282,7 @@ def _make_entity_element_sq(self):
Creates a subquery for entity-elements.
Returns:
- Alias: an entity_element subquery
+ Subquery: an entity_element subquery
"""
return self._subquery("entity_element")
@@ -1299,7 +1300,7 @@ def _make_entity_alternative_sq(self):
Creates a subquery for entity-alternatives.
Returns:
- Alias: an entity_alternative subquery
+ Subquery: an entity_alternative subquery
"""
return self._subquery("entity_alternative")
@@ -1308,18 +1309,18 @@ def _make_parameter_definition_sq(self):
Creates a subquery for parameter definitions.
Returns:
- Alias: a parameter definition subquery
+ Subquery: a parameter definition subquery
"""
par_def_sq = self._subquery("parameter_definition")
list_value_id = case(
- [(par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer()))], else_=None
+ (par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer())), else_=None
)
default_value = case(
- [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value)],
+ (par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value),
else_=par_def_sq.c.default_value,
)
default_type = case(
- [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type)],
+ (par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type),
else_=par_def_sq.c.default_type,
)
return (
@@ -1343,12 +1344,12 @@ def _make_parameter_value_sq(self):
Creates a subquery for parameter values.
Returns:
- Alias: a parameter value subquery
+ Subquery: a parameter value subquery
"""
par_val_sq = self._subquery("parameter_value")
- list_value_id = case([(par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer()))], else_=None)
- value = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value)], else_=par_val_sq.c.value)
- type_ = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type)], else_=par_val_sq.c.type)
+ list_value_id = case((par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer())), else_=None)
+ value = case((par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value), else_=par_val_sq.c.value)
+ type_ = case((par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type), else_=par_val_sq.c.type)
return (
self.query(
par_val_sq.c.id.label("id"),
@@ -1371,7 +1372,7 @@ def _make_alternative_sq(self):
Creates a subquery for alternatives.
Returns:
- Alias: an alternative subquery
+ Subquery: an alternative subquery
"""
return self._subquery("alternative")
@@ -1380,7 +1381,7 @@ def _make_scenario_sq(self):
Creates a subquery for scenarios.
Returns:
- Alias: a scenario subquery
+ Subquery: a scenario subquery
"""
return self._subquery("scenario")
@@ -1389,7 +1390,7 @@ def _make_scenario_alternative_sq(self):
Creates a subquery for scenario alternatives.
Returns:
- Alias: a scenario alternative subquery
+ Subquery: a scenario alternative subquery
"""
return self._subquery("scenario_alternative")
@@ -1399,7 +1400,7 @@ def override_entity_class_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns entity class subquery as an :class:`Alias` object
+ returns entity class subquery as an :class:`Subquery` object
"""
self._make_entity_class_sq = MethodType(method, self)
self._clear_subqueries("entity_class")
@@ -1410,7 +1411,7 @@ def override_entity_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns entity subquery as an :class:`Alias` object
+ returns entity subquery as an :class:`Subquery` object
"""
self._make_entity_sq = MethodType(method, self)
self._clear_subqueries("entity")
@@ -1421,7 +1422,7 @@ def override_entity_element_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns entity_element subquery as an :class:`Alias` object
+ returns entity_element subquery as an :class:`Subquery` object
"""
self._make_entity_element_sq = MethodType(method, self)
self._clear_subqueries("entity_element")
@@ -1443,7 +1444,7 @@ def override_entity_alternative_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns entity alternative subquery as an :class:`Alias` object
+ returns entity alternative subquery as an :class:`Subquery` object
"""
self._make_entity_alternative_sq = MethodType(method, self)
self._clear_subqueries("entity_alternative")
@@ -1454,7 +1455,7 @@ def override_parameter_definition_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns parameter definition subquery as an :class:`Alias` object
+ returns parameter definition subquery as an :class:`Subquery` object
"""
self._make_parameter_definition_sq = MethodType(method, self)
self._clear_subqueries("parameter_definition")
@@ -1465,7 +1466,7 @@ def override_parameter_value_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns parameter value subquery as an :class:`Alias` object
+ returns parameter value subquery as an :class:`Subquery` object
"""
self._make_parameter_value_sq = MethodType(method, self)
self._clear_subqueries("parameter_value")
@@ -1476,7 +1477,7 @@ def override_alternative_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns alternative subquery as an :class:`Alias` object
+ returns alternative subquery as an :class:`Subquery` object
"""
self._make_alternative_sq = MethodType(method, self)
self._clear_subqueries("alternative")
@@ -1487,7 +1488,7 @@ def override_scenario_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns scenario subquery as an :class:`Alias` object
+ returns scenario subquery as an :class:`Subquery` object
"""
self._make_scenario_sq = MethodType(method, self)
self._clear_subqueries("scenario")
@@ -1498,7 +1499,7 @@ def override_scenario_alternative_sq_maker(self, method):
Args:
method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and
- returns scenario alternative subquery as an :class:`Alias` object
+ returns scenario alternative subquery as an :class:`Subquery` object
"""
self._make_scenario_alternative_sq = MethodType(method, self)
self._clear_subqueries("scenario_alternative")
@@ -1544,62 +1545,54 @@ def restore_scenario_alternative_sq_maker(self):
self._clear_subqueries("scenario_alternative")
def _object_class_id(self):
- return case(
- [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None
- )
+ return case((self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id), else_=None)
def _relationship_class_id(self):
- return case(
- [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id)], else_=None
- )
+ return case((self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id), else_=None)
def _object_id(self):
- return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id)], else_=None)
+ return case((self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id), else_=None)
def _relationship_id(self):
- return case([(self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id)], else_=None)
+ return case((self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id), else_=None)
def _object_class_name(self):
return case(
- [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name)], else_=None
+ (self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name), else_=None
)
def _relationship_class_name(self):
return case(
- [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name)], else_=None
+ (self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name), else_=None
)
def _object_class_id_list(self):
return case(
- [
- (
- self.wide_entity_class_sq.c.dimension_id_list != None,
- self.wide_relationship_class_sq.c.object_class_id_list,
- )
- ],
+ (
+ self.wide_entity_class_sq.c.dimension_id_list != None,
+ self.wide_relationship_class_sq.c.object_class_id_list,
+ ),
else_=None,
)
def _object_class_name_list(self):
return case(
- [
- (
- self.wide_entity_class_sq.c.dimension_id_list != None,
- self.wide_relationship_class_sq.c.object_class_name_list,
- )
- ],
+ (
+ self.wide_entity_class_sq.c.dimension_id_list != None,
+ self.wide_relationship_class_sq.c.object_class_name_list,
+ ),
else_=None,
)
def _object_name(self):
- return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name)], else_=None)
+ return case((self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name), else_=None)
def _object_id_list(self):
return case(
- [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None
+ (self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list), else_=None
)
def _object_name_list(self):
return case(
- [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None
+ (self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list), else_=None
)
diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py
index 68a1e1d7..2a772e2e 100644
--- a/spinedb_api/export_functions.py
+++ b/spinedb_api/export_functions.py
@@ -10,8 +10,8 @@
# this program. If not, see .
######################################################################################################################
""" Functions for exporting data from a Spine database in a standard format. """
+from collections import namedtuple
from operator import itemgetter
-from sqlalchemy.util import KeyedTuple
from .helpers import Asterisk
from .parameter_value import from_database
@@ -108,6 +108,9 @@ def _make_item_processor(db_map, tablename):
return lambda item: (item,)
+ValueRow = namedtuple("ValueRow", ["name", "value", "type"])
+
+
class _ParameterValueListProcessor:
def __init__(self, value_items):
self._value_items_by_list_id = {}
@@ -116,7 +119,7 @@ def __init__(self, value_items):
def __call__(self, item):
for list_value_item in sorted(self._value_items_by_list_id.get(item.id, ()), key=lambda x: x.index):
- yield KeyedTuple([item.name, list_value_item.value, list_value_item.type], ["name", "value", "type"])
+ yield ValueRow(item.name, list_value_item.value, list_value_item.type)
def export_parameter_value_lists(db_map, ids=Asterisk, parse_value=from_database):
diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py
index 7ceac7eb..80dd35de 100644
--- a/spinedb_api/export_mapping/export_mapping.py
+++ b/spinedb_api/export_mapping/export_mapping.py
@@ -14,7 +14,7 @@
from dataclasses import dataclass
from itertools import cycle, dropwhile, islice
from sqlalchemy import and_
-from sqlalchemy.sql.expression import literal
+from sqlalchemy.sql.expression import literal, null
from ..mapping import Mapping, Position, is_pivoted, is_regular, unflatten
from ..parameter_value import (
IndexedValue,
@@ -146,19 +146,13 @@ def reconstruct(cls, position, value, header, filter_re, ignorable, mapping_dict
mapping.set_ignorable(ignorable)
return mapping
- def add_query_columns(self, db_map, query):
- """Adds columns to the mapping query if needed, and returns the new query.
-
- The base class implementation just returns the same query without adding any new columns.
+ def build_query_columns(self, db_map, columns):
+ """Appends columns needed to query the mapping's data into a list.
Args:
- db_map (DatabaseMapping)
- query (Alias or dict)
-
- Returns:
- Alias: expanded query, or the same if nothing to add.
+ db_map (DatabaseMapping): database mapping
+ columns (list of Column): list of columns to append to
"""
- return query
def filter_query(self, db_map, query):
"""Filters the mapping query if needed, and returns the new query.
@@ -167,10 +161,10 @@ def filter_query(self, db_map, query):
Args:
db_map (DatabaseMapping)
- query (Alias or dict)
+ query (Subquery or dict)
Returns:
- Alias: filtered query, or the same if nothing to add.
+ Subquery: filtered query, or the same if nothing to add.
"""
return query
@@ -202,10 +196,11 @@ def _build_query(self, db_map, title_state):
"""
mappings = self.flatten()
# Start with empty query
- qry = db_map.query(literal(None))
# Add columns
+ columns = []
for m in mappings:
- qry = m.add_query_columns(db_map, qry)
+ m.build_query_columns(db_map, columns)
+ qry = db_map.query(*columns)
# Apply filters
for m in mappings:
qry = m.filter_query(db_map, qry)
@@ -234,11 +229,12 @@ def _build_title_query(self, db_map):
if mappings[-1].position == Position.table_name:
break
mappings.pop(-1)
- # Start with empty query
- qry = db_map.query(literal(None))
- # Add columns
+ columns = []
for m in mappings:
- qry = m.add_query_columns(db_map, qry)
+ m.build_query_columns(db_map, columns)
+ if not columns:
+ return None
+ qry = db_map.query(*columns)
# Apply filters
for m in mappings:
qry = m.filter_query(db_map, qry)
@@ -262,18 +258,17 @@ def _build_header_query(self, db_map, title_state, buddies):
if m.position in (Position.header, Position.table_name) or m in flat_buddies:
break
mappings.pop(-1)
- # Start with empty query
- qry = db_map.query(literal(None))
- # Add columns
+ columns = []
for m in mappings:
- qry = m.add_query_columns(db_map, qry)
- # Apply filters
+ m.build_query_columns(db_map, columns)
+ if not columns:
+ return None
+ qry = db_map.query(*columns)
for m in mappings:
qry = m.filter_query(db_map, qry)
# Apply special title filters (first, so we clean up the state)
for m in mappings:
qry = m.filter_query_by_title(qry, title_state)
- # Apply standard title filters
if not title_state:
return qry
# Use a _FilteredQuery, since building a subquery to query it again leads to parser stack overflow
@@ -307,7 +302,7 @@ def _data(self, row):
The base class implementation returns the field given by ``name_field()``.
Args:
- row (KeyedTuple)
+ row (Row)
Returns:
any
@@ -349,7 +344,7 @@ def _get_rows(self, db_row):
"""Yields rows issued by this mapping for given database row.
Args:
- db_row (KeyedTuple)
+ db_row (Row)
Returns:
generator(dict)
@@ -368,7 +363,7 @@ def get_rows_recursive(self, db_row):
"""Takes a database row and yields rows issued by this mapping and its children combined.
Args:
- db_row (KeyedTuple)
+ db_row (Row)
Returns:
generator(dict)
@@ -415,7 +410,7 @@ def _title_state(self, db_row):
to the corresponding field from the row.
Args:
- db_row (KeyedTuple)
+ db_row (Row)
Returns:
dict
@@ -429,7 +424,7 @@ def _get_titles(self, db_row, limit=None):
"""Yields pairs (title, title state) issued by this mapping for given database row.
Args:
- db_row (KeyedTuple)
+ db_row (Row)
limit (int, optional): yield only this many items
Returns:
@@ -452,7 +447,7 @@ def get_titles_recursive(self, db_row, limit=None):
"""Takes a database row and yields pairs (title, title state) issued by this mapping and its children combined.
Args:
- db_row (KeyedTuple)
+ db_row (Row)
limit (int, optional): yield only this many items
Returns:
@@ -478,6 +473,8 @@ def _non_unique_titles(self, db_map, limit=None):
tuple(str,dict): title, and associated title state dictionary
"""
qry = self._build_title_query(db_map)
+ if qry is None:
+ return
for db_row in qry:
yield from self.get_titles_recursive(db_row, limit=limit)
@@ -512,9 +509,7 @@ def make_header_recursive(self, query, buddies):
"""Builds the header recursively.
Args:
- build_header_query (callable): a function that any mapping in the hierarchy can call to get the query
- db_map (DatabaseMapping): database map
- title_state (dict): title state
+ query (Query, optional): export query
buddies (list of tuple): buddy mappings
Returns
@@ -525,7 +520,7 @@ def make_header_recursive(self, query, buddies):
return {}
return {self.position: self.header}
header = self.child.make_header_recursive(query, buddies)
- if self.position == Position.header:
+ if self.position == Position.header and query is not None:
buddy = find_my_buddy(self, buddies)
if buddy is not None:
query.rewind()
@@ -547,7 +542,9 @@ def make_header(self, db_map, title_state, buddies):
Returns
dict: a mapping from column index to string header
"""
- query = _Rewindable(self._build_header_query(db_map, title_state, buddies))
+ query = self._build_header_query(db_map, title_state, buddies)
+ if query is not None:
+ query = _Rewindable(query)
return self.make_header_recursive(query, buddies)
@@ -609,16 +606,15 @@ def __init__(self, position, value=None, header="", filter_re="", highlight_posi
super().__init__(position, value, header, filter_re)
self.highlight_position = highlight_position
- def add_query_columns(self, db_map, query):
- query = query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.wide_entity_class_sq.c.id.label("entity_class_id"),
db_map.wide_entity_class_sq.c.name.label("entity_class_name"),
db_map.wide_entity_class_sq.c.dimension_id_list.label("dimension_id_list"),
db_map.wide_entity_class_sq.c.dimension_name_list.label("dimension_name_list"),
- )
+ ]
if self.highlight_position is not None:
- query = query.add_columns(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id"))
- return query
+ columns.append(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id"))
def filter_query(self, db_map, query):
if any(isinstance(m, (DimensionMapping, ElementMapping)) for m in self.flatten()):
@@ -675,16 +671,15 @@ class EntityMapping(ExportMapping):
MAP_TYPE = "Entity"
- def add_query_columns(self, db_map, query):
- query = query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.wide_entity_sq.c.id.label("entity_id"),
db_map.wide_entity_sq.c.name.label("entity_name"),
db_map.wide_entity_sq.c.element_id_list,
db_map.wide_entity_sq.c.element_name_list,
- )
+ ]
if self.query_parents("highlight_position") is not None:
- query = query.add_columns(db_map.entity_element_sq.c.element_id.label("highlighted_element_id"))
- return query
+ columns.append(db_map.entity_element_sq.c.element_id.label("highlighted_element_id"))
def filter_query(self, db_map, query):
query = query.outerjoin(
@@ -727,8 +722,8 @@ class EntityGroupMapping(ExportMapping):
MAP_TYPE = "EntityGroup"
- def add_query_columns(self, db_map, query):
- return query.add_columns(db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name)
+ def build_query_columns(self, db_map, columns):
+ columns += [db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name]
def filter_query(self, db_map, query):
return query.outerjoin(
@@ -756,10 +751,8 @@ class EntityGroupEntityMapping(ExportMapping):
MAP_TYPE = "EntityGroupEntity"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
- db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name")
- )
+ def build_query_columns(self, db_map, columns):
+ columns += [db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name")]
def filter_query(self, db_map, query):
return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.wide_entity_sq.c.id)
@@ -864,11 +857,11 @@ class ParameterDefinitionMapping(ExportMapping):
MAP_TYPE = "ParameterDefinition"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.parameter_definition_sq.c.id.label("parameter_definition_id"),
db_map.parameter_definition_sq.c.name.label("parameter_definition_name"),
- )
+ ]
def filter_query(self, db_map, query):
if self.query_parents("highlight_position") is not None:
@@ -898,10 +891,8 @@ class ParameterDefaultValueMapping(ExportMapping):
MAP_TYPE = "ParameterDefaultValue"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
- db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type
- )
+ def build_query_columns(self, db_map, columns):
+ columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type]
@staticmethod
def name_field():
@@ -986,12 +977,10 @@ class ParameterDefaultValueIndexMapping(_MappingWithLeafMixin, ExportMapping):
MAP_TYPE = "ParameterDefaultValueIndex"
- def add_query_columns(self, db_map, query):
- if "default_value" in set(query.column_names()):
- return query
- return query.add_columns(
- db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type
- )
+ def build_query_columns(self, db_map, columns):
+ if any(c.name == "default_value" for c in columns):
+ return
+ columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type]
def _expand_data(self, data):
yield from _expand_indexed_data(data, self)
@@ -1046,11 +1035,11 @@ class ParameterValueMapping(ExportMapping):
MAP_TYPE = "ParameterValue"
_selects_value = False
- def add_query_columns(self, db_map, query):
- if "value" in set(query.column_names()):
- return query
+ def build_query_columns(self, db_map, columns):
+ if any(c.name == "value" for c in columns):
+ return
self._selects_value = True
- return query.add_columns(db_map.parameter_value_sq.c.value, db_map.parameter_value_sq.c.type)
+ columns += [db_map.parameter_value_sq.c.value, db_map.parameter_value_sq.c.type]
def filter_query(self, db_map, query):
if not self._selects_value:
@@ -1113,7 +1102,7 @@ def filter_query_by_title(self, query, title_state):
pv = title_state.pop("type_and_dimensions", None)
if pv is None:
return query
- if "value" not in set(query.column_names()):
+ if all(d["name"] != "value" for d in query.column_descriptions):
return query
return _FilteredQuery(
query, lambda db_row: (db_row.type, from_database_to_dimension_count(db_row.value, db_row.type) == pv)
@@ -1190,11 +1179,11 @@ class ParameterValueListMapping(ExportMapping):
MAP_TYPE = "ParameterValueList"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.parameter_value_list_sq.c.id.label("parameter_value_list_id"),
db_map.parameter_value_list_sq.c.name.label("parameter_value_list_name"),
- )
+ ]
def filter_query(self, db_map, query):
if self.parent is None:
@@ -1222,8 +1211,8 @@ class ParameterValueListValueMapping(ExportMapping):
MAP_TYPE = "ParameterValueListValue"
- def add_query_columns(self, db_map, query):
- return query.add_columns(db_map.ord_list_value_sq.c.value, db_map.ord_list_value_sq.c.type)
+ def build_query_columns(self, db_map, columns):
+ columns += [db_map.ord_list_value_sq.c.value, db_map.ord_list_value_sq.c.type]
def filter_query(self, db_map, query):
return query.filter(db_map.ord_list_value_sq.c.parameter_value_list_id == db_map.parameter_value_list_sq.c.id)
@@ -1252,12 +1241,12 @@ class AlternativeMapping(ExportMapping):
MAP_TYPE = "Alternative"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.alternative_sq.c.id.label("alternative_id"),
db_map.alternative_sq.c.name.label("alternative_name"),
db_map.alternative_sq.c.description.label("description"),
- )
+ ]
def filter_query(self, db_map, query):
parent = self.parent
@@ -1284,12 +1273,12 @@ class ScenarioMapping(ExportMapping):
MAP_TYPE = "Scenario"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.scenario_sq.c.id.label("scenario_id"),
db_map.scenario_sq.c.name.label("scenario_name"),
db_map.scenario_sq.c.description.label("description"),
- )
+ ]
@staticmethod
def name_field():
@@ -1308,18 +1297,19 @@ class ScenarioAlternativeMapping(ExportMapping):
MAP_TYPE = "ScenarioAlternative"
- def add_query_columns(self, db_map, query):
+ def build_query_columns(self, db_map, columns):
if self._child is None:
- return query.add_columns(
+ columns += [
db_map.ext_scenario_sq.c.alternative_id,
db_map.ext_scenario_sq.c.alternative_name,
db_map.ext_scenario_sq.c.rank,
- )
- # Legacy: expecting child to be ScenarioBeforeAlternativeMapping
- return query.add_columns(
- db_map.ext_linked_scenario_alternative_sq.c.alternative_id,
- db_map.ext_linked_scenario_alternative_sq.c.alternative_name,
- )
+ ]
+ else:
+ # Legacy: expecting child to be ScenarioBeforeAlternativeMapping
+ columns += [
+ db_map.ext_linked_scenario_alternative_sq.c.alternative_id,
+ db_map.ext_linked_scenario_alternative_sq.c.alternative_name,
+ ]
def filter_query(self, db_map, query):
if self._child is None:
@@ -1354,11 +1344,11 @@ class ScenarioBeforeAlternativeMapping(ExportMapping):
MAP_TYPE = "ScenarioBeforeAlternative"
- def add_query_columns(self, db_map, query):
- return query.add_columns(
+ def build_query_columns(self, db_map, columns):
+ columns += [
db_map.ext_linked_scenario_alternative_sq.c.before_alternative_id,
db_map.ext_linked_scenario_alternative_sq.c.before_alternative_name,
- )
+ ]
@staticmethod
def name_field():
diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py
index 79e144c4..0dbad0ab 100644
--- a/spinedb_api/filters/renamer.py
+++ b/spinedb_api/filters/renamer.py
@@ -202,8 +202,8 @@ def _make_renaming_entity_class_sq(db_map, state):
subquery = state.original_entity_class_sq
if not state.id_to_name:
return subquery
- cases = [(subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items()]
- new_class_name = case(cases, else_=subquery.c.name) # if not in the name map, just keep the original name
+ cases = ((subquery.c.id == id_, new_name) for id_, new_name in state.id_to_name.items())
+ new_class_name = case(*cases, else_=subquery.c.name) # if not in the name map, just keep the original name
entity_class_sq = db_map.query(
subquery.c.id,
new_class_name.label("name"),
@@ -262,8 +262,8 @@ def _make_renaming_parameter_definition_sq(db_map, state):
subquery = state.original_parameter_definition_sq
if not state.id_to_name:
return subquery
- cases = [(subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items()]
- new_parameter_name = case(cases, else_=subquery.c.name) # if not in the name map, just keep the original name
+ cases = ((subquery.c.id == id, new_name) for id, new_name in state.id_to_name.items())
+ new_parameter_name = case(*cases, else_=subquery.c.name) # if not in the name map, just keep the original name
parameter_definition_sq = db_map.query(
subquery.c.id,
new_parameter_name.label("name"),
diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py
index 7a9cae29..a53f53c8 100644
--- a/spinedb_api/filters/value_transformer.py
+++ b/spinedb_api/filters/value_transformer.py
@@ -182,8 +182,8 @@ def _make_parameter_value_transforming_sq(db_map, state):
]
statements += [select([literal(i), literal(v), literal(t)]) for i, v, t in transformed_rows[1:]]
temp_sq = union_all(*statements).alias("transformed_values")
- new_value = case([(temp_sq.c.transformed_value != None, temp_sq.c.transformed_value)], else_=subquery.c.value)
- new_type = case([(temp_sq.c.transformed_type != None, temp_sq.c.transformed_type)], else_=subquery.c.type)
+ new_value = case((temp_sq.c.transformed_value != None, temp_sq.c.transformed_value), else_=subquery.c.value)
+ new_type = case((temp_sq.c.transformed_type != None, temp_sq.c.transformed_type), else_=subquery.c.type)
parameter_value_sq = (
db_map.query(
subquery.c.id.label("id"),
diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py
index 9ed7b87e..ef310f4e 100644
--- a/spinedb_api/helpers.py
+++ b/spinedb_api/helpers.py
@@ -45,6 +45,7 @@
inspect,
null,
select,
+ text,
true,
)
from sqlalchemy.dialects.mysql import DOUBLE, TINYINT
@@ -106,19 +107,6 @@ def name_from_dimensions(dimensions):
return name_from_elements(dimensions)
-# NOTE: Deactivated since foreign keys are too difficult to get right in the diff tables.
-# For example, the diff_object table would need a `class_id` field and a `diff_class_id` field,
-# plus a CHECK constraint that at least one of the two is NOT NULL.
-# @event.listens_for(Engine, "connect")
-def set_sqlite_pragma(dbapi_connection, connection_record):
- module_name = dbapi_connection.__class__.__module__
- if not module_name.lower().startswith("sqlite"):
- return
- cursor = dbapi_connection.cursor()
- cursor.execute("PRAGMA foreign_keys=ON")
- cursor.close()
-
-
@compiles(TINYINT, "sqlite")
def compile_TINYINT_mysql_sqlite(element, compiler, **kw):
"""Handles mysql TINYINT datatype as INTEGER in sqlite."""
@@ -134,6 +122,7 @@ def compile_DOUBLE_mysql_sqlite(element, compiler, **kw):
class group_concat(FunctionElement):
type = String()
name = "group_concat"
+ inherit_cache = True
def _parse_group_concat_clauses(clauses):
@@ -514,7 +503,7 @@ def create_spine_metadata():
Column("display_order", Integer, nullable=False),
Column(
"display_status",
- Enum(DisplayStatus, name="display_status_enum"),
+ Enum(DisplayStatus, name="display_status_enum", create_constraint=True),
server_default=DisplayStatus.visible.name,
nullable=False,
),
@@ -682,17 +671,17 @@ def create_new_spine_database(db_url):
def create_new_spine_database_from_bind(bind):
# Drop existing tables. This is a Spine db now...
- meta = MetaData(bind)
- meta.reflect()
- meta.drop_all()
+ meta = MetaData()
+ meta.reflect(bind)
+ meta.drop_all(bind)
# Create new tables
meta = create_spine_metadata()
version = get_head_alembic_version()
try:
meta.create_all(bind)
- bind.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')")
- bind.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)")
- bind.execute(f"INSERT INTO alembic_version VALUES ('{version}')")
+ bind.execute(text("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')"))
+ bind.execute(text("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)"))
+ bind.execute(text(f"INSERT INTO alembic_version VALUES ('{version}')"))
except DatabaseError as e:
raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None
@@ -706,8 +695,8 @@ def _create_first_spine_database(db_url):
except DatabaseError as e:
raise SpineDBAPIError(f"Could not connect to '{db_url}': {e.orig.args}") from None
# Drop existing tables. This is a Spine db now...
- meta = MetaData(engine)
- meta.reflect()
+ meta = MetaData()
+ meta.reflect(engine)
meta.drop_all(engine)
# Create new tables
meta = MetaData(naming_convention=naming_convention)
diff --git a/spinedb_api/query.py b/spinedb_api/query.py
deleted file mode 100644
index 8a78bcbc..00000000
--- a/spinedb_api/query.py
+++ /dev/null
@@ -1,157 +0,0 @@
-######################################################################################################################
-# Copyright (C) 2017-2022 Spine project consortium
-# Copyright Spine Database API contributors
-# This file is part of Spine Database API.
-# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
-# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
-# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
-# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
-# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
-# this program. If not, see .
-######################################################################################################################
-
-"""The :class:`Query` class."""
-
-from sqlalchemy import and_, select
-from sqlalchemy.sql.functions import count
-from .exception import SpineDBAPIError
-
-
-class Query:
- """A clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`."""
-
- def __init__(self, bind, *entities):
- """
- Args:
- bind(Engine or Connection): An engine or connection to a DB against which the query will be executed.
- entities(Iterable): A sequence of SQL expressions.
- """
- self._bind = bind
- self._entities = entities
- self._select = select(entities)
- self._from = None
-
- def __str__(self):
- return str(self._select)
-
- @property
- def column_descriptions(self):
- return [{"name": c.name} for c in self._select.columns]
-
- def column_names(self):
- yield from (c.name for c in self._select.columns)
-
- def subquery(self, name=None):
- return self._select.alias(name)
-
- def add_columns(self, *columns):
- self._entities += columns
- self._select = select(self._entities)
- return self
-
- def filter(self, *clauses):
- for clause in clauses:
- self._select = self._select.where(clause)
- return self
-
- def filter_by(self, **kwargs):
- if len(self._entities) != 1:
- raise SpineDBAPIError(f"can't find a unique 'from-clause' to filter, candidates are {self._entities}")
- return self.filter(and_(getattr(self._entities[0].c, k) == v for k, v in kwargs.items()))
-
- def _get_from(self, right, on):
- if self._from is not None:
- return self._from
- from_candidates = (set(_get_descendant_tables(on)) - {right}) & set(self._select.get_children())
- if len(from_candidates) != 1:
- raise SpineDBAPIError(f"can't find a unique 'from-clause' to join into, candidates are {from_candidates}")
- return next(iter(from_candidates))
-
- def join(self, right, on, isouter=False):
- self._from = self._get_from(right, on).join(right, on, isouter=isouter)
- self._select = self._select.select_from(self._from)
- return self
-
- def outerjoin(self, right, on):
- return self.join(right, on, isouter=True)
-
- def order_by(self, *args):
- self._select = self._select.order_by(*args)
- return self
-
- def group_by(self, *args):
- self._select = self._select.group_by(*args)
- return self
-
- def limit(self, *args):
- self._select = self._select.limit(*args)
- return self
-
- def offset(self, *args):
- self._select = self._select.offset(*args)
- return self
-
- def distinct(self, *args):
- self._select = self._select.distinct(*args)
- return self
-
- def having(self, *args):
- self._select = self._select.having(*args)
- return self
-
- def _result(self):
- return self._bind.execute(self._select)
-
- def all(self):
- return self._result().fetchall()
-
- def first(self):
- return self._result().first()
-
- def one(self):
- result = self._result()
- first = result.fetchone()
- if first is None:
- return SpineDBAPIError("no results found for one()")
- second = result.fetchone()
- if second is not None:
- raise SpineDBAPIError("multiple results found for one()")
- return first
-
- def one_or_none(self):
- result = self._result()
- first = result.fetchone()
- if first is None:
- return None
- second = result.fetchone()
- if second is not None:
- raise SpineDBAPIError("multiple results found for one_or_none()")
- return first
-
- def scalar(self):
- return self._result().scalar()
-
- def count(self):
- return self._bind.execute(select([count()]).select_from(self._select)).scalar()
-
- def __iter__(self):
- return self._result() or iter([])
-
-
-def _get_leaves(parent):
- children = parent.get_children()
- if not children:
- try:
- yield parent.table
- except AttributeError:
- pass
- for child in children:
- yield from _get_leaves(child)
-
-
-def _get_descendant_tables(on):
- for x in on.get_children():
- try:
- yield x.table
- except AttributeError:
- yield from _get_descendant_tables(x)
diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py
index 82e9a0c3..e8e375e4 100644
--- a/tests/filters/test_renamer.py
+++ b/tests/filters/test_renamer.py
@@ -39,7 +39,9 @@ class TestEntityClassRenamer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._temp_dir = TemporaryDirectory()
- cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_entity_class_renamer.sqlite").as_posix())
+ cls._db_url = URL.create(
+ "sqlite", database=Path(cls._temp_dir.name, "test_entity_class_renamer.sqlite").as_posix()
+ )
def setUp(self):
create_new_spine_database(self._db_url)
@@ -62,11 +64,10 @@ def test_renaming_singe_entity_class(self):
classes = list(self._db_map.query(self._db_map.entity_class_sq).all())
self.assertEqual(len(classes), 1)
class_row = classes[0]
- keys = tuple(class_row.keys())
expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default")
- self.assertEqual(len(keys), len(expected_keys))
+ self.assertEqual(len(class_row._fields), len(expected_keys))
for expected_key in expected_keys:
- self.assertIn(expected_key, keys)
+ self.assertIn(expected_key, class_row._fields)
self.assertEqual(class_row.name, "new_name")
def test_renaming_singe_relationship_class(self):
@@ -119,11 +120,10 @@ def test_entity_class_renamer_from_dict(self):
classes = list(self._db_map.query(self._db_map.entity_class_sq).all())
self.assertEqual(len(classes), 1)
class_row = classes[0]
- keys = tuple(class_row.keys())
expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default")
- self.assertEqual(len(keys), len(expected_keys))
+ self.assertEqual(len(class_row._fields), len(expected_keys))
for expected_key in expected_keys:
- self.assertIn(expected_key, keys)
+ self.assertIn(expected_key, class_row._fields)
self.assertEqual(class_row.name, "new_name")
@@ -144,7 +144,9 @@ class TestParameterRenamer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._temp_dir = TemporaryDirectory()
- cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_parameter_renamer.sqlite").as_posix())
+ cls._db_url = URL.create(
+ "sqlite", database=Path(cls._temp_dir.name, "test_parameter_renamer.sqlite").as_posix()
+ )
def setUp(self):
create_new_spine_database(self._db_url)
@@ -168,7 +170,6 @@ def test_renaming_single_parameter(self):
parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all())
self.assertEqual(len(parameters), 1)
parameter_row = parameters[0]
- keys = tuple(parameter_row.keys())
expected_keys = (
"id",
"name",
@@ -180,9 +181,9 @@ def test_renaming_single_parameter(self):
"commit_id",
"parameter_value_list_id",
)
- self.assertEqual(len(keys), len(expected_keys))
+ self.assertEqual(len(parameter_row._fields), len(expected_keys))
for expected_key in expected_keys:
- self.assertIn(expected_key, keys)
+ self.assertIn(expected_key, parameter_row._fields)
self.assertEqual(parameter_row.name, "new_name")
def test_renaming_applies_to_correct_parameter(self):
@@ -214,7 +215,6 @@ def test_parameter_renamer_from_dict(self):
parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all())
self.assertEqual(len(parameters), 1)
parameter_row = parameters[0]
- keys = tuple(parameter_row.keys())
expected_keys = (
"id",
"name",
@@ -226,9 +226,9 @@ def test_parameter_renamer_from_dict(self):
"commit_id",
"parameter_value_list_id",
)
- self.assertEqual(len(keys), len(expected_keys))
+ self.assertEqual(len(parameter_row._fields), len(expected_keys))
for expected_key in expected_keys:
- self.assertIn(expected_key, keys)
+ self.assertIn(expected_key, parameter_row._fields)
self.assertEqual(parameter_row.name, "new_name")
diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py
index d27cacc4..a98c0515 100644
--- a/tests/filters/test_scenario_filter.py
+++ b/tests/filters/test_scenario_filter.py
@@ -44,7 +44,7 @@ def test_filter_entities_with_default_activity_only(self):
apply_filter_stack(db_map, [scenario_filter_config("S")])
entities = db_map.query(db_map.wide_entity_sq).all()
self.assertEqual(len(entities), 1)
- self.assertEqual(entities[0]["name"], "visible_object")
+ self.assertEqual(entities[0].name, "visible_object")
def test_filter_entities_with_default_activity(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
@@ -79,8 +79,8 @@ def test_filter_entities_with_default_activity(self):
apply_filter_stack(db_map, [scenario_filter_config("S")])
entities = db_map.query(db_map.wide_entity_sq).all()
self.assertEqual(len(entities), 2)
- self.assertEqual(entities[0]["name"], "visible")
- self.assertEqual(entities[1]["name"], "visible")
+ self.assertEqual(entities[0].name, "visible")
+ self.assertEqual(entities[1].name, "visible")
def test_filter_entity_that_is_not_active_in_scenario(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
@@ -323,14 +323,14 @@ def test_scenario_filter(self):
url = "sqlite:///" + str(Path(temp_dir, "db.sqlite"))
with DatabaseMapping(url, create=True) as db_map:
self._build_data_with_single_scenario(db_map)
- with DatabaseMapping(url, create=True):
+ with DatabaseMapping(url) as db_map:
apply_scenario_filter_to_subqueries(db_map, "scenario")
parameters = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(parameters), 1)
self.assertEqual(parameters[0].value, b"23.0")
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}])
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -353,7 +353,7 @@ def test_scenario_filter_uncommitted_data(self):
apply_scenario_filter_to_subqueries(db_map, "scenario")
parameters = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(parameters), 0)
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(
alternatives, [{"name": "Base", "description": "Base alternative", "id": 1, "commit_id": 1}]
)
@@ -438,7 +438,7 @@ def test_scenario_filter_works_for_entity_sq(self):
entity_names = {
name
for x in entities
- for name in (x["element_name_list"].split(",") if x["element_name_list"] else (x["name"],))
+ for name in (x.element_name_list.split(",") if x.element_name_list else (x.name,))
}
self.assertFalse("obj2" in entity_names)
@@ -452,9 +452,9 @@ def test_scenario_filter_works_for_object_parameter_value_sq(self):
parameters = db_map.query(db_map.object_parameter_value_sq).all()
self.assertEqual(len(parameters), 1)
self.assertEqual(parameters[0].value, b"23.0")
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}])
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -514,9 +514,9 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self):
parameters = db_map.query(db_map.relationship_parameter_value_sq).all()
self.assertEqual(len(parameters), 1)
self.assertEqual((parameters[0].value, parameters[0].type), to_database(23.0))
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}])
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -610,7 +610,7 @@ def test_scenario_filter_selects_highest_ranked_alternative(self):
parameters = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(parameters), 1)
self.assertEqual((parameters[0].value, parameters[0].type), to_database(2000.0))
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(
alternatives,
[
@@ -619,7 +619,7 @@ def test_scenario_filter_selects_highest_ranked_alternative(self):
{"name": "alternative2", "description": None, "id": 4, "commit_id": 2},
],
)
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -735,7 +735,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s
parameters = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(parameters), 1)
self.assertEqual((parameters[0].value, parameters[0].type), to_database(2000.0))
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(
alternatives,
[
@@ -744,7 +744,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s
{"name": "alternative2", "description": None, "id": 4, "commit_id": 2},
],
)
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -875,9 +875,9 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self):
apply_scenario_filter_to_subqueries(db_map, "scenario")
parameters = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(parameters), 4)
- object_names = {o.id: o.name for o in db_map.query(db_map.object_sq).all()}
- alternative_names = {a.id: a.name for a in db_map.query(db_map.alternative_sq).all()}
- parameter_names = {d.id: d.name for d in db_map.query(db_map.parameter_definition_sq).all()}
+ object_names = {o.id: o.name for o in db_map.query(db_map.object_sq)}
+ alternative_names = {a.id: a.name for a in db_map.query(db_map.alternative_sq)}
+ parameter_names = {d.id: d.name for d in db_map.query(db_map.parameter_definition_sq)}
datamined_values = {}
for parameter in parameters:
self.assertEqual(alternative_names[parameter.alternative_id], "alternative")
@@ -890,9 +890,9 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self):
"object2": {"parameter1": b"20.0", "parameter2": b"22.0"},
},
)
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}])
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -940,7 +940,7 @@ def test_filters_scenarios_and_alternatives(self):
db_map.commit_session("Add test data.")
with DatabaseMapping(url, create=True) as db_map:
apply_scenario_filter_to_subqueries(db_map, "scenario2")
- alternatives = [dict(a) for a in db_map.query(db_map.alternative_sq)]
+ alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)]
self.assertEqual(
alternatives,
[
@@ -948,7 +948,7 @@ def test_filters_scenarios_and_alternatives(self):
{"name": "alternative3", "description": None, "id": 4, "commit_id": 2},
],
)
- scenarios = [dict(s) for s in db_map.query(db_map.wide_scenario_sq).all()]
+ scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq)]
self.assertEqual(
scenarios,
[
@@ -1074,10 +1074,10 @@ def test_parameter_values_for_entities_that_have_been_filtered_out_by_default(se
with DatabaseMapping(filtered_url) as db_map:
entities = db_map.query(db_map.entity_sq).all()
self.assertEqual(len(entities), 1)
- self.assertEqual(entities[0]["name"], "visible widget")
+ self.assertEqual(entities[0].name, "visible widget")
values = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(values), 1)
- self.assertEqual(from_database(values[0]["value"], values[0]["type"]), 2.3)
+ self.assertEqual(from_database(values[0].value, values[0].type), 2.3)
def test_parameter_values_for_entities_that_swim_against_active_by_default(self):
with TemporaryDirectory() as temp_dir:
@@ -1136,10 +1136,10 @@ def test_parameter_values_for_entities_that_swim_against_active_by_default(self)
with DatabaseMapping(filtered_url) as db_map:
entities = db_map.query(db_map.entity_sq).all()
self.assertEqual(len(entities), 1)
- self.assertEqual(entities[0]["name"], "invisible widget")
+ self.assertEqual(entities[0].name, "invisible widget")
values = db_map.query(db_map.parameter_value_sq).all()
self.assertEqual(len(values), 1)
- self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3)
+ self.assertEqual(from_database(values[0].value, values[0].type), -2.3)
def test_parameter_values_of_multidimensional_entity_whose_elements_have_entity_alternatives(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py
deleted file mode 100644
index a9921663..00000000
--- a/tests/filters/test_tool_filter.py
+++ /dev/null
@@ -1,221 +0,0 @@
-######################################################################################################################
-# Copyright (C) 2017-2022 Spine project consortium
-# Copyright Spine Database API contributors
-# This file is part of Spine Database API.
-# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
-# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
-# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
-# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
-# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
-# this program. If not, see .
-######################################################################################################################
-
-"""
-Unit tests for ``tool_entity_filter`` module.
-
-"""
-from pathlib import Path
-from tempfile import TemporaryDirectory
-import unittest
-from sqlalchemy.engine.url import URL
-from spinedb_api import (
- DatabaseMapping,
- SpineDBAPIError,
- create_new_spine_database,
- import_object_classes,
- import_object_parameter_values,
- import_object_parameters,
- import_objects,
- import_parameter_value_lists,
- import_relationship_classes,
- import_relationship_parameter_values,
- import_relationship_parameters,
- import_relationships,
-)
-
-
-@unittest.skip("obsolete, but need to adapt into the scenario filter")
-class TestToolEntityFilter(unittest.TestCase):
- _db_url = None
- _temp_dir = None
-
- @classmethod
- def setUpClass(cls):
- cls._temp_dir = TemporaryDirectory()
- cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_tool_filter_mapping.sqlite").as_posix())
-
- def setUp(self):
- create_new_spine_database(self._db_url)
- self._db_map = DatabaseMapping(self._db_url)
-
- def tearDown(self):
- self._db_map.close()
-
- def _build_data_with_tools(self):
- import_object_classes(self._db_map, ["object_class"])
- import_objects(
- self._db_map,
- [
- ("object_class", "object1"),
- ("object_class", "object2"),
- ("object_class", "object3"),
- ("object_class", "object4"),
- ],
- )
- import_parameter_value_lists(
- self._db_map, [("methods", "methodA"), ("methods", "methodB"), ("methods", "methodC")]
- )
- import_object_parameters(
- self._db_map,
- [
- ("object_class", "parameter1", "methodA", "methods"),
- ("object_class", "parameter2", "methodC", "methods"),
- ],
- )
- import_object_parameter_values(
- self._db_map,
- [
- ("object_class", "object1", "parameter1", "methodA"),
- ("object_class", "object2", "parameter1", "methodB"),
- ("object_class", "object3", "parameter1", "methodC"),
- ("object_class", "object4", "parameter1", "methodB"),
- ("object_class", "object2", "parameter2", "methodA"),
- ("object_class", "object3", "parameter2", "methodC"),
- ],
- )
- import_tools(self._db_map, ["tool1", "tool2"])
- import_features(self._db_map, [("object_class", "parameter1"), ("object_class", "parameter2")])
- import_tool_features(
- self._db_map,
- [("tool1", "object_class", "parameter1", False), ("tool2", "object_class", "parameter1", False)],
- )
-
- def test_non_existing_tool_filter_raises(self):
- self._build_data_with_tools()
- self._db_map.commit_session("Add test data")
- self.assertRaises(SpineDBAPIError, apply_tool_filter_to_entity_sq, self._db_map, "notool")
-
- def test_tool_feature_no_filter(self):
- self._build_data_with_tools()
- self._db_map.commit_session("Add test data")
- apply_tool_filter_to_entity_sq(self._db_map, "tool1")
- entities = self._db_map.query(self._db_map.entity_sq).all()
- self.assertEqual(len(entities), 4)
- names = [x.name for x in entities]
- self.assertIn("object1", names)
- self.assertIn("object2", names)
- self.assertIn("object3", names)
- self.assertIn("object4", names)
-
- def test_tool_feature_required(self):
- self._build_data_with_tools()
- import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)])
- self._db_map.commit_session("Add test data")
- apply_tool_filter_to_entity_sq(self._db_map, "tool1")
- entities = self._db_map.query(self._db_map.entity_sq).all()
- self.assertEqual(len(entities), 2)
- names = [x.name for x in entities]
- self.assertIn("object2", names)
- self.assertIn("object3", names)
-
- def test_tool_feature_method(self):
- self._build_data_with_tools()
- import_tool_feature_methods(
- self._db_map,
- [("tool1", "object_class", "parameter1", "methodB"), ("tool2", "object_class", "parameter1", "methodC")],
- )
- self._db_map.commit_session("Add test data")
- apply_tool_filter_to_entity_sq(self._db_map, "tool1")
- entities = self._db_map.query(self._db_map.entity_sq).all()
- self.assertEqual(len(entities), 2)
- names = [x.name for x in entities]
- self.assertIn("object2", names)
- self.assertIn("object4", names)
-
- def test_tool_feature_required_and_method(self):
- self._build_data_with_tools()
- import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)])
- import_tool_feature_methods(
- self._db_map,
- [("tool1", "object_class", "parameter1", "methodB"), ("tool2", "object_class", "parameter1", "methodC")],
- )
- self._db_map.commit_session("Add test data")
- apply_tool_filter_to_entity_sq(self._db_map, "tool1")
- entities = self._db_map.query(self._db_map.entity_sq).all()
- self.assertEqual(len(entities), 1)
- self.assertEqual(entities[0].name, "object2")
-
- def test_tool_filter_config(self):
- config = tool_filter_config("tool name")
- self.assertEqual(config, {"type": "tool_filter", "tool": "tool name"})
-
- def test_tool_filter_from_dict(self):
- self._build_data_with_tools()
- import_tool_features(self._db_map, [("tool1", "object_class", "parameter2", True)])
- self._db_map.commit_session("Add test data")
- config = tool_filter_config("tool1")
- tool_filter_from_dict(self._db_map, config)
- entities = self._db_map.query(self._db_map.entity_sq).all()
- self.assertEqual(len(entities), 2)
- names = [x.name for x in entities]
- self.assertIn("object2", names)
- self.assertIn("object3", names)
-
- def test_tool_filter_config_to_shorthand(self):
- config = tool_filter_config("tool name")
- shorthand = tool_filter_config_to_shorthand(config)
- self.assertEqual(shorthand, "tool:tool name")
-
- def test_tool_filter_shorthand_to_config(self):
- config = tool_filter_shorthand_to_config("tool:tool name")
- self.assertEqual(config, {"type": "tool_filter", "tool": "tool name"})
-
- def test_object_activity_control_filter(self):
- import_object_classes(self._db_map, ["node", "unit"])
- import_relationship_classes(self._db_map, [["node__unit", ["node", "unit"]]])
- import_objects(self._db_map, [("node", "node1"), ("node", "node2"), ("unit", "unita"), ("unit", "unitb")])
- import_relationships(
- self._db_map,
- [
- ["node__unit", ["node1", "unita"]],
- ["node__unit", ["node1", "unitb"]],
- ["node__unit", ["node2", "unita"]],
- ],
- )
- import_parameter_value_lists(self._db_map, [("boolean", True), ("boolean", False)])
- import_object_parameters(self._db_map, [("node", "is_active", True, "boolean")])
- import_relationship_parameters(self._db_map, [("node__unit", "x")])
- import_object_parameter_values(self._db_map, [("node", "node1", "is_active", False)])
- import_relationship_parameter_values(
- self._db_map,
- [
- ["node__unit", ["node1", "unita"], "x", 5],
- ["node__unit", ["node1", "unitb"], "x", 7],
- ["node__unit", ["node2", "unita"], "x", 11],
- ],
- )
- import_tools(self._db_map, ["obj_act_ctrl"])
- import_features(self._db_map, [("node", "is_active")])
- import_tool_features(self._db_map, [("obj_act_ctrl", "node", "is_active", False)])
- import_tool_feature_methods(self._db_map, [("obj_act_ctrl", "node", "is_active", True)])
- self._db_map.commit_session("Add obj act ctrl filter")
- apply_tool_filter_to_entity_sq(self._db_map, "obj_act_ctrl")
- objects = self._db_map.query(self._db_map.object_sq).all()
- self.assertEqual(len(objects), 3)
- object_names = [x.name for x in objects]
- self.assertTrue("node1" not in object_names)
- self.assertTrue("node2" in object_names)
- self.assertTrue("unita" in object_names)
- self.assertTrue("unitb" in object_names)
- relationships = self._db_map.query(self._db_map.wide_relationship_sq).all()
- self.assertEqual(len(relationships), 1)
- relationship_object_names = relationships[0].object_name_list.split(",")
- self.assertTrue("node1" not in relationship_object_names)
- ent_pvals = self._db_map.query(self._db_map.entity_parameter_value_sq).all()
- self.assertEqual(len(ent_pvals), 1)
- pval_object_names = ent_pvals[0].object_name_list.split(",")
- self.assertTrue("node1" not in pval_object_names)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py
index 5bb71348..640de2a2 100644
--- a/tests/filters/test_tools.py
+++ b/tests/filters/test_tools.py
@@ -81,13 +81,13 @@ def test_mixture_of_files_and_shorthands(self):
class TestApplyFilterStack(unittest.TestCase):
- _db_url = URL("sqlite")
+ _db_url = URL.create("sqlite")
_dir = None
@classmethod
def setUpClass(cls):
cls._dir = TemporaryDirectory()
- cls._db_url.database = os.path.join(cls._dir.name, ".sqlite")
+ cls._db_url = cls._db_url.set(database=os.path.join(cls._dir.name, ".sqlite"))
db_map = DatabaseMapping(cls._db_url, create=True)
import_object_classes(db_map, ("object_class",))
db_map.commit_session("Add test data.")
@@ -134,7 +134,7 @@ class TestFilteredDatabaseMap(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._dir = TemporaryDirectory()
- cls._db_url.database = os.path.join(cls._dir.name, "TestFilteredDatabaseMap.sqlite")
+ cls._db_url = cls._db_url.set(database=os.path.join(cls._dir.name, "TestFilteredDatabaseMap.sqlite"))
db_map = DatabaseMapping(cls._db_url, create=True)
import_object_classes(db_map, ("object_class",))
db_map.commit_session("Add test data.")
diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py
index 5497567a..985fff8b 100644
--- a/tests/filters/test_value_transformer.py
+++ b/tests/filters/test_value_transformer.py
@@ -99,7 +99,9 @@ class TestValueTransformerUsingDatabase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._temp_dir = TemporaryDirectory()
- cls._db_url = URL("sqlite", database=Path(cls._temp_dir.name, "test_value_transformer.sqlite").as_posix())
+ cls._db_url = URL.create(
+ "sqlite", database=Path(cls._temp_dir.name, "test_value_transformer.sqlite").as_posix()
+ )
@classmethod
def tearDownClass(cls):
diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py
index bd8181da..3fad1e32 100644
--- a/tests/spine_io/exporters/test_sql_writer.py
+++ b/tests/spine_io/exporters/test_sql_writer.py
@@ -9,10 +9,7 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see .
######################################################################################################################
-"""
-Unit tests for SQL writer.
-
-"""
+""" Unit tests for SQL writer. """
from pathlib import Path
from tempfile import TemporaryDirectory
import unittest
diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py
index 0abea455..08da7560 100644
--- a/tests/test_DatabaseMapping.py
+++ b/tests/test_DatabaseMapping.py
@@ -10,6 +10,7 @@
# this program. If not, see .
######################################################################################################################
""" Unit tests for DatabaseMapping class. """
+from collections import namedtuple
import multiprocessing
import os.path
from tempfile import TemporaryDirectory
@@ -19,7 +20,6 @@
from unittest.mock import patch
from dateutil.relativedelta import relativedelta
from sqlalchemy.engine.url import URL, make_url
-from sqlalchemy.util import KeyedTuple
from spinedb_api import (
DatabaseMapping,
SpineDBAPIError,
@@ -37,6 +37,10 @@
from tests.custom_db_mapping import CustomDatabaseMapping
from tests.mock_helpers import AssertSuccessTestCase
+ObjectRow = namedtuple("ObjectRow", ["id", "class_id", "name"])
+ObjectClassRow = namedtuple("ObjectClassRow", ["id", "name"])
+RelationshipRow = namedtuple("RelationshipRow", ["id", "object_class_id_list", "name"])
+
def create_query_wrapper(db_map):
def query_wrapper(*args, orig_query=db_map.query, **kwargs):
@@ -72,13 +76,12 @@ def test_construction_with_sqlalchemy_url_and_filters(self):
) as mock_load:
db_map = CustomDatabaseMapping(sa_url, create=True)
db_map.close()
- mock_load.assert_called_once_with(["fltr1", "fltr2"])
+ mock_load.assert_called_once_with(("fltr1", "fltr2"))
mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}])
def test_shorthand_filter_query_works(self):
with TemporaryDirectory() as temp_dir:
- url = URL("sqlite")
- url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json")
+ url = URL.create("sqlite", database=os.path.join(temp_dir, "test_shorthand_filter_query_works.json"))
out_db_map = CustomDatabaseMapping(url, create=True)
out_db_map.add_scenarios({"name": "scen1"})
out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1})
@@ -315,7 +318,7 @@ def test_update_entity_metadata_by_changing_its_entity(self):
)
self.assertEqual(len(metadata_records), 1)
self.assertEqual(
- dict(**metadata_records[0]),
+ metadata_records[0]._asdict(),
{
"id": 1,
"entity_class_name": "my_class",
@@ -416,7 +419,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self):
)
self.assertEqual(len(metadata_records), 1)
self.assertEqual(
- dict(**metadata_records[0]),
+ metadata_records[0]._asdict(),
{
"id": 1,
"entity_class_name": "my_class",
@@ -613,9 +616,9 @@ def test_commit_parameter_value_coincidentally_called_is_active(self):
)
)
db_map.commit_session("Add test data to see if this crashes.")
- entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)}
+ entity_names = {entity.id: entity.name for entity in db_map.query(db_map.wide_entity_sq)}
alternative_names = {
- alternative["id"]: alternative["name"] for alternative in db_map.query(db_map.alternative_sq)
+ alternative.id: alternative.name for alternative in db_map.query(db_map.alternative_sq)
}
expected = {
("widget1", "Base"): True,
@@ -625,9 +628,9 @@ def test_commit_parameter_value_coincidentally_called_is_active(self):
in_database = {}
entity_alternatives = db_map.query(db_map.entity_alternative_sq)
for entity_alternative in entity_alternatives:
- entity_name = entity_names[entity_alternative["entity_id"]]
- alternative_name = alternative_names[entity_alternative["alternative_id"]]
- in_database[(entity_name, alternative_name)] = entity_alternative["active"]
+ entity_name = entity_names[entity_alternative.entity_id]
+ alternative_name = alternative_names[entity_alternative.alternative_id]
+ in_database[(entity_name, alternative_name)] = entity_alternative.active
self.assertEqual(in_database, expected)
self.assertEqual(db_map.query(db_map.parameter_value_sq).all(), [])
@@ -666,12 +669,12 @@ def test_commit_default_value_for_parameter_called_is_active(self):
)
db_map.commit_session("Add test data to see if this crashes")
active_by_defaults = {
- entity_class["name"]: entity_class["active_by_default"]
+ entity_class.name: entity_class.active_by_default
for entity_class in db_map.query(db_map.wide_entity_class_sq)
}
self.assertEqual(active_by_defaults, {"Widget": True, "Gadget": True, "NoIsActiveDefault": False})
defaults = [
- from_database(definition["default_value"], definition["default_type"])
+ from_database(definition.default_value, definition.default_type)
for definition in db_map.query(db_map.parameter_definition_sq)
]
self.assertEqual(defaults, [True, True, None])
@@ -1104,7 +1107,7 @@ def test_entity_item_active_in_scenario(self):
scenario_alternatives = db_map.query(db_map.scenario_alternative_sq).all()
self.assertEqual(len(scenario_alternatives), 5)
self.assertEqual(
- dict(scenario_alternatives[0]),
+ scenario_alternatives[0]._asdict(),
{"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 7},
)
import_functions.import_scenarios(db_map, ("scen1",))
@@ -1975,15 +1978,14 @@ def test_construction_with_filters(self):
mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}])
def test_construction_with_sqlalchemy_url_and_filters(self):
- sa_url = URL("sqlite")
- sa_url.query = {"spinedbfilter": ["fltr1", "fltr2"]}
+ sa_url = URL.create("sqlite", query={"spinedbfilter": ["fltr1", "fltr2"]})
with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply:
with patch(
"spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}]
) as mock_load:
db_map = CustomDatabaseMapping(sa_url, create=True)
db_map.close()
- mock_load.assert_called_once_with(["fltr1", "fltr2"])
+ mock_load.assert_called_once_with(("fltr1", "fltr2"))
mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}])
def test_entity_sq(self):
@@ -2737,12 +2739,11 @@ def test_add_relationship_class_with_same_name_as_existing_one(self):
):
mock_query.side_effect = query_wrapper
mock_object_class_sq.return_value = [
- KeyedTuple([1, "fish"], labels=["id", "name"]),
- KeyedTuple([2, "dog"], labels=["id", "name"]),
- ]
- mock_wide_rel_cls_sq.return_value = [
- KeyedTuple([1, "1,2", "fish__dog"], labels=["id", "object_class_id_list", "name"])
+ ObjectClassRow(1, "fish"),
+ ObjectClassRow(2, "dog"),
]
+ WideObjectClassRow = namedtuple("WideObjectClassRow", ["id", "object_class_id_list", "name"])
+ mock_wide_rel_cls_sq.return_value = [WideObjectClassRow(1, "1,2", "fish__dog")]
with self.assertRaises(SpineIntegrityError):
self._db_map.add_wide_relationship_classes(
{"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True
@@ -2757,7 +2758,7 @@ def test_add_relationship_class_with_invalid_object_class(self):
mock.patch.object(CustomDatabaseMapping, "wide_relationship_class_sq"),
):
mock_query.side_effect = query_wrapper
- mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])]
+ mock_object_class_sq.return_value = [ObjectClassRow(1, "fish")]
with self.assertRaises(SpineIntegrityError):
self._db_map.add_wide_relationship_classes(
{"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True
@@ -2814,15 +2815,15 @@ def test_add_relationship_identical_to_existing_one(self):
):
mock_query.side_effect = query_wrapper
mock_object_sq.return_value = [
- KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]),
- KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]),
- ]
- mock_wide_rel_cls_sq.return_value = [
- KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"])
- ]
- mock_wide_rel_sq.return_value = [
- KeyedTuple([1, 1, "1,2", "nemo__pluto"], labels=["id", "class_id", "object_id_list", "name"])
+ ObjectRow(1, 10, "nemo"),
+ ObjectRow(2, 20, "pluto"),
]
+ RelationshipClassRow = namedtuple("RelationshipClassRow", ["id", "object_class_id_list", "name"])
+ mock_wide_rel_cls_sq.return_value = [RelationshipClassRow(1, "10,20", "fish__dog")]
+ WideRelationshipClassRow = namedtuple(
+ "WideRelationshipClassRow", ["id", "class_id", "object_id_list", "name"]
+ )
+ mock_wide_rel_sq.return_value = [WideRelationshipClassRow(1, 1, "1,2", "nemo__pluto")]
with self.assertRaises(SpineIntegrityError):
self._db_map.add_wide_relationships(
{"name": "nemoy__plutoy", "class_id": 1, "object_id_list": [1, 2]}, strict=True
@@ -2839,12 +2840,10 @@ def test_add_relationship_with_invalid_class(self):
):
mock_query.side_effect = query_wrapper
mock_object_sq.return_value = [
- KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]),
- KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]),
- ]
- mock_wide_rel_cls_sq.return_value = [
- KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"])
+ ObjectRow(1, 10, "nemo"),
+ ObjectRow(2, 20, "pluto"),
]
+ mock_wide_rel_cls_sq.return_value = [RelationshipRow(1, "10,20", "fish__dog")]
with self.assertRaises(SpineIntegrityError):
self._db_map.add_wide_relationships(
{"name": "nemo__pluto", "class_id": 2, "object_id_list": [1, 2]}, strict=True
@@ -2861,12 +2860,10 @@ def test_add_relationship_with_invalid_object(self):
):
mock_query.side_effect = query_wrapper
mock_object_sq.return_value = [
- KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]),
- KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]),
- ]
- mock_wide_rel_cls_sq.return_value = [
- KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"])
+ ObjectRow(1, 10, "nemo"),
+ ObjectRow(2, 20, "pluto"),
]
+ mock_wide_rel_cls_sq.return_value = [RelationshipRow(1, "10,20", "fish__dog")]
with self.assertRaises(SpineIntegrityError):
self._db_map.add_wide_relationships(
{"name": "nemo__pluto", "class_id": 1, "object_id_list": [1, 3]}, strict=True
@@ -3113,10 +3110,10 @@ def test_add_alternative(self):
alternatives = self._db_map.query(self._db_map.alternative_sq).all()
self.assertEqual(len(alternatives), 2)
self.assertEqual(
- dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}
+ alternatives[0]._asdict(), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}
)
self.assertEqual(
- dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2}
+ alternatives[1]._asdict(), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2}
)
def test_add_scenario(self):
@@ -3127,7 +3124,7 @@ def test_add_scenario(self):
scenarios = self._db_map.query(self._db_map.scenario_sq).all()
self.assertEqual(len(scenarios), 1)
self.assertEqual(
- dict(scenarios[0]),
+ scenarios[0]._asdict(),
{"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2},
)
@@ -3141,7 +3138,7 @@ def test_add_scenario_alternative(self):
scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all()
self.assertEqual(len(scenario_alternatives), 1)
self.assertEqual(
- dict(scenario_alternatives[0]),
+ scenario_alternatives[0]._asdict(),
{"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3},
)
@@ -3153,7 +3150,7 @@ def test_add_metadata(self):
metadata = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata), 1)
self.assertEqual(
- dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2}
+ metadata[0]._asdict(), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2}
)
def test_add_metadata_that_exists_does_not_add_it(self):
@@ -3163,7 +3160,7 @@ def test_add_metadata_that_exists_does_not_add_it(self):
self.assertEqual(items, [])
metadata = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata), 1)
- self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2})
+ self.assertEqual(metadata[0]._asdict(), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2})
def test_add_entity_metadata_for_object(self):
import_functions.import_object_classes(self._db_map, ("fish",))
@@ -3177,7 +3174,7 @@ def test_add_entity_metadata_for_object(self):
entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all()
self.assertEqual(len(entity_metadata), 1)
self.assertEqual(
- dict(entity_metadata[0]),
+ entity_metadata[0]._asdict(),
{
"entity_id": 1,
"entity_name": "leviathan",
@@ -3203,7 +3200,7 @@ def test_add_entity_metadata_for_relationship(self):
entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all()
self.assertEqual(len(entity_metadata), 1)
self.assertEqual(
- dict(entity_metadata[0]),
+ entity_metadata[0]._asdict(),
{
"entity_id": 2,
"entity_name": "my_object__",
@@ -3233,7 +3230,7 @@ def test_add_ext_entity_metadata_for_object(self):
entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all()
self.assertEqual(len(entity_metadata), 1)
self.assertEqual(
- dict(entity_metadata[0]),
+ entity_metadata[0]._asdict(),
{
"entity_id": 1,
"entity_name": "leviathan",
@@ -3258,11 +3255,11 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an
self._db_map.commit_session("Add entity metadata")
metadata = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata), 1)
- self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
+ self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all()
self.assertEqual(len(entity_metadata), 1)
self.assertEqual(
- dict(entity_metadata[0]),
+ entity_metadata[0]._asdict(),
{
"entity_id": 1,
"entity_name": "leviathan",
@@ -3290,7 +3287,7 @@ def test_add_parameter_value_metadata(self):
value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all()
self.assertEqual(len(value_metadata), 1)
self.assertEqual(
- dict(value_metadata[0]),
+ value_metadata[0]._asdict(),
{
"alternative_name": "Base",
"entity_name": "leviathan",
@@ -3326,7 +3323,7 @@ def test_add_ext_parameter_value_metadata(self):
value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all()
self.assertEqual(len(value_metadata), 1)
self.assertEqual(
- dict(value_metadata[0]),
+ value_metadata[0]._asdict(),
{
"alternative_name": "Base",
"entity_name": "leviathan",
@@ -3355,11 +3352,11 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self):
self._db_map.commit_session("Add value metadata")
metadata = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata), 1)
- self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
+ self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all()
self.assertEqual(len(value_metadata), 1)
self.assertEqual(
- dict(value_metadata[0]),
+ value_metadata[0]._asdict(),
{
"alternative_name": "Base",
"entity_name": "leviathan",
@@ -3596,7 +3593,7 @@ def test_update_parameter_definition_value_list(self):
pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all()
self.assertEqual(len(pdefs), 1)
self.assertEqual(
- dict(pdefs[0]),
+ pdefs[0]._asdict(),
{
"commit_id": 3,
"default_type": None,
@@ -3669,10 +3666,14 @@ def test_update_object_metadata(self):
self._db_map.commit_session("Update data")
metadata_entries = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_entries), 1)
- self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3})
+ self.assertEqual(
+ metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}
+ )
entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all()
self.assertEqual(len(entity_metadata_entries), 1)
- self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 3})
+ self.assertEqual(
+ entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 3}
+ )
def test_update_object_metadata_reuses_existing_metadata(self):
import_functions.import_object_classes(self._db_map, ("my_class",))
@@ -3697,12 +3698,16 @@ def test_update_object_metadata_reuses_existing_metadata(self):
metadata_entries = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_entries), 1)
self.assertEqual(
- dict(metadata_entries[0]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2}
+ metadata_entries[0]._asdict(), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2}
)
entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all()
self.assertEqual(len(entity_metadata_entries), 2)
- self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3})
- self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2})
+ self.assertEqual(
+ entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}
+ )
+ self.assertEqual(
+ entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2}
+ )
def test_update_object_metadata_keeps_metadata_still_in_use(self):
import_functions.import_object_classes(self._db_map, ("my_class",))
@@ -3724,12 +3729,20 @@ def test_update_object_metadata_keeps_metadata_still_in_use(self):
self._db_map.commit_session("Update data")
metadata_entries = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_entries), 2)
- self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
- self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3})
+ self.assertEqual(
+ metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}
+ )
+ self.assertEqual(
+ metadata_entries[1]._asdict(), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3}
+ )
entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all()
self.assertEqual(len(entity_metadata_entries), 2)
- self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3})
- self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2})
+ self.assertEqual(
+ entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}
+ )
+ self.assertEqual(
+ entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2}
+ )
def test_update_parameter_value_metadata(self):
import_functions.import_object_classes(self._db_map, ("my_class",))
@@ -3752,11 +3765,13 @@ def test_update_parameter_value_metadata(self):
self._db_map.commit_session("Update data")
metadata_entries = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_entries), 1)
- self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3})
+ self.assertEqual(
+ metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}
+ )
value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all()
self.assertEqual(len(value_metadata_entries), 1)
self.assertEqual(
- dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 3}
+ value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 3}
)
def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self):
@@ -3780,16 +3795,22 @@ def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(
self._db_map.commit_session("Update data")
metadata_entries = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_entries), 2)
- self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2})
- self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3})
+ self.assertEqual(
+ metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}
+ )
+ self.assertEqual(
+ metadata_entries[1]._asdict(), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}
+ )
value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all()
self.assertEqual(len(value_metadata_entries), 1)
self.assertEqual(
- dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3}
+ value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3}
)
entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all()
self.assertEqual(len(entity_metadata_entries), 1)
- self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2})
+ self.assertEqual(
+ entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2}
+ )
def test_update_metadata(self):
import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',))
@@ -3802,7 +3823,7 @@ def test_update_metadata(self):
metadata_records = self._db_map.query(self._db_map.metadata_sq).all()
self.assertEqual(len(metadata_records), 1)
self.assertEqual(
- dict(metadata_records[0]), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3}
+ metadata_records[0]._asdict(), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3}
)
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 6d233472..8dc86590 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -13,6 +13,7 @@
import unittest
+from sqlalchemy import text
from spinedb_api.helpers import (
compare_schemas,
create_new_spine_database,
@@ -49,7 +50,7 @@ def test_same_schema(self):
def test_different_schema(self):
engine1 = create_new_spine_database("sqlite://")
engine2 = create_new_spine_database("sqlite://")
- engine2.execute("drop table entity")
+ engine2.execute(text("drop table entity"))
self.assertFalse(compare_schemas(engine1, engine2))
diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py
index 74252b0a..ccbd7aa4 100644
--- a/tests/test_import_functions.py
+++ b/tests/test_import_functions.py
@@ -624,7 +624,7 @@ def test_import_object_parameter_definition(self):
count = self._assert_imports(import_object_parameters(db_map, (("my_object_class", "my_parameter"),)))
self.assertEqual(count, 1)
db_map.commit_session("Add test data.")
- parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)]
+ parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)]
self.assertEqual(
parameter_definitions,
[
@@ -653,7 +653,7 @@ def test_import_object_parameter_definition_with_value_list(self):
)
self.assertEqual(count, 1)
db_map.commit_session("Add test data.")
- parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)]
+ parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)]
self.assertEqual(
parameter_definitions,
[
@@ -682,7 +682,7 @@ def test_import_object_parameter_definition_with_default_value_from_value_list(s
)
self.assertEqual(count, 1)
db_map.commit_session("Add test data.")
- parameter_definitions = [dict(row) for row in db_map.query(db_map.object_parameter_definition_sq)]
+ parameter_definitions = [row._asdict() for row in db_map.query(db_map.object_parameter_definition_sq)]
self.assertEqual(
parameter_definitions,
[
@@ -1761,7 +1761,7 @@ def test_import_object_parameter_value_metadata(self):
metadata = db_map.query(db_map.ext_parameter_value_metadata_sq).all()
self.assertEqual(len(metadata), 2)
self.assertEqual(
- dict(metadata[0]),
+ metadata[0]._asdict(),
{
"alternative_name": "Base",
"entity_name": "object",
@@ -1775,7 +1775,7 @@ def test_import_object_parameter_value_metadata(self):
},
)
self.assertEqual(
- dict(metadata[1]),
+ metadata[1]._asdict(),
{
"alternative_name": "Base",
"entity_name": "object",
@@ -1810,7 +1810,7 @@ def test_import_relationship_parameter_value_metadata(self):
metadata = db_map.query(db_map.ext_parameter_value_metadata_sq).all()
self.assertEqual(len(metadata), 2)
self.assertEqual(
- dict(metadata[0]),
+ metadata[0]._asdict(),
{
"alternative_name": "Base",
"entity_name": "object__",
@@ -1824,7 +1824,7 @@ def test_import_relationship_parameter_value_metadata(self):
},
)
self.assertEqual(
- dict(metadata[1]),
+ metadata[1]._asdict(),
{
"alternative_name": "Base",
"entity_name": "object__",
@@ -1864,7 +1864,7 @@ def test_import_single_entity_class_display_mode(self):
display_modes = db_map.query(db_map.entity_class_display_mode_sq).all()
self.assertEqual(len(display_modes), 1)
self.assertEqual(
- dict(display_modes[0]),
+ display_modes[0]._asdict(),
{
"id": 1,
"display_mode_id": 1,
diff --git a/tests/test_migration.py b/tests/test_migration.py
index b845dd97..f283c5d9 100644
--- a/tests/test_migration.py
+++ b/tests/test_migration.py
@@ -10,89 +10,77 @@
# this program. If not, see .
######################################################################################################################
-"""
-Unit tests for migration scripts.
-
-"""
+""" Unit tests for migration scripts. """
import os.path
from tempfile import TemporaryDirectory
import unittest
-from sqlalchemy import inspect
+from sqlalchemy import inspect, text
from sqlalchemy.engine.url import URL
from spinedb_api import DatabaseMapping
from spinedb_api.helpers import _create_first_spine_database, create_new_spine_database, is_head_engine, schema_dict
class TestMigration(unittest.TestCase):
- @unittest.skip(
- "default_values's server_default has been changed from 0 to NULL in the create scrip, "
- "but there's no associated upgrade script yet."
- )
- def test_upgrade_schema(self):
- """Tests that the upgrade scripts produce the same schema as the function to create
- a Spine db anew.
- """
- left_engine = _create_first_spine_database("sqlite://")
- is_head_engine(left_engine, upgrade=True)
- left_insp = inspect(left_engine)
- left_dict = schema_dict(left_insp)
- right_engine = create_new_spine_database("sqlite://")
- right_insp = inspect(right_engine)
- right_dict = schema_dict(right_insp)
- self.maxDiff = None
- self.assertEqual(str(left_dict), str(right_dict))
-
- left_ver = left_engine.execute("SELECT version_num FROM alembic_version").fetchall()
- right_ver = right_engine.execute("SELECT version_num FROM alembic_version").fetchall()
- self.assertEqual(left_ver, right_ver)
-
- left_ent_typ = left_engine.execute("SELECT * FROM entity_type").fetchall()
- right_ent_typ = right_engine.execute("SELECT * FROM entity_type").fetchall()
- left_ent_cls_typ = left_engine.execute("SELECT * FROM entity_class_type").fetchall()
- right_ent_cls_typ = right_engine.execute("SELECT * FROM entity_class_type").fetchall()
- self.assertEqual(left_ent_typ, right_ent_typ)
- self.assertEqual(left_ent_cls_typ, right_ent_cls_typ)
-
def test_upgrade_content(self):
"""Tests that the upgrade scripts when applied on a db that has some contents
persist that content entirely.
"""
with TemporaryDirectory() as temp_dir:
- db_url = URL("sqlite")
- db_url.database = os.path.join(temp_dir, "test_upgrade_content.sqlite")
- # Create *first* spine db
+ db_url = URL.create("sqlite", database=os.path.join(temp_dir, "test_upgrade_content.sqlite"))
engine = _create_first_spine_database(db_url)
# Insert basic stuff
- engine.execute("INSERT INTO object_class (id, name) VALUES (1, 'dog')")
- engine.execute("INSERT INTO object_class (id, name) VALUES (2, 'fish')")
- engine.execute("INSERT INTO object (id, class_id, name) VALUES (1, 1, 'pluto')")
- engine.execute("INSERT INTO object (id, class_id, name) VALUES (2, 1, 'scooby')")
- engine.execute("INSERT INTO object (id, class_id, name) VALUES (3, 2, 'nemo')")
+ engine.execute(text("INSERT INTO object_class (id, name) VALUES (1, 'dog')"))
+ engine.execute(text("INSERT INTO object_class (id, name) VALUES (2, 'fish')"))
+ engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (1, 1, 'pluto')"))
+ engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (2, 1, 'scooby')"))
+ engine.execute(text("INSERT INTO object (id, class_id, name) VALUES (3, 2, 'nemo')"))
+ engine.execute(
+ text(
+ "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 0, 1)"
+ )
+ )
+ engine.execute(
+ text(
+ "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 1, 2)"
+ )
+ )
+ engine.execute(
+ text(
+ "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 0, 1)"
+ )
+ )
+ engine.execute(
+ text(
+ "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 1, 3)"
+ )
+ )
+ engine.execute(
+ text(
+ "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 0, 2)"
+ )
+ )
engine.execute(
- "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 0, 1)"
+ text(
+ "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 1, 3)"
+ )
)
+ engine.execute(text("INSERT INTO parameter (id, object_class_id, name) VALUES (1, 1, 'breed')"))
+ engine.execute(text("INSERT INTO parameter (id, object_class_id, name) VALUES (2, 2, 'water')"))
engine.execute(
- "INSERT INTO relationship_class (id, name, dimension, object_class_id) VALUES (1, 'dog__fish', 1, 2)"
+ text("INSERT INTO parameter (id, relationship_class_id, name) VALUES (3, 1, 'relative_speed')")
)
engine.execute(
- "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 0, 1)"
+ text("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 1, '\"labrador\"')")
)
engine.execute(
- "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (1, 1, 'pluto__nemo', 1, 3)"
+ text("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 2, '\"big dane\"')")
)
engine.execute(
- "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 0, 2)"
+ text("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 1, '100')")
)
engine.execute(
- "INSERT INTO relationship (id, class_id, name, dimension, object_id) VALUES (2, 1, 'scooby__nemo', 1, 3)"
+ text("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')")
)
- engine.execute("INSERT INTO parameter (id, object_class_id, name) VALUES (1, 1, 'breed')")
- engine.execute("INSERT INTO parameter (id, object_class_id, name) VALUES (2, 2, 'water')")
- engine.execute("INSERT INTO parameter (id, relationship_class_id, name) VALUES (3, 1, 'relative_speed')")
- engine.execute("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 1, '\"labrador\"')")
- engine.execute("INSERT INTO parameter_value (parameter_id, object_id, value) VALUES (1, 2, '\"big dane\"')")
- engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 1, '100')")
- engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')")
# Upgrade the db and check that our stuff is still there
db_map = DatabaseMapping(db_url, upgrade=True)
object_classes = {x.id: x.name for x in db_map.query(db_map.object_class_sq)}