From f2460fe0c2870746961c9c2908c5b9190feeb21e Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Tue, 8 Oct 2024 09:45:15 +0200 Subject: [PATCH] fix: enable foreign keys for sqlite (#2172) By default, sqlite does not enforce foreign key constraints, this needs to be enabled explicitly. This is in particular important to catch those errors in our unit tests, but also for desktop mode. --------- Signed-off-by: Sylvain Leclerc --- .../utils/fastapi_sqlalchemy/middleware.py | 21 ++++++- tests/conftest.py | 3 + tests/conftest_instances.py | 28 +++++++++ tests/core/test_tasks.py | 21 ++++--- .../test_matrix_garbage_collector.py | 62 +++++++------------ 5 files changed, 84 insertions(+), 51 deletions(-) create mode 100644 tests/conftest_instances.py diff --git a/antarest/core/utils/fastapi_sqlalchemy/middleware.py b/antarest/core/utils/fastapi_sqlalchemy/middleware.py index dcc1f95b25..73a073e033 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/middleware.py +++ b/antarest/core/utils/fastapi_sqlalchemy/middleware.py @@ -1,9 +1,10 @@ from contextvars import ContextVar, Token from typing import Any, Dict, Optional, Type, Union -from sqlalchemy import create_engine # type: ignore +from sqlalchemy import create_engine, event # type: ignore from sqlalchemy.engine import Engine # type: ignore from sqlalchemy.engine.url import URL # type: ignore +from sqlalchemy.event import listen # type: ignore from sqlalchemy.orm import Session, sessionmaker # type: ignore from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request @@ -16,6 +17,20 @@ _session: ContextVar[Optional[Session]] = ContextVar("_session", default=None) +def _is_sqlite_engine(engine: Engine) -> bool: + return "sqlite" in engine.url.drivername.lower() + + +def enable_sqlite_foreign_keys(dbapi_connection: Any, connection_record: Any) -> None: + """ + By default, sqlite does not enforce foreign key constraints, + we need to tell it explicitly. + """ + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON;") + cursor.close() + + class DBSessionMiddleware(BaseHTTPMiddleware): def __init__( self, @@ -39,6 +54,10 @@ def __init__( engine = create_engine(db_url, **engine_args) else: engine = custom_engine + + if _is_sqlite_engine(engine): + listen(engine, "connect", enable_sqlite_foreign_keys) + _Session = sessionmaker(bind=engine, **session_args) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: diff --git a/tests/conftest.py b/tests/conftest.py index 3a074c8b98..cd2016daea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,9 @@ # noinspection PyUnresolvedReferences from tests.conftest_db import * +# noinspection PyUnresolvedReferences +from tests.conftest_instances import * + # noinspection PyUnresolvedReferences from tests.conftest_services import * diff --git a/tests/conftest_instances.py b/tests/conftest_instances.py new file mode 100644 index 0000000000..58237e6fc1 --- /dev/null +++ b/tests/conftest_instances.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. +""" +The aim of this module is to contain fixtures for +instantiating objects such as users, studies, ... +""" +import pytest + +from antarest.core.jwt import DEFAULT_ADMIN_USER +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db +from antarest.login.model import User + + +@pytest.fixture +def admin_user(db_middleware: DBSessionMiddleware) -> User: + with db(commit_on_exit=True): + user = User(id=DEFAULT_ADMIN_USER.id) + db.session.add(user) + return DEFAULT_ADMIN_USER diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index cb2c635177..668333e540 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -23,7 +23,7 @@ from antarest.core.config import Config from antarest.core.interfaces.eventbus import EventType, IEventBus -from antarest.core.jwt import DEFAULT_ADMIN_USER +from antarest.core.jwt import JWTUser from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError @@ -72,8 +72,9 @@ def db_engine_fixture(tmp_path: Path) -> t.Generator[Engine, None, None]: @with_db_context -def test_service(core_config: Config, event_bus: IEventBus) -> None: +def test_service(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None: engine = db.session.bind + task_job_repo = TaskJobRepository() # Prepare a TaskJob in the database @@ -92,7 +93,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: tasks = service.list_tasks( TaskListFilter(), - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), + request_params=RequestParameters(user=admin_user), ) assert len(tasks) == 1 assert tasks[0].status == TaskStatus.FAILED @@ -101,7 +102,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: # Test Case: get task status # ========================== - res = service.status_task("a", RequestParameters(user=DEFAULT_ADMIN_USER)) + res = service.status_task("a", RequestParameters(user=admin_user)) assert res is not None expected = { "completion_date_utc": ANY, @@ -134,7 +135,7 @@ def action_fail(update_msg: t.Callable[[str], None]) -> TaskResult: None, None, None, - RequestParameters(user=DEFAULT_ADMIN_USER), + RequestParameters(user=admin_user), ) service.await_task(failed_id, timeout_sec=2) @@ -162,7 +163,7 @@ def action_ok(update_msg: t.Callable[[str], None]) -> TaskResult: None, None, None, - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), + request_params=RequestParameters(user=admin_user), ) service.await_task(ok_id, timeout_sec=2) @@ -200,7 +201,7 @@ def test_repository(db_session: Session) -> None: # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id="study_id", name="foo", version="860")) + db_session.add(RawStudy(id=study_id, name="foo", version="860")) db_session.commit() # Create a TaskJobService @@ -276,7 +277,7 @@ def test_repository(db_session: Session) -> None: @with_db_context -def test_cancel(core_config: Config, event_bus: IEventBus) -> None: +def test_cancel(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None: # Create a TaskJobService and add tasks task_job_repo = TaskJobRepository() task_job_repo.save(TaskJob(id="a", name="foo")) @@ -296,7 +297,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: backend.clear_events() - service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) + service.cancel_task("b", RequestParameters(user=admin_user), dispatch=True) collected_events = backend.get_events() @@ -312,7 +313,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: backend.clear_events() - service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) + service.cancel_task("a", RequestParameters(user=admin_user), dispatch=True) collected_events = backend.get_events() assert len(collected_events) == 0, "No event should have been emitted because the task is in the service map" diff --git a/tests/matrixstore/test_matrix_garbage_collector.py b/tests/matrixstore/test_matrix_garbage_collector.py index 1e02731910..865a6754ef 100644 --- a/tests/matrixstore/test_matrix_garbage_collector.py +++ b/tests/matrixstore/test_matrix_garbage_collector.py @@ -10,24 +10,24 @@ # # This file is part of the Antares project. -from datetime import datetime from pathlib import Path from unittest.mock import Mock +import numpy as np import pytest -from sqlalchemy import create_engine -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db -from antarest.dbmodel import Base +from antarest.core.jwt import JWTUser +from antarest.core.requests import RequestParameters +from antarest.core.utils.fastapi_sqlalchemy import db from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector -from antarest.matrixstore.model import MatrixDataSet, MatrixDataSetRelation +from antarest.matrixstore.model import MatrixDataSetUpdateDTO, MatrixInfoDTO from antarest.matrixstore.repository import MatrixDataSetRepository from antarest.matrixstore.service import MatrixService from antarest.study.storage.patch_service import PatchService from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.command_factory import CommandFactory from antarest.study.storage.variantstudy.model.command.common import CommandName -from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy from antarest.study.storage.variantstudy.repository import VariantStudyRepository @@ -116,17 +116,13 @@ def test_get_matrices_used_in_raw_studies( @pytest.mark.unit_test def test_get_matrices_used_in_variant_studies( matrix_garbage_collector: MatrixGarbageCollector, + variant_study_repository: VariantStudyRepository, ): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) with db(): study_id = "study_id" + + variant_study_repository.save(VariantStudy(id=study_id)) + # TODO: add series to the command blocks command_block1 = CommandBlock( study_id=study_id, @@ -173,38 +169,24 @@ def test_get_matrices_used_in_variant_studies( @pytest.mark.unit_test def test_get_matrices_used_in_dataset( matrix_garbage_collector: MatrixGarbageCollector, + matrix_service: MatrixService, + admin_user: JWTUser, ): matrix_garbage_collector.dataset_repository = MatrixDataSetRepository() - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - dataset = MatrixDataSet( - name="name", - public=True, - owner_id="owner_id", - groups=[], - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - matrix_relation1 = MatrixDataSetRelation(name="matrix_name1") - matrix_relation1.matrix_id = "matrix_id1" - dataset.matrices.append(matrix_relation1) - matrix_relation2 = MatrixDataSetRelation(name="matrix_name2") - matrix_relation2.matrix_id = "matrix_id2" - dataset.matrices.append(matrix_relation2) with db(): - db.session.add(dataset) - db.session.commit() + matrix1_id = matrix_service.create(np.ones((1, 1))) + matrix2_id = matrix_service.create(np.ones((2, 1))) + dataset = matrix_service.create_dataset( + dataset_info=MatrixDataSetUpdateDTO(name="name", groups=[], public=True), + matrices=[MatrixInfoDTO(id=matrix1_id, name="matrix_1"), MatrixInfoDTO(id=matrix2_id, name="matrix_2")], + params=RequestParameters(admin_user), + ) + matrices = matrix_garbage_collector._get_datasets_matrices() assert len(matrices) == 2 - assert "matrix_id1" in matrices - assert "matrix_id2" in matrices + assert matrix1_id in matrices + assert matrix2_id in matrices @pytest.mark.unit_test