Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into feature/add-endpoint-l…
Browse files Browse the repository at this point in the history
…ink-update
  • Loading branch information
TheoPascoli committed Oct 8, 2024
2 parents 06faa3e + f2460fe commit 7bd5059
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 51 deletions.
21 changes: 20 additions & 1 deletion antarest/core/utils/fastapi_sqlalchemy/middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from contextvars import ContextVar, Token
from typing import Any, Dict, Optional, Type, Union

from sqlalchemy import create_engine # type: ignore
from sqlalchemy import create_engine, event # type: ignore
from sqlalchemy.engine import Engine # type: ignore
from sqlalchemy.engine.url import URL # type: ignore
from sqlalchemy.event import listen # type: ignore
from sqlalchemy.orm import Session, sessionmaker # type: ignore
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
Expand All @@ -16,6 +17,20 @@
_session: ContextVar[Optional[Session]] = ContextVar("_session", default=None)


def _is_sqlite_engine(engine: Engine) -> bool:
return "sqlite" in engine.url.drivername.lower()


def enable_sqlite_foreign_keys(dbapi_connection: Any, connection_record: Any) -> None:
"""
By default, sqlite does not enforce foreign key constraints,
we need to tell it explicitly.
"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON;")
cursor.close()


class DBSessionMiddleware(BaseHTTPMiddleware):
def __init__(
self,
Expand All @@ -39,6 +54,10 @@ def __init__(
engine = create_engine(db_url, **engine_args)
else:
engine = custom_engine

if _is_sqlite_engine(engine):
listen(engine, "connect", enable_sqlite_foreign_keys)

_Session = sessionmaker(bind=engine, **session_args)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# noinspection PyUnresolvedReferences
from tests.conftest_db import *

# noinspection PyUnresolvedReferences
from tests.conftest_instances import *

# noinspection PyUnresolvedReferences
from tests.conftest_services import *

Expand Down
28 changes: 28 additions & 0 deletions tests/conftest_instances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
#
# See AUTHORS.txt
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
"""
The aim of this module is to contain fixtures for
instantiating objects such as users, studies, ...
"""
import pytest

from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db
from antarest.login.model import User


@pytest.fixture
def admin_user(db_middleware: DBSessionMiddleware) -> User:
with db(commit_on_exit=True):
user = User(id=DEFAULT_ADMIN_USER.id)
db.session.add(user)
return DEFAULT_ADMIN_USER
21 changes: 11 additions & 10 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from antarest.core.config import Config
from antarest.core.interfaces.eventbus import EventType, IEventBus
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.jwt import JWTUser
from antarest.core.model import PermissionInfo, PublicMode
from antarest.core.persistence import Base
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
Expand Down Expand Up @@ -72,8 +72,9 @@ def db_engine_fixture(tmp_path: Path) -> t.Generator[Engine, None, None]:


@with_db_context
def test_service(core_config: Config, event_bus: IEventBus) -> None:
def test_service(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None:
engine = db.session.bind

task_job_repo = TaskJobRepository()

# Prepare a TaskJob in the database
Expand All @@ -92,7 +93,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None:

tasks = service.list_tasks(
TaskListFilter(),
request_params=RequestParameters(user=DEFAULT_ADMIN_USER),
request_params=RequestParameters(user=admin_user),
)
assert len(tasks) == 1
assert tasks[0].status == TaskStatus.FAILED
Expand All @@ -101,7 +102,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None:
# Test Case: get task status
# ==========================

res = service.status_task("a", RequestParameters(user=DEFAULT_ADMIN_USER))
res = service.status_task("a", RequestParameters(user=admin_user))
assert res is not None
expected = {
"completion_date_utc": ANY,
Expand Down Expand Up @@ -134,7 +135,7 @@ def action_fail(update_msg: t.Callable[[str], None]) -> TaskResult:
None,
None,
None,
RequestParameters(user=DEFAULT_ADMIN_USER),
RequestParameters(user=admin_user),
)
service.await_task(failed_id, timeout_sec=2)

Expand Down Expand Up @@ -162,7 +163,7 @@ def action_ok(update_msg: t.Callable[[str], None]) -> TaskResult:
None,
None,
None,
request_params=RequestParameters(user=DEFAULT_ADMIN_USER),
request_params=RequestParameters(user=admin_user),
)
service.await_task(ok_id, timeout_sec=2)

Expand Down Expand Up @@ -200,7 +201,7 @@ def test_repository(db_session: Session) -> None:

# Create a RawStudy in the database
study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512"
db_session.add(RawStudy(id="study_id", name="foo", version="860"))
db_session.add(RawStudy(id=study_id, name="foo", version="860"))
db_session.commit()

# Create a TaskJobService
Expand Down Expand Up @@ -276,7 +277,7 @@ def test_repository(db_session: Session) -> None:


@with_db_context
def test_cancel(core_config: Config, event_bus: IEventBus) -> None:
def test_cancel(core_config: Config, event_bus: IEventBus, admin_user: JWTUser) -> None:
# Create a TaskJobService and add tasks
task_job_repo = TaskJobRepository()
task_job_repo.save(TaskJob(id="a", name="foo"))
Expand All @@ -296,7 +297,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None:

backend.clear_events()

service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True)
service.cancel_task("b", RequestParameters(user=admin_user), dispatch=True)

collected_events = backend.get_events()

Expand All @@ -312,7 +313,7 @@ def test_cancel(core_config: Config, event_bus: IEventBus) -> None:

backend.clear_events()

service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True)
service.cancel_task("a", RequestParameters(user=admin_user), dispatch=True)

collected_events = backend.get_events()
assert len(collected_events) == 0, "No event should have been emitted because the task is in the service map"
Expand Down
62 changes: 22 additions & 40 deletions tests/matrixstore/test_matrix_garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@
#
# This file is part of the Antares project.

from datetime import datetime
from pathlib import Path
from unittest.mock import Mock

import numpy as np
import pytest
from sqlalchemy import create_engine

from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db
from antarest.dbmodel import Base
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.matrixstore.model import MatrixDataSet, MatrixDataSetRelation
from antarest.matrixstore.model import MatrixDataSetUpdateDTO, MatrixInfoDTO
from antarest.matrixstore.repository import MatrixDataSetRepository
from antarest.matrixstore.service import MatrixService
from antarest.study.storage.patch_service import PatchService
from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants
from antarest.study.storage.variantstudy.command_factory import CommandFactory
from antarest.study.storage.variantstudy.model.command.common import CommandName
from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock
from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy
from antarest.study.storage.variantstudy.repository import VariantStudyRepository


Expand Down Expand Up @@ -116,17 +116,13 @@ def test_get_matrices_used_in_raw_studies(
@pytest.mark.unit_test
def test_get_matrices_used_in_variant_studies(
matrix_garbage_collector: MatrixGarbageCollector,
variant_study_repository: VariantStudyRepository,
):
engine = create_engine("sqlite:///:memory:", echo=False)
Base.metadata.create_all(engine)
# noinspection SpellCheckingInspection
DBSessionMiddleware(
None,
custom_engine=engine,
session_args={"autocommit": False, "autoflush": False},
)
with db():
study_id = "study_id"

variant_study_repository.save(VariantStudy(id=study_id))

# TODO: add series to the command blocks
command_block1 = CommandBlock(
study_id=study_id,
Expand Down Expand Up @@ -173,38 +169,24 @@ def test_get_matrices_used_in_variant_studies(
@pytest.mark.unit_test
def test_get_matrices_used_in_dataset(
matrix_garbage_collector: MatrixGarbageCollector,
matrix_service: MatrixService,
admin_user: JWTUser,
):
matrix_garbage_collector.dataset_repository = MatrixDataSetRepository()
engine = create_engine("sqlite:///:memory:", echo=False)
Base.metadata.create_all(engine)
# noinspection SpellCheckingInspection
DBSessionMiddleware(
None,
custom_engine=engine,
session_args={"autocommit": False, "autoflush": False},
)

dataset = MatrixDataSet(
name="name",
public=True,
owner_id="owner_id",
groups=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
matrix_relation1 = MatrixDataSetRelation(name="matrix_name1")
matrix_relation1.matrix_id = "matrix_id1"
dataset.matrices.append(matrix_relation1)
matrix_relation2 = MatrixDataSetRelation(name="matrix_name2")
matrix_relation2.matrix_id = "matrix_id2"
dataset.matrices.append(matrix_relation2)
with db():
db.session.add(dataset)
db.session.commit()
matrix1_id = matrix_service.create(np.ones((1, 1)))
matrix2_id = matrix_service.create(np.ones((2, 1)))
dataset = matrix_service.create_dataset(
dataset_info=MatrixDataSetUpdateDTO(name="name", groups=[], public=True),
matrices=[MatrixInfoDTO(id=matrix1_id, name="matrix_1"), MatrixInfoDTO(id=matrix2_id, name="matrix_2")],
params=RequestParameters(admin_user),
)

matrices = matrix_garbage_collector._get_datasets_matrices()
assert len(matrices) == 2
assert "matrix_id1" in matrices
assert "matrix_id2" in matrices
assert matrix1_id in matrices
assert matrix2_id in matrices


@pytest.mark.unit_test
Expand Down

0 comments on commit 7bd5059

Please sign in to comment.