Skip to content

Commit

Permalink
feat(db-init): separate database initialization from global database …
Browse files Browse the repository at this point in the history
…session (#1805)
  • Loading branch information
mabw-rte authored and laurent-laporte-pro committed Dec 9, 2023
1 parent d9557dc commit ef908f4
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 56 deletions.
4 changes: 2 additions & 2 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, List, Mapping, Optional

from pydantic import BaseModel, Extra
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
Expand Down Expand Up @@ -174,7 +174,7 @@ def __repr__(self) -> str:
)


def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None:
def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None:
updated_values = {
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_status: False,
Expand Down
42 changes: 14 additions & 28 deletions antarest/login/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
import typing as t
import uuid
Expand Down Expand Up @@ -298,36 +299,21 @@ class CredentialsDTO(BaseModel):


def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None:
with sessionmaker(bind=engine, **session_args)() as session:
make_session = sessionmaker(bind=engine, **session_args)
with make_session() as session:
group = Group(id=GROUP_ID, name=GROUP_NAME)
user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password))
role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID))

existing_group = session.query(Group).get(group.id)
if not existing_group:
with contextlib.suppress(IntegrityError):
session.add(group)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")

existing_user = session.query(User).get(user.id)
if not existing_user:
session.commit()

with make_session() as session:
user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password))
with contextlib.suppress(IntegrityError):
session.add(user)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")

existing_role = session.query(Role).get((USER_ID, GROUP_ID))
if not existing_role:
role.group = session.merge(role.group)
role.identity = session.merge(role.identity)
session.commit()

with make_session() as session:
role = Role(type=RoleType.ADMIN, identity_id=USER_ID, group_id=GROUP_ID)
with contextlib.suppress(IntegrityError):
session.add(role)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")
6 changes: 3 additions & 3 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def fastapi_app(
application.add_middleware(
DBSessionMiddleware,
custom_engine=engine,
session_args=dict(SESSION_ARGS),
session_args=SESSION_ARGS,
)

application.add_middleware(LoggingMiddleware)
Expand Down Expand Up @@ -409,7 +409,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
config=RATE_LIMIT_CONFIG,
)

init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd)
init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd)
services = create_services(config, application)

if mount_front:
Expand Down Expand Up @@ -439,7 +439,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
customize_openapi(application)
cancel_orphan_tasks(
engine=engine,
session_args=dict(SESSION_ARGS),
session_args=SESSION_ARGS,
)
return application, services

Expand Down
4 changes: 2 additions & 2 deletions antarest/singleton_services.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, cast

from antarest.core.config import Config
from antarest.core.interfaces.service import IService
Expand Down Expand Up @@ -27,7 +27,7 @@ def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IServi
config,
False,
)
DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS))
DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS))
configure_logger(config)

(
Expand Down
20 changes: 9 additions & 11 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,10 @@ def test_cancel_orphan_tasks(
status: int,
result_status: bool,
result_msg: str,
max_diff_seconds: int = 6,
test_id: str = "test_cancel_orphan_tasks_id",
):
max_diff_seconds: int = 1
test_id: str = "test_cancel_orphan_tasks_id"

completion_date: datetime.datetime = datetime.datetime.utcnow()
task_job = TaskJob(
id=test_id,
Expand All @@ -494,15 +495,12 @@ def test_cancel_orphan_tasks(
result_msg=result_msg,
completion_date=completion_date,
)
with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session:
if session.query(TaskJob).get(test_id) is not None:
session.merge(task_job)
session.commit()
else:
session.add(task_job)
session.commit()
cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS))
with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session:
make_session = sessionmaker(bind=db_engine, **SESSION_ARGS)
with make_session() as session:
session.add(task_job)
session.commit()
cancel_orphan_tasks(engine=db_engine, session_args=SESSION_ARGS)
with make_session() as session:
if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]:
updated_task_job = (
session.query(TaskJob)
Expand Down
53 changes: 43 additions & 10 deletions tests/login/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import contextlib

from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.exc import IntegrityError # type: ignore
from sqlalchemy.orm import sessionmaker # type: ignore

from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user
from antarest.login.model import (
GROUP_ID,
GROUP_NAME,
USER_ID,
USER_NAME,
Group,
Password,
Role,
RoleType,
User,
init_admin_user,
)
from antarest.utils import SESSION_ARGS

TEST_ADMIN_PASS_WORD = "test"
Expand All @@ -12,8 +26,8 @@ def test_password():


class TestInitAdminUser:
def test_nominal_init_admin_user(self, db_engine: Engine):
init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD)
def test_init_admin_user_nominal(self, db_engine: Engine):
init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)
make_session = sessionmaker(bind=db_engine)
with make_session() as session:
user = session.query(User).get(USER_ID)
Expand All @@ -27,10 +41,29 @@ def test_nominal_init_admin_user(self, db_engine: Engine):
assert group.name == GROUP_NAME
role = session.query(Role).get((USER_ID, GROUP_ID))
assert role is not None
assert role.identity is not None
assert role.identity.id == USER_ID
assert role.identity.name == USER_NAME
assert role.identity.password.check(TEST_ADMIN_PASS_WORD)
assert role.group is not None
assert role.group.id == GROUP_ID
assert role.group.name == GROUP_NAME
assert role.identity is user
assert role.group is group

def test_init_admin_user_redundancy_check(self, db_engine: Engine):
# run first time
init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)
# run second time
init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)

def test_init_admin_user_existing_group(self, db_engine: Engine):
make_session = sessionmaker(bind=db_engine)
with make_session() as session:
group = Group(id=GROUP_ID, name=GROUP_NAME)
with contextlib.suppress(IntegrityError):
session.add(group)
session.commit()
init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)

def test_init_admin_user_existing_user(self, db_engine: Engine):
make_session = sessionmaker(bind=db_engine)
with make_session() as session:
user = User(id=USER_ID, name=USER_NAME, password=Password(TEST_ADMIN_PASS_WORD))
with contextlib.suppress(IntegrityError):
session.add(user)
session.commit()
init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD)

0 comments on commit ef908f4

Please sign in to comment.