Skip to content

Commit

Permalink
feat(db-init): separate database initialization from global database …
Browse files Browse the repository at this point in the history
…session (#1805)
  • Loading branch information
mabw-rte authored and laurent-laporte-pro committed Dec 7, 2023
1 parent be89cf5 commit 2785a89
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 72 deletions.
4 changes: 2 additions & 2 deletions antarest/core/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from pydantic import BaseModel

from antarest.core.roles import RoleType
from antarest.login.model import Group, Identity
from antarest.login.model import USER_ID, Group, Identity

ADMIN_ID = 1
ADMIN_ID = USER_ID


class JWTGroup(BaseModel):
Expand Down
51 changes: 50 additions & 1 deletion antarest/login/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
import typing as t
import uuid

import bcrypt
from pydantic.main import BaseModel
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.exc import IntegrityError # type: ignore
from sqlalchemy.ext.hybrid import hybrid_property # type: ignore
from sqlalchemy.orm import relationship # type: ignore
from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore

from antarest.core.persistence import Base
from antarest.core.roles import RoleType
Expand All @@ -15,6 +18,16 @@
from antarest.launcher.model import JobResult


logger = logging.getLogger(__name__)


GROUP_ID = "admin"
GROUP_NAME = "admin"

USER_ID = 1
USER_NAME = "admin"


class UserInfo(BaseModel):
id: int
name: str
Expand Down Expand Up @@ -282,3 +295,39 @@ class CredentialsDTO(BaseModel):
user: int
access_token: str
refresh_token: str


def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None:
with sessionmaker(bind=engine, **session_args)() as session:
group = Group(id=GROUP_ID, name=GROUP_NAME)
user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password))
role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID))

existing_group = session.query(Group).get(group.id)
if not existing_group:
session.add(group)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")

existing_user = session.query(User).get(user.id)
if not existing_user:
session.add(user)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")

existing_role = session.query(Role).get((USER_ID, GROUP_ID))
if not existing_role:
role.group = session.merge(role.group)
role.identity = session.merge(role.identity)
session.add(role)
try:
session.commit()
except IntegrityError as e:
session.rollback() # Rollback any changes made before the error
logger.error(f"IntegrityError: {e}")
58 changes: 4 additions & 54 deletions antarest/login/repository.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,14 @@
import logging
from typing import Dict, List, Optional
from typing import List, Optional

from sqlalchemy import exists # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import joinedload, Session, sessionmaker # type: ignore
from sqlalchemy.orm import joinedload, Session # type: ignore

from antarest.core.jwt import ADMIN_ID
from antarest.core.roles import RoleType
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Bot, Group, Password, Role, User, UserLdap
from antarest.login.model import Bot, Group, Role, User, UserLdap

logger = logging.getLogger(__name__)

DB_INIT_DEFAULT_GROUP_ID = "admin"
DB_INIT_DEFAULT_GROUP_NAME = "admin"

DB_INIT_DEFAULT_USER_ID = ADMIN_ID
DB_INIT_DEFAULT_USER_NAME = "admin"

DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID
DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin"


def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None:
with sessionmaker(bind=engine, **session_args)() as session:
group = Group(
id=DB_INIT_DEFAULT_GROUP_ID,
name=DB_INIT_DEFAULT_GROUP_NAME,
)
user = User(
id=DB_INIT_DEFAULT_USER_ID,
name=DB_INIT_DEFAULT_USER_NAME,
password=Password(admin_password),
)
role = Role(
type=RoleType.ADMIN,
identity=User(id=DB_INIT_DEFAULT_USER_ID),
group=Group(
id=DB_INIT_DEFAULT_GROUP_ID,
),
)

existing_group = session.query(Group).get(group.id)
if not existing_group:
session.add(group)
session.commit()

existing_user = session.query(User).get(user.id)
if not existing_user:
session.add(user)
session.commit()

existing_role = session.query(Role).get((DB_INIT_DEFAULT_USER_ID, DB_INIT_DEFAULT_GROUP_ID))
if not existing_role:
role.group = session.merge(role.group)
role.identity = session.merge(role.identity)
session.add(role)

session.commit()


class GroupRepository:
"""
Expand Down Expand Up @@ -143,7 +93,7 @@ def get(self, id_number: int) -> Optional[User]:
return user

def get_by_name(self, name: str) -> Optional[User]:
user: User = db.session.query(User).filter_by(name=name).first()
user: User = self.session.query(User).filter_by(name=name).first()
return user

def get_all(self) -> List[User]:
Expand Down
2 changes: 1 addition & 1 deletion antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from antarest.core.utils.utils import get_local_path
from antarest.core.utils.web import tags_metadata
from antarest.login.auth import Auth, JwtSettings
from antarest.login.repository import init_admin_user
from antarest.login.model import init_admin_user
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.singleton_services import start_all_services
from antarest.study.storage.auto_archive_service import AutoArchiveService
Expand Down
2 changes: 1 addition & 1 deletion antarest/study/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def build_study_service(
)

generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service)
generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir)
generator_matrix_constants.init_constant_matrices()
command_factory = CommandFactory(
generator_matrix_constants=generator_matrix_constants,
matrix_service=matrix_service,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ class CommandExtractor(ICommandExtractor):
def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService):
self.matrix_service = matrix_service
self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service)
self.generator_matrix_constants.init_constant_matrices(
bucket_dir=self.generator_matrix_constants.matrix_service.bucket_dir
)
self.generator_matrix_constants.init_constant_matrices()
self.patch_service = patch_service
self.command_context = CommandContext(
generator_matrix_constants=self.generator_matrix_constants,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ class GeneratorMatrixConstants:
def __init__(self, matrix_service: ISimpleMatrixService) -> None:
self.hashes: Dict[str, str] = {}
self.matrix_service: ISimpleMatrixService = matrix_service
self._lock_dir = tempfile.gettempdir()

def init_constant_matrices(self, bucket_dir: Path) -> None:
bucket_dir.mkdir(parents=True, exist_ok=True)
with FileLock(bucket_dir / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME):
def init_constant_matrices(
self,
) -> None:
with FileLock(str(Path(self._lock_dir) / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME)):
self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create(
matrix_constants.hydro.v7.max_power
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class VariantCommandsExtractor:
def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService):
self.matrix_service = matrix_service
self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service)
self.generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir)
self.generator_matrix_constants.init_constant_matrices()
self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service)

def extract(self, study: FileStudy) -> List[CommandDTO]:
Expand Down
2 changes: 1 addition & 1 deletion antarest/tools/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene
)
generator = VariantCommandGenerator(study_factory)
generator_matrix_constants = GeneratorMatrixConstants(matrix_service)
generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir)
generator_matrix_constants.init_constant_matrices()
command_factory = CommandFactory(
generator_matrix_constants=generator_matrix_constants,
matrix_service=matrix_service,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def generator_matrix_constants_fixture(
An instance of the GeneratorMatrixConstants class representing the matrix constants generator.
"""
out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service)
out_generator_matrix_constants.init_constant_matrices(bucket_dir=simple_matrix_service.bucket_dir)
out_generator_matrix_constants.init_constant_matrices()
return out_generator_matrix_constants


Expand Down
33 changes: 32 additions & 1 deletion tests/login/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
from antarest.login.model import Password
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import sessionmaker # type: ignore

from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user
from antarest.utils import SESSION_ARGS

TEST_ADMIN_PASS_WORD = "test"


def test_password():
assert Password("pwd").check("pwd")


class TestInitAdminUser:
def test_nominal_init_admin_user(self, db_engine: Engine):
init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD)
make_session = sessionmaker(bind=db_engine)
with make_session() as session:
user = session.query(User).get(USER_ID)
assert user is not None
assert user.id == USER_ID
assert user.name == USER_NAME
assert user.password.check(TEST_ADMIN_PASS_WORD)
group = session.query(Group).get(GROUP_ID)
assert group is not None
assert group.id == GROUP_ID
assert group.name == GROUP_NAME
role = session.query(Role).get((USER_ID, GROUP_ID))
assert role is not None
assert role.identity is not None
assert role.identity.id == USER_ID
assert role.identity.name == USER_NAME
assert role.identity.password.check(TEST_ADMIN_PASS_WORD)
assert role.group is not None
assert role.group.id == GROUP_ID
assert role.group.name == GROUP_NAME
2 changes: 1 addition & 1 deletion tests/storage/business/test_arealink_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService):
raw_study_service.get_raw.return_value = empty_study
raw_study_service.cache = Mock()
generator_matrix_constants = GeneratorMatrixConstants(matrix_service)
generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir)
generator_matrix_constants.init_constant_matrices()
variant_study_service.command_factory = CommandFactory(
generator_matrix_constants,
matrix_service,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_get_st_storage(self, tmp_path):
matrix_content_repository=matrix_content_repository,
)
)
generator.init_constant_matrices()

ref1 = generator.get_st_storage_pmax_injection()
matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1]
Expand Down Expand Up @@ -54,6 +55,7 @@ def test_get_binding_constraint(self, tmp_path):
matrix_content_repository=matrix_content_repository,
)
)
generator.init_constant_matrices()
series = matrix_constants.binding_constraint.series

hourly = generator.get_binding_constraint_hourly()
Expand Down
4 changes: 2 additions & 2 deletions tests/variantstudy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext:
"""
# sourcery skip: inline-immediately-returned-variable
generator_matrix_constants = GeneratorMatrixConstants(matrix_service)
generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir)
generator_matrix_constants.init_constant_matrices()
command_context = CommandContext(
generator_matrix_constants=generator_matrix_constants,
matrix_service=matrix_service,
Expand All @@ -113,7 +113,7 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory:
CommandFactory: The CommandFactory object.
"""
generator_matrix_constants = GeneratorMatrixConstants(matrix_service)
generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir)
generator_matrix_constants.init_constant_matrices()
return CommandFactory(
generator_matrix_constants=generator_matrix_constants,
matrix_service=matrix_service,
Expand Down

0 comments on commit 2785a89

Please sign in to comment.