diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 6eacab3f10..8f5be488f5 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, List, Mapping, Optional from pydantic import BaseModel, Extra from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore @@ -174,7 +174,7 @@ def __repr__(self) -> str: ) -def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: +def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None: updated_values = { TaskJob.status: TaskStatus.FAILED.value, TaskJob.result_status: False, diff --git a/antarest/login/model.py b/antarest/login/model.py index 50c62f8295..dbccb08e34 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,3 +1,4 @@ +import contextlib import logging import typing as t import uuid @@ -298,36 +299,21 @@ class CredentialsDTO(BaseModel): def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: - with sessionmaker(bind=engine, **session_args)() as session: + make_session = sessionmaker(bind=engine, **session_args) + with make_session() as session: group = Group(id=GROUP_ID, name=GROUP_NAME) - user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) - role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID)) - - existing_group = session.query(Group).get(group.id) - if not existing_group: + with contextlib.suppress(IntegrityError): session.add(group) - try: - session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") - - existing_user = session.query(User).get(user.id) - if not existing_user: + session.commit() + + with make_session() as session: + user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) + with contextlib.suppress(IntegrityError): session.add(user) - try: - session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") - - existing_role = session.query(Role).get((USER_ID, GROUP_ID)) - if not existing_role: - role.group = session.merge(role.group) - role.identity = session.merge(role.identity) + session.commit() + + with make_session() as session: + role = Role(type=RoleType.ADMIN, identity_id=USER_ID, group_id=GROUP_ID) + with contextlib.suppress(IntegrityError): session.add(role) - try: session.commit() - except IntegrityError as e: - session.rollback() # Rollback any changes made before the error - logger.error(f"IntegrityError: {e}") diff --git a/antarest/main.py b/antarest/main.py index bf233260f2..a661e51bff 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -253,7 +253,7 @@ def fastapi_app( application.add_middleware( DBSessionMiddleware, custom_engine=engine, - session_args=dict(SESSION_ARGS), + session_args=SESSION_ARGS, ) application.add_middleware(LoggingMiddleware) @@ -409,7 +409,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) - init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd) + init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) services = create_services(config, application) if mount_front: @@ -439,7 +439,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: customize_openapi(application) cancel_orphan_tasks( engine=engine, - session_args=dict(SESSION_ARGS), + session_args=SESSION_ARGS, ) return application, services diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 70a791002d..f106099523 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List +from typing import Dict, List, cast from antarest.core.config import Config from antarest.core.interfaces.service import IService @@ -27,7 +27,7 @@ def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IServi config, False, ) - DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS)) + DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) configure_logger(config) ( diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 73782745ce..1d05cb4887 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -483,9 +483,10 @@ def test_cancel_orphan_tasks( status: int, result_status: bool, result_msg: str, - max_diff_seconds: int = 6, - test_id: str = "test_cancel_orphan_tasks_id", ): + max_diff_seconds: int = 1 + test_id: str = "test_cancel_orphan_tasks_id" + completion_date: datetime.datetime = datetime.datetime.utcnow() task_job = TaskJob( id=test_id, @@ -494,15 +495,12 @@ def test_cancel_orphan_tasks( result_msg=result_msg, completion_date=completion_date, ) - with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: - if session.query(TaskJob).get(test_id) is not None: - session.merge(task_job) - session.commit() - else: - session.add(task_job) - session.commit() - cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS)) - with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + make_session = sessionmaker(bind=db_engine, **SESSION_ARGS) + with make_session() as session: + session.add(task_job) + session.commit() + cancel_orphan_tasks(engine=db_engine, session_args=SESSION_ARGS) + with make_session() as session: if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: updated_task_job = ( session.query(TaskJob) diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 0b1da1c8f2..72fced9478 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -1,7 +1,21 @@ +import contextlib + from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.orm import sessionmaker # type: ignore -from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user +from antarest.login.model import ( + GROUP_ID, + GROUP_NAME, + USER_ID, + USER_NAME, + Group, + Password, + Role, + RoleType, + User, + init_admin_user, +) from antarest.utils import SESSION_ARGS TEST_ADMIN_PASS_WORD = "test" @@ -12,8 +26,8 @@ def test_password(): class TestInitAdminUser: - def test_nominal_init_admin_user(self, db_engine: Engine): - init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD) + def test_init_admin_user_nominal(self, db_engine: Engine): + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) make_session = sessionmaker(bind=db_engine) with make_session() as session: user = session.query(User).get(USER_ID) @@ -27,10 +41,29 @@ def test_nominal_init_admin_user(self, db_engine: Engine): assert group.name == GROUP_NAME role = session.query(Role).get((USER_ID, GROUP_ID)) assert role is not None - assert role.identity is not None - assert role.identity.id == USER_ID - assert role.identity.name == USER_NAME - assert role.identity.password.check(TEST_ADMIN_PASS_WORD) - assert role.group is not None - assert role.group.id == GROUP_ID - assert role.group.name == GROUP_NAME + assert role.identity is user + assert role.group is group + + def test_init_admin_user_redundancy_check(self, db_engine: Engine): + # run first time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + # run second time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_group(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + with contextlib.suppress(IntegrityError): + session.add(group) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_user(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = User(id=USER_ID, name=USER_NAME, password=Password(TEST_ADMIN_PASS_WORD)) + with contextlib.suppress(IntegrityError): + session.add(user) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)