diff --git a/antarest/core/jwt.py b/antarest/core/jwt.py index ff9ffd1187..4fb8b8fcb1 100644 --- a/antarest/core/jwt.py +++ b/antarest/core/jwt.py @@ -3,9 +3,9 @@ from pydantic import BaseModel from antarest.core.roles import RoleType -from antarest.login.model import Group, Identity +from antarest.login.model import USER_ID, Group, Identity -ADMIN_ID = 1 +ADMIN_ID = USER_ID class JWTGroup(BaseModel): diff --git a/antarest/login/model.py b/antarest/login/model.py index 52106685bc..50c62f8295 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,11 +1,14 @@ +import logging import typing as t import uuid import bcrypt from pydantic.main import BaseModel from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, Sequence, String # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.ext.hybrid import hybrid_property # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore from antarest.core.persistence import Base from antarest.core.roles import RoleType @@ -15,6 +18,16 @@ from antarest.launcher.model import JobResult +logger = logging.getLogger(__name__) + + +GROUP_ID = "admin" +GROUP_NAME = "admin" + +USER_ID = 1 +USER_NAME = "admin" + + class UserInfo(BaseModel): id: int name: str @@ -282,3 +295,39 @@ class CredentialsDTO(BaseModel): user: int access_token: str refresh_token: str + + +def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: + with sessionmaker(bind=engine, **session_args)() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + user = User(id=USER_ID, name=USER_NAME, password=Password(admin_password)) + role = Role(type=RoleType.ADMIN, identity=User(id=USER_ID), group=Group(id=GROUP_ID)) + + existing_group = session.query(Group).get(group.id) + if not existing_group: + session.add(group) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") + + existing_user = session.query(User).get(user.id) + if not existing_user: + session.add(user) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") + + existing_role = session.query(Role).get((USER_ID, GROUP_ID)) + if not existing_role: + role.group = session.merge(role.group) + role.identity = session.merge(role.identity) + session.add(role) + try: + session.commit() + except IntegrityError as e: + session.rollback() # Rollback any changes made before the error + logger.error(f"IntegrityError: {e}") diff --git a/antarest/login/repository.py b/antarest/login/repository.py index edac68d495..b2058952b1 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -1,64 +1,14 @@ import logging -from typing import Dict, List, Optional +from typing import List, Optional from sqlalchemy import exists # type: ignore -from sqlalchemy.engine.base import Engine # type: ignore -from sqlalchemy.orm import joinedload, Session, sessionmaker # type: ignore +from sqlalchemy.orm import joinedload, Session # type: ignore -from antarest.core.jwt import ADMIN_ID -from antarest.core.roles import RoleType from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.login.model import Bot, Group, Password, Role, User, UserLdap +from antarest.login.model import Bot, Group, Role, User, UserLdap logger = logging.getLogger(__name__) -DB_INIT_DEFAULT_GROUP_ID = "admin" -DB_INIT_DEFAULT_GROUP_NAME = "admin" - -DB_INIT_DEFAULT_USER_ID = ADMIN_ID -DB_INIT_DEFAULT_USER_NAME = "admin" - -DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID -DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin" - - -def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None: - with sessionmaker(bind=engine, **session_args)() as session: - group = Group( - id=DB_INIT_DEFAULT_GROUP_ID, - name=DB_INIT_DEFAULT_GROUP_NAME, - ) - user = User( - id=DB_INIT_DEFAULT_USER_ID, - name=DB_INIT_DEFAULT_USER_NAME, - password=Password(admin_password), - ) - role = Role( - type=RoleType.ADMIN, - identity=User(id=DB_INIT_DEFAULT_USER_ID), - group=Group( - id=DB_INIT_DEFAULT_GROUP_ID, - ), - ) - - existing_group = session.query(Group).get(group.id) - if not existing_group: - session.add(group) - session.commit() - - existing_user = session.query(User).get(user.id) - if not existing_user: - session.add(user) - session.commit() - - existing_role = session.query(Role).get((DB_INIT_DEFAULT_USER_ID, DB_INIT_DEFAULT_GROUP_ID)) - if not existing_role: - role.group = session.merge(role.group) - role.identity = session.merge(role.identity) - session.add(role) - - session.commit() - class GroupRepository: """ @@ -143,7 +93,7 @@ def get(self, id_number: int) -> Optional[User]: return user def get_by_name(self, name: str) -> Optional[User]: - user: User = db.session.query(User).filter_by(name=name).first() + user: User = self.session.query(User).filter_by(name=name).first() return user def get_all(self) -> List[User]: diff --git a/antarest/main.py b/antarest/main.py index 5e1c1ec850..bf233260f2 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -35,7 +35,7 @@ from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.login.auth import Auth, JwtSettings -from antarest.login.repository import init_admin_user +from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector from antarest.singleton_services import start_all_services from antarest.study.storage.auto_archive_service import AutoArchiveService diff --git a/antarest/study/main.py b/antarest/study/main.py index c3b48356af..0758c6d070 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -81,7 +81,7 @@ def build_study_service( ) generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index 4d3c563799..9aa5a9b397 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -48,9 +48,7 @@ class CommandExtractor(ICommandExtractor): def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) - self.generator_matrix_constants.init_constant_matrices( - bucket_dir=self.generator_matrix_constants.matrix_service.bucket_dir - ) + self.generator_matrix_constants.init_constant_matrices() self.patch_service = patch_service self.command_context = CommandContext( generator_matrix_constants=self.generator_matrix_constants, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index f5d63d5d8b..a0e7515f1e 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -55,10 +55,12 @@ class GeneratorMatrixConstants: def __init__(self, matrix_service: ISimpleMatrixService) -> None: self.hashes: Dict[str, str] = {} self.matrix_service: ISimpleMatrixService = matrix_service + self._lock_dir = tempfile.gettempdir() - def init_constant_matrices(self, bucket_dir: Path) -> None: - bucket_dir.mkdir(parents=True, exist_ok=True) - with FileLock(bucket_dir / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME): + def init_constant_matrices( + self, + ) -> None: + with FileLock(str(Path(self._lock_dir) / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME)): self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( matrix_constants.hydro.v7.max_power ) diff --git a/antarest/study/storage/variantstudy/variant_command_extractor.py b/antarest/study/storage/variantstudy/variant_command_extractor.py index 33ee3ff49f..bd052a6c0a 100644 --- a/antarest/study/storage/variantstudy/variant_command_extractor.py +++ b/antarest/study/storage/variantstudy/variant_command_extractor.py @@ -20,7 +20,7 @@ class VariantCommandsExtractor: def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) - self.generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + self.generator_matrix_constants.init_constant_matrices() self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service) def extract(self, study: FileStudy) -> List[CommandDTO]: diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 058a3402fa..c3c5db9dff 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -156,7 +156,7 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene ) generator = VariantCommandGenerator(study_factory) generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/tests/conftest_services.py b/tests/conftest_services.py index 7fa50f6f86..5afb53460b 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -149,7 +149,7 @@ def generator_matrix_constants_fixture( An instance of the GeneratorMatrixConstants class representing the matrix constants generator. """ out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service) - out_generator_matrix_constants.init_constant_matrices(bucket_dir=simple_matrix_service.bucket_dir) + out_generator_matrix_constants.init_constant_matrices() return out_generator_matrix_constants diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 787f4f2d6a..0b1da1c8f2 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -1,5 +1,36 @@ -from antarest.login.model import Password +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import sessionmaker # type: ignore + +from antarest.login.model import GROUP_ID, GROUP_NAME, USER_ID, USER_NAME, Group, Password, Role, User, init_admin_user +from antarest.utils import SESSION_ARGS + +TEST_ADMIN_PASS_WORD = "test" def test_password(): assert Password("pwd").check("pwd") + + +class TestInitAdminUser: + def test_nominal_init_admin_user(self, db_engine: Engine): + init_admin_user(db_engine, dict(SESSION_ARGS), admin_password=TEST_ADMIN_PASS_WORD) + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = session.query(User).get(USER_ID) + assert user is not None + assert user.id == USER_ID + assert user.name == USER_NAME + assert user.password.check(TEST_ADMIN_PASS_WORD) + group = session.query(Group).get(GROUP_ID) + assert group is not None + assert group.id == GROUP_ID + assert group.name == GROUP_NAME + role = session.query(Role).get((USER_ID, GROUP_ID)) + assert role is not None + assert role.identity is not None + assert role.identity.id == USER_ID + assert role.identity.name == USER_NAME + assert role.identity.password.check(TEST_ADMIN_PASS_WORD) + assert role.group is not None + assert role.group.id == GROUP_ID + assert role.group.name == GROUP_NAME diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 314e77e500..4caee7b7bd 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -99,7 +99,7 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): raw_study_service.get_raw.return_value = empty_study raw_study_service.cache = Mock() generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() variant_study_service.command_factory = CommandFactory( generator_matrix_constants, matrix_service, diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index fe679a821a..67b0c15e74 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -19,6 +19,7 @@ def test_get_st_storage(self, tmp_path): matrix_content_repository=matrix_content_repository, ) ) + generator.init_constant_matrices() ref1 = generator.get_st_storage_pmax_injection() matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1] @@ -54,6 +55,7 @@ def test_get_binding_constraint(self, tmp_path): matrix_content_repository=matrix_content_repository, ) ) + generator.init_constant_matrices() series = matrix_constants.binding_constraint.series hourly = generator.get_binding_constraint_hourly() diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index b069e029d8..011a6bb68d 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -92,7 +92,7 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext: """ # sourcery skip: inline-immediately-returned-variable generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() command_context = CommandContext( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, @@ -113,7 +113,7 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory: CommandFactory: The CommandFactory object. """ generator_matrix_constants = GeneratorMatrixConstants(matrix_service) - generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) + generator_matrix_constants.init_constant_matrices() return CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service,