From 6a5c49de1161457aae835d3d58b217cbac178af7 Mon Sep 17 00:00:00 2001 From: Mohamed Abdel Wedoud Date: Mon, 20 Nov 2023 18:41:52 +0100 Subject: [PATCH] feat(db-init): separate database initialization from global database session (#1805) --- antarest/core/tasks/model.py | 19 +- antarest/login/main.py | 2 +- antarest/login/repository.py | 270 ++++++++++++------ antarest/main.py | 22 +- antarest/matrixstore/repository.py | 62 ++-- antarest/matrixstore/service.py | 22 +- antarest/singleton_services.py | 130 ++++----- antarest/study/main.py | 1 + .../business/command_extractor.py | 3 + .../business/matrix_constants_generator.py | 95 +++--- .../variantstudy/variant_command_extractor.py | 1 + antarest/tools/lib.py | 34 ++- antarest/utils.py | 51 ++-- tests/conftest_services.py | 10 +- tests/login/test_repository.py | 85 ++---- tests/matrixstore/test_repository.py | 23 +- .../storage/business/test_arealink_manager.py | 10 +- tests/storage/integration/conftest.py | 6 +- .../test_matrix_constants_generator.py | 19 +- .../test_variant_study_service.py | 1 + tests/variantstudy/conftest.py | 8 +- 21 files changed, 503 insertions(+), 371 deletions(-) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index af3a46b8f7..1206db9fc4 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -1,11 +1,12 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore from antarest.core.persistence import Base @@ -171,3 +172,17 @@ def __repr__(self) -> str: f" result_msg={self.result_msg}," f" result_status={self.result_status}" ) + + +def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: + updated_values = { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result: False, + TaskJob.result_msg: "Task was interrupted due to server restart", + TaskJob.completion_date: datetime.utcnow(), + } + with sessionmaker(bind=engine, **session_args)() as session: + session.query(TaskJob).filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])).update( + updated_values, synchronize_session=False + ) + session.commit() diff --git a/antarest/login/main.py b/antarest/login/main.py index 9b487de5b7..d87a082abd 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -37,7 +37,7 @@ def build_login( """ if service is None: - user_repo = UserRepository(config) + user_repo = UserRepository() bot_repo = BotRepository() group_repo = GroupRepository() role_repo = RoleRepository() diff --git a/antarest/login/repository.py b/antarest/login/repository.py index 3d9cb80fc9..ab6de74027 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -1,9 +1,10 @@ import logging -from typing import List, Optional +from typing import Dict, List, Optional, Tuple, Union from sqlalchemy import exists # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, sessionmaker # type: ignore -from antarest.core.config import Config from antarest.core.jwt import ADMIN_ID from antarest.core.roles import RoleType from antarest.core.utils.fastapi_sqlalchemy import db @@ -11,43 +12,99 @@ logger = logging.getLogger(__name__) +DB_INIT_DEFAULT_GROUP_ID = "admin" +DB_INIT_DEFAULT_GROUP_NAME = "admin" + +DB_INIT_DEFAULT_USER_ID = ADMIN_ID +DB_INIT_DEFAULT_USER_NAME = "admin" + +DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID +DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin" + + +def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None: + with sessionmaker(bind=engine, **session_args)() as session: + group = Group( + id=DB_INIT_DEFAULT_GROUP_ID, + name=DB_INIT_DEFAULT_GROUP_NAME, + ) + user = User( + id=DB_INIT_DEFAULT_USER_ID, + name=DB_INIT_DEFAULT_USER_NAME, + password=Password(admin_password), + ) + role = Role( + type=RoleType.ADMIN, + identity=User(id=DB_INIT_DEFAULT_USER_ID), + group=Group( + id=DB_INIT_DEFAULT_GROUP_ID, + ), + ) + + existing_group = session.query(Group).get(group.id) + if not existing_group: + session.add(group) + session.commit() + + existing_user = session.query(User).get(user.id) + if not existing_user: + session.add(user) + session.commit() + + existing_role = session.query(Role).get((DB_INIT_DEFAULT_USER_ID, DB_INIT_DEFAULT_GROUP_ID)) + if not existing_role: + role.group = session.merge(role.group) + role.identity = session.merge(role.identity) + session.add(role) + + session.commit() + class GroupRepository: """ Database connector to manage Group entity. """ - def __init__(self) -> None: - with db(): - self.save(Group(id="admin", name="admin")) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, group: Group) -> Group: - res = db.session.query(exists().where(Group.id == group.id)).scalar() + res = self.session.query(exists().where(Group.id == group.id)).scalar() if res: - db.session.merge(group) + self.session.merge(group) else: - db.session.add(group) - db.session.commit() + self.session.add(group) + self.session.commit() logger.debug(f"Group {group.id} saved") return group def get(self, id: str) -> Optional[Group]: - group: Group = db.session.query(Group).get(id) + group: Group = self.session.query(Group).get(id) return group def get_by_name(self, name: str) -> Group: - group: Group = db.session.query(Group).filter_by(name=name).first() + group: Group = self.session.query(Group).filter_by(name=name).first() return group def get_all(self) -> List[Group]: - groups: List[Group] = db.session.query(Group).all() + groups: List[Group] = self.session.query(Group).all() return groups def delete(self, id: str) -> None: - g = db.session.query(Group).get(id) - db.session.delete(g) - db.session.commit() + g = self.session.query(Group).get(id) + self.session.delete(g) + self.session.commit() logger.debug(f"Group {id} deleted") @@ -57,49 +114,46 @@ class UserRepository: Database connector to manage User entity. """ - def __init__(self, config: Config) -> None: - # init seed admin user from conf - with db(): - admin_user = self.get_by_name("admin") - if admin_user is None: - self.save( - User( - id=ADMIN_ID, - name="admin", - password=Password(config.security.admin_pwd), - ) - ) - elif not admin_user.password.check(config.security.admin_pwd): # type: ignore - admin_user.password = Password(config.security.admin_pwd) # type: ignore - self.save(admin_user) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, user: User) -> User: - res = db.session.query(exists().where(User.id == user.id)).scalar() + res = self.session.query(exists().where(User.id == user.id)).scalar() if res: - db.session.merge(user) + self.session.merge(user) else: - db.session.add(user) - db.session.commit() + self.session.add(user) + self.session.commit() logger.debug(f"User {user.id} saved") return user - def get(self, id: int) -> Optional[User]: - user: User = db.session.query(User).get(id) + def get(self, id_number: int) -> Optional[User]: + user: User = self.session.query(User).get(id_number) return user def get_by_name(self, name: str) -> User: - user: User = db.session.query(User).filter_by(name=name).first() + user: User = self.session.query(User).filter_by(name=name).first() return user def get_all(self) -> List[User]: - users: List[User] = db.session.query(User).all() + users: List[User] = self.session.query(User).all() return users def delete(self, id: int) -> None: - u: User = db.session.query(User).get(id) - db.session.delete(u) - db.session.commit() + u: User = self.session.query(User).get(id) + self.session.delete(u) + self.session.commit() logger.debug(f"User {id} deleted") @@ -109,39 +163,54 @@ class UserLdapRepository: Database connector to manage UserLdap entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, user_ldap: UserLdap) -> UserLdap: - res = db.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() + res = self.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() if res: - db.session.merge(user_ldap) + self.session.merge(user_ldap) else: - db.session.add(user_ldap) - db.session.commit() + self.session.add(user_ldap) + self.session.commit() logger.debug(f"User LDAP {user_ldap.id} saved") return user_ldap - def get(self, id: int) -> Optional[UserLdap]: - user_ldap: Optional[UserLdap] = db.session.query(UserLdap).get(id) + def get(self, id_number: int) -> Optional[UserLdap]: + user_ldap: Optional[UserLdap] = self.session.query(UserLdap).get(id_number) return user_ldap def get_by_name(self, name: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(name=name).first() + user: UserLdap = self.session.query(UserLdap).filter_by(name=name).first() return user def get_by_external_id(self, external_id: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(external_id=external_id).first() + user: UserLdap = self.session.query(UserLdap).filter_by(external_id=external_id).first() return user - def get_all(self) -> List[UserLdap]: - users_ldap: List[UserLdap] = db.session.query(UserLdap).all() + def get_all( + self, + ) -> List[UserLdap]: + users_ldap: List[UserLdap] = self.session.query(UserLdap).all() return users_ldap - def delete(self, id: int) -> None: - u: UserLdap = db.session.query(UserLdap).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: UserLdap = self.session.query(UserLdap).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"User LDAP {id} deleted") + logger.debug(f"User LDAP {id_number} deleted") class BotRepository: @@ -149,42 +218,57 @@ class BotRepository: Database connector to manage Bot entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, bot: Bot) -> Bot: - res = db.session.query(exists().where(Bot.id == bot.id)).scalar() + res = self.session.query(exists().where(Bot.id == bot.id)).scalar() if res: raise ValueError("Bot already exist") else: - db.session.add(bot) - db.session.commit() + self.session.add(bot) + self.session.commit() logger.debug(f"Bot {bot.id} saved") return bot - def get(self, id: int) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).get(id) + def get(self, id_number: int) -> Optional[Bot]: + bot: Bot = self.session.query(Bot).get(id_number) return bot - def get_all(self) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).all() + def get_all( + self, + ) -> List[Bot]: + bots: List[Bot] = self.session.query(Bot).all() return bots - def delete(self, id: int) -> None: - u: Bot = db.session.query(Bot).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: Bot = self.session.query(Bot).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"Bot {id} deleted") + logger.debug(f"Bot {id_number} deleted") def get_all_by_owner(self, owner: int) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).filter_by(owner=owner).all() + bots: List[Bot] = self.session.query(Bot).filter_by(owner=owner).all() return bots def get_by_name_and_owner(self, owner: int, name: str) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).filter_by(owner=owner, name=name).first() + bot: Bot = self.session.query(Bot).filter_by(owner=owner, name=name).first() return bot - def exists(self, id: int) -> bool: - res: bool = db.session.query(exists().where(Bot.id == id)).scalar() + def exists(self, id_number: int) -> bool: + res: bool = self.session.query(exists().where(Bot.id == id_number)).scalar() return res @@ -193,42 +277,44 @@ class RoleRepository: Database connector to manage Role entity. """ - def __init__(self) -> None: - with db(): - if self.get(1, "admin") is None: - self.save( - Role( - type=RoleType.ADMIN, - identity=User(id=1), - group=Group(id="admin"), - ) - ) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, role: Role) -> Role: - role.group = db.session.merge(role.group) - role.identity = db.session.merge(role.identity) + role.group = self.session.merge(role.group) + role.identity = self.session.merge(role.identity) - db.session.add(role) - db.session.commit() + self.session.add(role) + self.session.commit() logger.debug(f"Role (user={role.identity}, group={role.group} saved") return role def get(self, user: int, group: str) -> Optional[Role]: - role: Role = db.session.query(Role).get((user, group)) + role: Role = self.session.query(Role).get((user, group)) return role def get_all_by_user(self, user: int) -> List[Role]: - roles: List[Role] = db.session.query(Role).filter_by(identity_id=user).all() + roles: List[Role] = self.session.query(Role).filter_by(identity_id=user).all() return roles def get_all_by_group(self, group: str) -> List[Role]: - roles: List[Role] = db.session.query(Role).filter_by(group_id=group).all() + roles: List[Role] = self.session.query(Role).filter_by(group_id=group).all() return roles def delete(self, user: int, group: str) -> None: - r = db.session.query(Role).get((user, group)) - db.session.delete(r) - db.session.commit() + r = self.session.query(Role).get((user, group)) + self.session.delete(r) + self.session.commit() logger.debug(f"Role (user={user}, group={group} deleted") diff --git a/antarest/main.py b/antarest/main.py index 1e0c9183dd..5e1c1ec850 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -30,15 +30,18 @@ from antarest.core.logging.utils import LoggingMiddleware, configure_logger from antarest.core.requests import RATE_LIMIT_CONFIG from antarest.core.swagger import customize_openapi +from antarest.core.tasks.model import cancel_orphan_tasks +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.login.auth import Auth, JwtSettings +from antarest.login.repository import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector -from antarest.singleton_services import SingletonServices +from antarest.singleton_services import start_all_services from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.study.storage.rawstudy.watcher import Watcher from antarest.tools.admin_lib import clean_locks -from antarest.utils import Module, create_services, init_db +from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine logger = logging.getLogger(__name__) @@ -246,7 +249,12 @@ def fastapi_app( ) # Database - init_db(config_file, config, auto_upgrade_db, application) + engine = init_db_engine(config_file, config, auto_upgrade_db) + application.add_middleware( + DBSessionMiddleware, + custom_engine=engine, + session_args=dict(SESSION_ARGS), + ) application.add_middleware(LoggingMiddleware) @@ -401,6 +409,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) + init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd) services = create_services(config, application) if mount_front: @@ -428,6 +437,10 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) + cancel_orphan_tasks( + engine=engine, + session_args=dict(SESSION_ARGS), + ) return application, services @@ -455,8 +468,7 @@ def main() -> None: # noinspection PyTypeChecker uvicorn.run(app, host="0.0.0.0", port=8080, log_config=LOGGING_CONFIG) else: - services = SingletonServices(arguments.config_file, [arguments.module]) - services.start() + start_all_services(arguments.config_file, [arguments.module]) if __name__ == "__main__": diff --git a/antarest/matrixstore/repository.py b/antarest/matrixstore/repository.py index 6301e39c7f..9ab44a69ec 100644 --- a/antarest/matrixstore/repository.py +++ b/antarest/matrixstore/repository.py @@ -7,7 +7,7 @@ from filelock import FileLock from numpy import typing as npt from sqlalchemy import and_, exists # type: ignore -from sqlalchemy.orm import aliased # type: ignore +from sqlalchemy.orm import Session, aliased # type: ignore from antarest.core.utils.fastapi_sqlalchemy import db from antarest.matrixstore.model import Matrix, MatrixContent, MatrixData, MatrixDataSet @@ -20,23 +20,33 @@ class MatrixDataSetRepository: Database connector to manage Matrix metadata entity """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix_user_metadata: MatrixDataSet) -> MatrixDataSet: - res: bool = db.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() + res: bool = self.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() if res: - matrix_user_metadata = db.session.merge(matrix_user_metadata) + matrix_user_metadata = self.session.merge(matrix_user_metadata) else: - db.session.add(matrix_user_metadata) - db.session.commit() + self.session.add(matrix_user_metadata) + self.session.commit() logger.debug(f"Matrix dataset {matrix_user_metadata.id} for user {matrix_user_metadata.owner_id} saved") return matrix_user_metadata - def get(self, id: str) -> t.Optional[MatrixDataSet]: - matrix: MatrixDataSet = db.session.query(MatrixDataSet).get(id) + def get(self, id_number: str) -> t.Optional[MatrixDataSet]: + matrix: MatrixDataSet = self.session.query(MatrixDataSet).get(id_number) return matrix def get_all_datasets(self) -> t.List[MatrixDataSet]: - matrix_datasets: t.List[MatrixDataSet] = db.session.query(MatrixDataSet).all() + matrix_datasets: t.List[MatrixDataSet] = self.session.query(MatrixDataSet).all() return matrix_datasets def query( @@ -54,7 +64,7 @@ def query( Returns: the list of metadata per user, matching the query """ - query = db.session.query(MatrixDataSet) + query = self.session.query(MatrixDataSet) if name is not None: query = query.filter(MatrixDataSet.name.ilike(f"%{name}%")) # type: ignore if owner is not None: @@ -63,9 +73,9 @@ def query( return datasets def delete(self, dataset_id: str) -> None: - dataset = db.session.query(MatrixDataSet).get(dataset_id) - db.session.delete(dataset) - db.session.commit() + dataset = self.session.query(MatrixDataSet).get(dataset_id) + self.session.delete(dataset) + self.session.commit() class MatrixRepository: @@ -73,28 +83,38 @@ class MatrixRepository: Database connector to manage Matrix entity. """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix: Matrix) -> Matrix: - if db.session.query(exists().where(Matrix.id == matrix.id)).scalar(): - db.session.merge(matrix) + if self.session.query(exists().where(Matrix.id == matrix.id)).scalar(): + self.session.merge(matrix) else: - db.session.add(matrix) - db.session.commit() + self.session.add(matrix) + self.session.commit() logger.debug(f"Matrix {matrix.id} saved") return matrix def get(self, matrix_hash: str) -> t.Optional[Matrix]: - matrix: Matrix = db.session.query(Matrix).get(matrix_hash) + matrix: Matrix = self.session.query(Matrix).get(matrix_hash) return matrix def exists(self, matrix_hash: str) -> bool: - res: bool = db.session.query(exists().where(Matrix.id == matrix_hash)).scalar() + res: bool = self.session.query(exists().where(Matrix.id == matrix_hash)).scalar() return res def delete(self, matrix_hash: str) -> None: - if g := db.session.query(Matrix).get(matrix_hash): - db.session.delete(g) - db.session.commit() + if g := self.session.query(Matrix).get(matrix_hash): + self.session.delete(g) + self.session.commit() else: logger.warning(f"Trying to delete matrix {matrix_hash}, but was not found in database!") logger.debug(f"Matrix {matrix_hash} deleted") diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 4869ed11fa..c7030160ad 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -54,6 +54,13 @@ class ISimpleMatrixService(ABC): + def __init__(self, matrix_content_repository: MatrixContentRepository) -> None: + self.matrix_content_repository = matrix_content_repository + + @property + def bucket_dir(self) -> Path: + return self.matrix_content_repository.bucket_dir + @abstractmethod def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: raise NotImplementedError() @@ -72,15 +79,14 @@ def delete(self, matrix_id: str) -> None: class SimpleMatrixService(ISimpleMatrixService): - def __init__(self, bucket_dir: Path): - self.bucket_dir = bucket_dir - self.content_repo = MatrixContentRepository(bucket_dir) + def __init__(self, matrix_content_repository: MatrixContentRepository): + super().__init__(matrix_content_repository=matrix_content_repository) def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: - return self.content_repo.save(data) + return self.matrix_content_repository.save(data) def get(self, matrix_id: str) -> MatrixDTO: - data = self.content_repo.get(matrix_id) + data = self.matrix_content_repository.get(matrix_id) return MatrixDTO.construct( id=matrix_id, width=len(data.columns), @@ -91,10 +97,10 @@ def get(self, matrix_id: str) -> MatrixDTO: ) def exists(self, matrix_id: str) -> bool: - return self.content_repo.exists(matrix_id) + return self.matrix_content_repository.exists(matrix_id) def delete(self, matrix_id: str) -> None: - self.content_repo.delete(matrix_id) + self.matrix_content_repository.delete(matrix_id) class MatrixService(ISimpleMatrixService): @@ -108,9 +114,9 @@ def __init__( config: Config, user_service: LoginService, ): + super().__init__(matrix_content_repository=matrix_content_repository) self.repo = repo self.repo_dataset = repo_dataset - self.matrix_content_repository = matrix_content_repository self.user_service = user_service self.file_transfer_manager = file_transfer_manager self.task_service = task_service diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 9b702a346b..70a791002d 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -1,90 +1,76 @@ -import logging -import time from pathlib import Path from typing import Dict, List from antarest.core.config import Config from antarest.core.interfaces.service import IService from antarest.core.logging.utils import configure_logger +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.utils import ( + SESSION_ARGS, Module, create_archive_worker, create_core_services, create_matrix_gc, create_simulator_worker, create_watcher, - init_db, + init_db_engine, ) -logger = logging.getLogger(__name__) - -class SingletonServices: - def __init__(self, config_file: Path, services_list: List[Module]) -> None: - self.services_list = self._init(config_file, services_list) - - @staticmethod - def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - init_db(config_file, config, False, None) - configure_logger(config) - - ( - cache, - event_bus, - task_service, - ft_manager, - login_service, - matrix_service, - study_service, - ) = create_core_services(None, config) - - services: Dict[Module, IService] = {} - - if Module.WATCHER in services_list: - watcher = create_watcher(config=config, application=None, study_service=study_service) - services[Module.WATCHER] = watcher - - if Module.MATRIX_GC in services_list: - matrix_gc = create_matrix_gc( - config=config, - application=None, - study_service=study_service, - matrix_service=matrix_service, - ) - services[Module.MATRIX_GC] = matrix_gc - - if Module.ARCHIVE_WORKER in services_list: - worker = create_archive_worker(config, "test", event_bus=event_bus) - services[Module.ARCHIVE_WORKER] = worker - - if Module.SIMULATOR_WORKER in services_list: - worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) - services[Module.SIMULATOR_WORKER] = worker - - if Module.AUTO_ARCHIVER in services_list: - auto_archive_service = AutoArchiveService(study_service, config) - services[Module.AUTO_ARCHIVER] = auto_archive_service - - return services - - def start(self) -> None: - for service in self.services_list: - self.services_list[service].start(threaded=True) - - self._loop() - - def _loop(self) -> None: - while True: - try: - pass - except Exception as e: - logger.error( - "Unexpected error happened while processing service manager loop", - exc_info=e, - ) - finally: - time.sleep(2) +def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: + res = get_local_path() / "resources" + config = Config.from_yaml_file(res=res, file=config_file) + engine = init_db_engine( + config_file, + config, + False, + ) + DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS)) + configure_logger(config) + + ( + cache, + event_bus, + task_service, + ft_manager, + login_service, + matrix_service, + study_service, + ) = create_core_services(None, config) + + services: Dict[Module, IService] = {} + + if Module.WATCHER in services_list: + watcher = create_watcher(config=config, application=None, study_service=study_service) + services[Module.WATCHER] = watcher + + if Module.MATRIX_GC in services_list: + matrix_gc = create_matrix_gc( + config=config, + application=None, + study_service=study_service, + matrix_service=matrix_service, + ) + services[Module.MATRIX_GC] = matrix_gc + + if Module.ARCHIVE_WORKER in services_list: + worker = create_archive_worker(config, "test", event_bus=event_bus) + services[Module.ARCHIVE_WORKER] = worker + + if Module.SIMULATOR_WORKER in services_list: + worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) + services[Module.SIMULATOR_WORKER] = worker + + if Module.AUTO_ARCHIVER in services_list: + auto_archive_service = AutoArchiveService(study_service, config) + services[Module.AUTO_ARCHIVER] = auto_archive_service + + return services + + +def start_all_services(config_file: Path, services_list: List[Module]) -> None: + services = _init(config_file, services_list) + for service in services: + services[service].start(threaded=True) diff --git a/antarest/study/main.py b/antarest/study/main.py index e4a981afd2..c3b48356af 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -81,6 +81,7 @@ def build_study_service( ) generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index e0fd1d1e3c..4d3c563799 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -48,6 +48,9 @@ class CommandExtractor(ICommandExtractor): def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices( + bucket_dir=self.generator_matrix_constants.matrix_service.bucket_dir + ) self.patch_service = patch_service self.command_context = CommandContext( generator_matrix_constants=self.generator_matrix_constants, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 8cb973785e..f5d63d5d8b 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -47,6 +47,7 @@ ST_STORAGE_INFLOWS = EMPTY_SCENARIO_MATRIX MATRIX_PROTOCOL_PREFIX = "matrix://" +MATRIX_CONSTANT_INIT_LOCK_FILE_NAME = "matrix_constant_init.lock" # noinspection SpellCheckingInspection @@ -54,50 +55,56 @@ class GeneratorMatrixConstants: def __init__(self, matrix_service: ISimpleMatrixService) -> None: self.hashes: Dict[str, str] = {} self.matrix_service: ISimpleMatrixService = matrix_service - with FileLock(str(Path(tempfile.gettempdir()) / "matrix_constant_init.lock")): - self._init() - - def _init(self) -> None: - self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.max_power - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( - matrix_constants.hydro.v6.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( - matrix_constants.hydro.v7.inflow_pattern - ) - self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( - matrix_constants.hydro.v7.credit_modulations - ) - self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) - self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) - self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) - - self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create(matrix_constants.thermals.prepro.modulation) - self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) - self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) - self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) - self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) - - self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) - self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) - self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) - self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) - - # Binding constraint matrices - series = matrix_constants.binding_constraint.series - self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_binding_constraint_hourly) - self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily) - self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create(series.default_binding_constraint_weekly) - - # Some short-term storage matrices use np.ones((8760, 1)) - self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( - matrix_constants.st_storage.series.pmax_injection - ) + + def init_constant_matrices(self, bucket_dir: Path) -> None: + bucket_dir.mkdir(parents=True, exist_ok=True) + with FileLock(bucket_dir / MATRIX_CONSTANT_INIT_LOCK_FILE_NAME): + self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.max_power + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( + matrix_constants.hydro.v6.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( + matrix_constants.hydro.v7.inflow_pattern + ) + self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( + matrix_constants.hydro.v7.credit_modulations + ) + self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) + self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) + self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) + + self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create( + matrix_constants.thermals.prepro.modulation + ) + self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) + self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) + self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) + self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) + + self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) + self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) + self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) + self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) + + # Binding constraint matrices + series = matrix_constants.binding_constraint.series + self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create( + series.default_binding_constraint_hourly + ) + self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily) + self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create( + series.default_binding_constraint_weekly + ) + + # Some short-term storage matrices use np.ones((8760, 1)) + self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( + matrix_constants.st_storage.series.pmax_injection + ) def get_hydro_max_power(self, version: int) -> str: if version > 650: diff --git a/antarest/study/storage/variantstudy/variant_command_extractor.py b/antarest/study/storage/variantstudy/variant_command_extractor.py index 5a88dde857..33ee3ff49f 100644 --- a/antarest/study/storage/variantstudy/variant_command_extractor.py +++ b/antarest/study/storage/variantstudy/variant_command_extractor.py @@ -20,6 +20,7 @@ class VariantCommandsExtractor: def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service) def extract(self, study: FileStudy) -> List[CommandDTO]: diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 5ade3d214b..058a3402fa 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -24,6 +24,7 @@ from antarest.core.config import CacheConfig from antarest.core.tasks.model import TaskDTO from antarest.core.utils.utils import StopWatch, get_local_path +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.model import NEW_DEFAULT_STUDY_VERSION, STUDY_REFERENCE_TEMPLATES @@ -140,7 +141,12 @@ def render_template(self, study_version: str = NEW_DEFAULT_STUDY_VERSION) -> Non def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> GenerationResultInfoDTO: stopwatch = StopWatch() - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) local_cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -149,8 +155,10 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene cache=local_cache, ) generator = VariantCommandGenerator(study_factory) + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) command_factory = CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), ) @@ -176,8 +184,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: commands_output_dir.mkdir(parents=True) matrices_dir = commands_output_dir / MATRIX_STORE_DIR matrices_dir.mkdir() - - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -187,7 +199,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: ) study = study_factory.create_from_fs(study_path, str(study_path), use_cache=False) - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) extractor = VariantCommandsExtractor(local_matrix_service, patch_service=PatchService()) command_list = extractor.extract(study) @@ -233,7 +250,12 @@ def generate_diff( study_id = "empty_base" path_study = output_dir / study_id - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) resolver = UriResolverService(matrix_service=local_matrix_service) cache = LocalCache() diff --git a/antarest/utils.py b/antarest/utils.py index d49951017f..39ea094168 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -1,7 +1,8 @@ +import datetime import logging from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Mapping, Optional, Tuple import redis import sqlalchemy.ext.baked # type: ignore @@ -12,6 +13,7 @@ from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore from antarest.core.cache.main import build_cache @@ -20,13 +22,11 @@ from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import IEventBus -from antarest.core.logging.utils import configure_logger from antarest.core.maintenance.main import build_maintenance_manager from antarest.core.persistence import upgrade_db from antarest.core.tasks.main import build_taskjob_manager from antarest.core.tasks.service import ITaskService -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware -from antarest.core.utils.utils import get_local_path, new_redis_instance +from antarest.core.utils.utils import new_redis_instance from antarest.eventbus.main import build_eventbus from antarest.launcher.main import build_launcher from antarest.login.main import build_login @@ -46,6 +46,19 @@ logger = logging.getLogger(__name__) +SESSION_ARGS: Mapping[str, bool] = { + "autocommit": False, + "expire_on_commit": False, + "autoflush": False, +} +""" +This mapping can be used to instantiate a new session, for example: + +>>> with sessionmaker(engine, **SESSION_ARGS)() as session: +... session.execute("SELECT 1") +""" + + class Module(str, Enum): APP = "app" WATCHER = "watcher" @@ -55,12 +68,11 @@ class Module(str, Enum): SIMULATOR_WORKER = "simulator_worker" -def init_db( +def init_db_engine( config_file: Path, config: Config, auto_upgrade_db: bool, - application: Optional[FastAPI], -) -> None: +) -> Engine: if auto_upgrade_db: upgrade_db(config_file) connect_args: Dict[str, Any] = {} @@ -86,19 +98,7 @@ def init_db( engine = create_engine(config.db.db_url, echo=config.debug, connect_args=connect_args, **extra) - session_args = { - "autocommit": False, - "expire_on_commit": False, - "autoflush": False, - } - if application: - application.add_middleware( - DBSessionMiddleware, - custom_engine=engine, - session_args=session_args, - ) - else: - DBSessionMiddleware(None, custom_engine=engine, session_args=session_args) + return engine def create_event_bus( @@ -264,14 +264,3 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: services["cache"] = cache services["maintenance"] = maintenance_service return services - - -def create_env(config_file: Path) -> Dict[str, Any]: - """ - Create application services env for testing and scripting purpose - """ - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - configure_logger(config) - init_db(config_file, config, False, None) - return create_services(config, None) diff --git a/tests/conftest_services.py b/tests/conftest_services.py index ee2fea2057..7fa50f6f86 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -18,6 +18,7 @@ from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskListFilter, TaskResult, TaskStatus, TaskType from antarest.core.tasks.service import ITaskService, Task from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.storage.patch_service import PatchService @@ -128,7 +129,10 @@ def simple_matrix_service_fixture(bucket_dir: Path) -> SimpleMatrixService: Returns: An instance of the SimpleMatrixService class representing the matrix service. """ - return SimpleMatrixService(bucket_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=bucket_dir, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @pytest.fixture(name="generator_matrix_constants", scope="session") @@ -144,7 +148,9 @@ def generator_matrix_constants_fixture( Returns: An instance of the GeneratorMatrixConstants class representing the matrix constants generator. """ - return GeneratorMatrixConstants(matrix_service=simple_matrix_service) + out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service) + out_generator_matrix_constants.init_constant_matrices(bucket_dir=simple_matrix_service.bucket_dir) + return out_generator_matrix_constants @pytest.fixture(name="uri_resolver_service", scope="session") diff --git a/tests/login/test_repository.py b/tests/login/test_repository.py index 3599ec437c..2019f6e940 100644 --- a/tests/login/test_repository.py +++ b/tests/login/test_repository.py @@ -1,29 +1,14 @@ import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore +from sqlalchemy.orm import Session, scoped_session, sessionmaker # type: ignore -from antarest.core.config import Config, SecurityConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db from antarest.login.model import Bot, Group, Password, Role, RoleType, User, UserLdap from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository @pytest.mark.unit_test -def test_users(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserRepository( - config=Config(security=SecurityConfig(admin_pwd="admin")), - ) +def test_users(db_session: Session): + with db_session: + repo = UserRepository(session=db_session) a = User( name="a", password=Password("a"), @@ -43,18 +28,9 @@ def test_users(): @pytest.mark.unit_test -def test_users_ldap(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserLdapRepository() +def test_users_ldap(db_session: Session): + repo = UserLdapRepository(session=db_session) + with repo.session: a = UserLdap(name="a", external_id="b") a = repo.save(a) @@ -67,18 +43,9 @@ def test_users_ldap(): @pytest.mark.unit_test -def test_bots(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = BotRepository() +def test_bots(db_session: Session): + repo = BotRepository(session=db_session) + with repo.session: a = Bot(name="a", owner=1) a = repo.save(a) assert a.id @@ -98,19 +65,9 @@ def test_bots(): @pytest.mark.unit_test -def test_groups(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = GroupRepository() - +def test_groups(db_session: Session): + repo = GroupRepository(session=db_session) + with repo.session: a = Group(name="a") a = repo.save(a) @@ -125,19 +82,9 @@ def test_groups(): @pytest.mark.unit_test -def test_roles(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = RoleRepository() - +def test_roles(db_session: Session): + repo = RoleRepository(session=db_session) + with repo.session: a = Role(type=RoleType.ADMIN, identity=User(id=0), group=Group(id="group")) a = repo.save(a) diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 3973a18d39..3825924f85 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy import typing as npt +from sqlalchemy.orm import Session # ignore type from antarest.core.config import Config, SecurityConfig from antarest.core.utils.fastapi_sqlalchemy import db @@ -51,20 +52,20 @@ def test_bucket_lifecycle(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError): repo.get(aid) - def test_dataset(self) -> None: - with db(): + def test_dataset(self, db_session: Session) -> None: + with db_session: # sourcery skip: extract-duplicate-method, extract-method - repo = MatrixRepository() + repo = MatrixRepository(session=db_session) - user_repo = UserRepository(Config(security=SecurityConfig())) + user_repo = UserRepository(session=db_session) # noinspection PyArgumentList user = user_repo.save(User(name="foo", password=Password("bar"))) - group_repo = GroupRepository() + group_repo = GroupRepository(session=db_session) # noinspection PyArgumentList group = group_repo.save(Group(name="group")) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) @@ -105,22 +106,22 @@ def test_dataset(self) -> None: assert dataset_query_result.name == "some name change" assert dataset_query_result.owner_id == user.id - def test_datastore_query(self) -> None: + def test_datastore_query(self, db_session: Session) -> None: # sourcery skip: extract-duplicate-method with db(): - user_repo = UserRepository(Config(security=SecurityConfig())) + user_repo = UserRepository(session=db_session) # noinspection PyArgumentList user1 = user_repo.save(User(name="foo", password=Password("bar"))) # noinspection PyArgumentList user2 = user_repo.save(User(name="hello", password=Password("world"))) - repo = MatrixRepository() + repo = MatrixRepository(session=db_session) m1 = Matrix(id="hello", created_at=datetime.now()) repo.save(m1) m2 = Matrix(id="world", created_at=datetime.now()) repo.save(m2) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) dataset = MatrixDataSet( name="some name", @@ -165,7 +166,7 @@ def test_datastore_query(self) -> None: assert ( len( # fmt: off - db.session + db_session .query(MatrixDataSetRelation) .filter(MatrixDataSetRelation.dataset_id == dataset.id) .all() diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 9f8e0be884..314e77e500 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -9,6 +9,7 @@ from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.area_management import AreaCreationDTO, AreaManager, AreaType, AreaUI from antarest.study.business.link_management import LinkInfoDTO, LinkManager @@ -66,7 +67,10 @@ def matrix_service_fixture(tmp_path: Path) -> SimpleMatrixService: """ matrix_path = tmp_path.joinpath("matrix-store") matrix_path.mkdir() - return SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @with_db_context @@ -94,8 +98,10 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): raw_study_service.get_raw.return_value = empty_study raw_study_service.cache = Mock() + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=generator_matrix_constants.matrix_service.bucket_dir) variant_study_service.command_factory = CommandFactory( - GeneratorMatrixConstants(matrix_service), + generator_matrix_constants, matrix_service, patch_service=Mock(spec=PatchService), ) diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index 4ff8fbf888..197be27144 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -12,6 +12,7 @@ from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base from antarest.login.model import User +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.main import build_study_service from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, StudyAdditionalData @@ -87,7 +88,10 @@ def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) matrix_path = tmp_path / "matrices" matrix_path.mkdir() - matrix_service = SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + matrix_service = SimpleMatrixService(matrix_content_repository=matrix_content_repository) storage_service = build_study_service( application=Mock(), cache=LocalCache(config=config.cache), diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index 93a3262259..fe679a821a 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -1,5 +1,6 @@ import numpy as np +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.storage.variantstudy.business import matrix_constants from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( @@ -10,7 +11,14 @@ class TestGeneratorMatrixConstants: def test_get_st_storage(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) ref1 = generator.get_st_storage_pmax_injection() matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1] @@ -38,7 +46,14 @@ def test_get_st_storage(self, tmp_path): assert np.array(matrix_dto5.data).all() == matrix_constants.st_storage.series.inflows.all() def test_get_binding_constraint(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) series = matrix_constants.binding_constraint.series hourly = generator.get_binding_constraint_hourly() diff --git a/tests/study/storage/variantstudy/test_variant_study_service.py b/tests/study/storage/variantstudy/test_variant_study_service.py index 25317a9589..8766bfd308 100644 --- a/tests/study/storage/variantstudy/test_variant_study_service.py +++ b/tests/study/storage/variantstudy/test_variant_study_service.py @@ -11,6 +11,7 @@ from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, User +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy, StudyAdditionalData diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index 9db21ab220..b069e029d8 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -91,8 +91,10 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext: CommandContext: The CommandContext object. """ # sourcery skip: inline-immediately-returned-variable + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) command_context = CommandContext( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(repository=Mock(spec=StudyMetadataRepository)), ) @@ -110,8 +112,10 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory: Returns: CommandFactory: The CommandFactory object. """ + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices(bucket_dir=matrix_service.bucket_dir) return CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), )