diff --git a/antarest/core/utils/fastapi_sqlalchemy/middleware.py b/antarest/core/utils/fastapi_sqlalchemy/middleware.py index dcc1f95b25..73a073e033 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/middleware.py +++ b/antarest/core/utils/fastapi_sqlalchemy/middleware.py @@ -1,9 +1,10 @@ from contextvars import ContextVar, Token from typing import Any, Dict, Optional, Type, Union -from sqlalchemy import create_engine # type: ignore +from sqlalchemy import create_engine, event # type: ignore from sqlalchemy.engine import Engine # type: ignore from sqlalchemy.engine.url import URL # type: ignore +from sqlalchemy.event import listen # type: ignore from sqlalchemy.orm import Session, sessionmaker # type: ignore from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request @@ -16,6 +17,20 @@ _session: ContextVar[Optional[Session]] = ContextVar("_session", default=None) +def _is_sqlite_engine(engine: Engine) -> bool: + return "sqlite" in engine.url.drivername.lower() + + +def enable_sqlite_foreign_keys(dbapi_connection: Any, connection_record: Any) -> None: + """ + By default, sqlite does not enforce foreign key constraints, + we need to tell it explicitly. + """ + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON;") + cursor.close() + + class DBSessionMiddleware(BaseHTTPMiddleware): def __init__( self, @@ -39,6 +54,10 @@ def __init__( engine = create_engine(db_url, **engine_args) else: engine = custom_engine + + if _is_sqlite_engine(engine): + listen(engine, "connect", enable_sqlite_foreign_keys) + _Session = sessionmaker(bind=engine, **session_args) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: diff --git a/tests/conftest.py b/tests/conftest.py index 3a074c8b98..cd2016daea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,9 @@ # noinspection PyUnresolvedReferences from tests.conftest_db import * +# noinspection PyUnresolvedReferences +from tests.conftest_instances import * + # noinspection PyUnresolvedReferences from tests.conftest_services import * diff --git a/tests/conftest_instances.py b/tests/conftest_instances.py new file mode 100644 index 0000000000..58237e6fc1 --- /dev/null +++ b/tests/conftest_instances.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. +""" +The aim of this module is to contain fixtures for +instantiating objects such as users, studies, ... +""" +import pytest + +from antarest.core.jwt import DEFAULT_ADMIN_USER +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db +from antarest.login.model import User + + +@pytest.fixture +def admin_user(db_middleware: DBSessionMiddleware) -> User: + with db(commit_on_exit=True): + user = User(id=DEFAULT_ADMIN_USER.id) + db.session.add(user) + return DEFAULT_ADMIN_USER diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index cb2c635177..668333e540 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -23,7 +23,7 @@ from antarest.core.config import Config from antarest.core.interfaces.eventbus import EventType, IEventBus -from antarest.core.jwt import DEFAULT_ADMIN_USER +from antarest.core.jwt import JWTUser from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError @@ -72,8 +72,9 @@ def db_engine_fixture(tmp_path: Path) -> t.Generator[Engine, None, None]: @with_db_context -def test_service(core_config: Config, event_bus: IEventBus) -> None: +def test_service(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None: engine = db.session.bind + task_job_repo = TaskJobRepository() # Prepare a TaskJob in the database @@ -92,7 +93,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: tasks = service.list_tasks( TaskListFilter(), - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), + request_params=RequestParameters(user=admin_user), ) assert len(tasks) == 1 assert tasks[0].status == TaskStatus.FAILED @@ -101,7 +102,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: # Test Case: get task status # ========================== - res = service.status_task("a", RequestParameters(user=DEFAULT_ADMIN_USER)) + res = service.status_task("a", RequestParameters(user=admin_user)) assert res is not None expected = { "completion_date_utc": ANY, @@ -134,7 +135,7 @@ def action_fail(update_msg: t.Callable[[str], None]) -> TaskResult: None, None, None, - RequestParameters(user=DEFAULT_ADMIN_USER), + RequestParameters(user=admin_user), ) service.await_task(failed_id, timeout_sec=2) @@ -162,7 +163,7 @@ def action_ok(update_msg: t.Callable[[str], None]) -> TaskResult: None, None, None, - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), + request_params=RequestParameters(user=admin_user), ) service.await_task(ok_id, timeout_sec=2) @@ -200,7 +201,7 @@ def test_repository(db_session: Session) -> None: # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id="study_id", name="foo", version="860")) + db_session.add(RawStudy(id=study_id, name="foo", version="860")) db_session.commit() # Create a TaskJobService @@ -276,7 +277,7 @@ def test_repository(db_session: Session) -> None: @with_db_context -def test_cancel(core_config: Config, event_bus: IEventBus) -> None: +def test_cancel(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None: # Create a TaskJobService and add tasks task_job_repo = TaskJobRepository() task_job_repo.save(TaskJob(id="a", name="foo")) @@ -296,7 +297,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: backend.clear_events() - service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) + service.cancel_task("b", RequestParameters(user=admin_user), dispatch=True) collected_events = backend.get_events() @@ -312,7 +313,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None: backend.clear_events() - service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) + service.cancel_task("a", RequestParameters(user=admin_user), dispatch=True) collected_events = backend.get_events() assert len(collected_events) == 0, "No event should have been emitted because the task is in the service map" diff --git a/tests/matrixstore/test_matrix_garbage_collector.py b/tests/matrixstore/test_matrix_garbage_collector.py index 1e02731910..865a6754ef 100644 --- a/tests/matrixstore/test_matrix_garbage_collector.py +++ b/tests/matrixstore/test_matrix_garbage_collector.py @@ -10,24 +10,24 @@ # # This file is part of the Antares project. -from datetime import datetime from pathlib import Path from unittest.mock import Mock +import numpy as np import pytest -from sqlalchemy import create_engine -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db -from antarest.dbmodel import Base +from antarest.core.jwt import JWTUser +from antarest.core.requests import RequestParameters +from antarest.core.utils.fastapi_sqlalchemy import db from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector -from antarest.matrixstore.model import MatrixDataSet, MatrixDataSetRelation +from antarest.matrixstore.model import MatrixDataSetUpdateDTO, MatrixInfoDTO from antarest.matrixstore.repository import MatrixDataSetRepository from antarest.matrixstore.service import MatrixService from antarest.study.storage.patch_service import PatchService from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.command_factory import CommandFactory from antarest.study.storage.variantstudy.model.command.common import CommandName -from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy from antarest.study.storage.variantstudy.repository import VariantStudyRepository @@ -116,17 +116,13 @@ def test_get_matrices_used_in_raw_studies( @pytest.mark.unit_test def test_get_matrices_used_in_variant_studies( matrix_garbage_collector: MatrixGarbageCollector, + variant_study_repository: VariantStudyRepository, ): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) with db(): study_id = "study_id" + + variant_study_repository.save(VariantStudy(id=study_id)) + # TODO: add series to the command blocks command_block1 = CommandBlock( study_id=study_id, @@ -173,38 +169,24 @@ def test_get_matrices_used_in_variant_studies( @pytest.mark.unit_test def test_get_matrices_used_in_dataset( matrix_garbage_collector: MatrixGarbageCollector, + matrix_service: MatrixService, + admin_user: JWTUser, ): matrix_garbage_collector.dataset_repository = MatrixDataSetRepository() - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - dataset = MatrixDataSet( - name="name", - public=True, - owner_id="owner_id", - groups=[], - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - matrix_relation1 = MatrixDataSetRelation(name="matrix_name1") - matrix_relation1.matrix_id = "matrix_id1" - dataset.matrices.append(matrix_relation1) - matrix_relation2 = MatrixDataSetRelation(name="matrix_name2") - matrix_relation2.matrix_id = "matrix_id2" - dataset.matrices.append(matrix_relation2) with db(): - db.session.add(dataset) - db.session.commit() + matrix1_id = matrix_service.create(np.ones((1, 1))) + matrix2_id = matrix_service.create(np.ones((2, 1))) + dataset = matrix_service.create_dataset( + dataset_info=MatrixDataSetUpdateDTO(name="name", groups=[], public=True), + matrices=[MatrixInfoDTO(id=matrix1_id, name="matrix_1"), MatrixInfoDTO(id=matrix2_id, name="matrix_2")], + params=RequestParameters(admin_user), + ) + matrices = matrix_garbage_collector._get_datasets_matrices() assert len(matrices) == 2 - assert "matrix_id1" in matrices - assert "matrix_id2" in matrices + assert matrix1_id in matrices + assert matrix2_id in matrices @pytest.mark.unit_test