Skip to content

Commit

Permalink
feat(db-init): separate database initialization from global database …
Browse files Browse the repository at this point in the history
…session (#1805)
  • Loading branch information
laurent-laporte-pro authored and mabw-rte committed Nov 19, 2023
1 parent 993d1d9 commit ffa6d0b
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 154 deletions.
2 changes: 1 addition & 1 deletion antarest/login/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def build_login(
"""

if service is None:
user_repo = UserRepository(config)
user_repo = UserRepository()
bot_repo = BotRepository()
group_repo = GroupRepository()
role_repo = RoleRepository()
Expand Down
88 changes: 55 additions & 33 deletions antarest/login/repository.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,75 @@
import logging
from typing import List, Optional
from typing import Dict, List, Optional, Tuple, Union

from sqlalchemy import exists # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import sessionmaker # type: ignore

from antarest.core.config import Config
from antarest.core.jwt import ADMIN_ID
from antarest.core.roles import RoleType
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.login.model import Bot, Group, Password, Role, User, UserLdap

logger = logging.getLogger(__name__)

DB_INIT_DEFAULT_GROUP_ID = "admin"
DB_INIT_DEFAULT_GROUP_NAME = "admin"

DB_INIT_DEFAULT_USER_ID = ADMIN_ID
DB_INIT_DEFAULT_USER_NAME = "admin"

DB_INIT_DEFAULT_ROLE_ID = ADMIN_ID
DB_INIT_DEFAULT_ROLE_GROUP_ID = "admin"


def init_admin_user(engine: Engine, session_args: Dict[str, bool], admin_password: str) -> None:
with sessionmaker(bind=engine, **session_args)() as session:
group = Group(
id=DB_INIT_DEFAULT_GROUP_ID,
name=DB_INIT_DEFAULT_GROUP_NAME,
)
user = User(
id=DB_INIT_DEFAULT_USER_ID,
name=DB_INIT_DEFAULT_USER_NAME,
password=Password(admin_password),
)
role = Role(
type=RoleType.ADMIN,
identity_id=DB_INIT_DEFAULT_ROLE_ID,
group_id=DB_INIT_DEFAULT_ROLE_GROUP_ID,
)

if session.query(exists().where(Group.id == group.id)).scalar():
session.merge(group)
else:
session.add(group)

if session.query(exists().where(User.id == user.id)).scalar():
session.merge(user)
else:
session.add(user)

if (
session.query(Role).get(
(
DB_INIT_DEFAULT_USER_ID,
DB_INIT_DEFAULT_GROUP_NAME,
)
)
is None
):
role.group = session.merge(role.group)
role.identity = session.merge(role.identity)
session.add(role)

session.commit()


class GroupRepository:
"""
Database connector to manage Group entity.
"""

def __init__(self) -> None:
with db():
self.save(Group(id="admin", name="admin"))

def save(self, group: Group) -> Group:
res = db.session.query(exists().where(Group.id == group.id)).scalar()
if res:
Expand Down Expand Up @@ -57,22 +106,6 @@ class UserRepository:
Database connector to manage User entity.
"""

def __init__(self, config: Config) -> None:
# init seed admin user from conf
with db():
admin_user = self.get_by_name("admin")
if admin_user is None:
self.save(
User(
id=ADMIN_ID,
name="admin",
password=Password(config.security.admin_pwd),
)
)
elif not admin_user.password.check(config.security.admin_pwd): # type: ignore
admin_user.password = Password(config.security.admin_pwd) # type: ignore
self.save(admin_user)

def save(self, user: User) -> User:
res = db.session.query(exists().where(User.id == user.id)).scalar()
if res:
Expand Down Expand Up @@ -193,17 +226,6 @@ class RoleRepository:
Database connector to manage Role entity.
"""

def __init__(self) -> None:
with db():
if self.get(1, "admin") is None:
self.save(
Role(
type=RoleType.ADMIN,
identity=User(id=1),
group=Group(id="admin"),
)
)

def save(self, role: Role) -> Role:
role.group = db.session.merge(role.group)
role.identity = db.session.merge(role.identity)
Expand Down
17 changes: 12 additions & 5 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@
from antarest.core.logging.utils import LoggingMiddleware, configure_logger
from antarest.core.requests import RATE_LIMIT_CONFIG
from antarest.core.swagger import customize_openapi
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.core.utils.utils import get_local_path
from antarest.core.utils.web import tags_metadata
from antarest.login.auth import Auth, JwtSettings
from antarest.login.repository import init_admin_user
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.singleton_services import SingletonServices
from antarest.singleton_services import start_all_services
from antarest.study.storage.auto_archive_service import AutoArchiveService
from antarest.study.storage.rawstudy.watcher import Watcher
from antarest.tools.admin_lib import clean_locks
from antarest.utils import Module, create_services, init_db
from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -246,7 +248,12 @@ def fastapi_app(
)

# Database
init_db(config_file, config, auto_upgrade_db, application)
engine = init_db_engine(config_file, config, auto_upgrade_db)
application.add_middleware(
DBSessionMiddleware,
custom_engine=engine,
session_args=dict(SESSION_ARGS),
)

application.add_middleware(LoggingMiddleware)

Expand Down Expand Up @@ -401,6 +408,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
config=RATE_LIMIT_CONFIG,
)

init_admin_user(engine=engine, session_args=dict(SESSION_ARGS), admin_password=config.security.admin_pwd)
services = create_services(config, application)

if mount_front:
Expand Down Expand Up @@ -455,8 +463,7 @@ def main() -> None:
# noinspection PyTypeChecker
uvicorn.run(app, host="0.0.0.0", port=8080, log_config=LOGGING_CONFIG)
else:
services = SingletonServices(arguments.config_file, [arguments.module])
services.start()
start_all_services(arguments.config_file, [arguments.module])


if __name__ == "__main__":
Expand Down
130 changes: 58 additions & 72 deletions antarest/singleton_services.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,76 @@
import logging
import time
from pathlib import Path
from typing import Dict, List

from antarest.core.config import Config
from antarest.core.interfaces.service import IService
from antarest.core.logging.utils import configure_logger
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.core.utils.utils import get_local_path
from antarest.study.storage.auto_archive_service import AutoArchiveService
from antarest.utils import (
SESSION_ARGS,
Module,
create_archive_worker,
create_core_services,
create_matrix_gc,
create_simulator_worker,
create_watcher,
init_db,
init_db_engine,
)

logger = logging.getLogger(__name__)


class SingletonServices:
def __init__(self, config_file: Path, services_list: List[Module]) -> None:
self.services_list = self._init(config_file, services_list)

@staticmethod
def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]:
res = get_local_path() / "resources"
config = Config.from_yaml_file(res=res, file=config_file)
init_db(config_file, config, False, None)
configure_logger(config)

(
cache,
event_bus,
task_service,
ft_manager,
login_service,
matrix_service,
study_service,
) = create_core_services(None, config)

services: Dict[Module, IService] = {}

if Module.WATCHER in services_list:
watcher = create_watcher(config=config, application=None, study_service=study_service)
services[Module.WATCHER] = watcher

if Module.MATRIX_GC in services_list:
matrix_gc = create_matrix_gc(
config=config,
application=None,
study_service=study_service,
matrix_service=matrix_service,
)
services[Module.MATRIX_GC] = matrix_gc

if Module.ARCHIVE_WORKER in services_list:
worker = create_archive_worker(config, "test", event_bus=event_bus)
services[Module.ARCHIVE_WORKER] = worker

if Module.SIMULATOR_WORKER in services_list:
worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus)
services[Module.SIMULATOR_WORKER] = worker

if Module.AUTO_ARCHIVER in services_list:
auto_archive_service = AutoArchiveService(study_service, config)
services[Module.AUTO_ARCHIVER] = auto_archive_service

return services

def start(self) -> None:
for service in self.services_list:
self.services_list[service].start(threaded=True)

self._loop()

def _loop(self) -> None:
while True:
try:
pass
except Exception as e:
logger.error(
"Unexpected error happened while processing service manager loop",
exc_info=e,
)
finally:
time.sleep(2)
def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]:
res = get_local_path() / "resources"
config = Config.from_yaml_file(res=res, file=config_file)
engine = init_db_engine(
config_file,
config,
False,
)
DBSessionMiddleware(None, custom_engine=engine, session_args=dict(SESSION_ARGS))
configure_logger(config)

(
cache,
event_bus,
task_service,
ft_manager,
login_service,
matrix_service,
study_service,
) = create_core_services(None, config)

services: Dict[Module, IService] = {}

if Module.WATCHER in services_list:
watcher = create_watcher(config=config, application=None, study_service=study_service)
services[Module.WATCHER] = watcher

if Module.MATRIX_GC in services_list:
matrix_gc = create_matrix_gc(
config=config,
application=None,
study_service=study_service,
matrix_service=matrix_service,
)
services[Module.MATRIX_GC] = matrix_gc

if Module.ARCHIVE_WORKER in services_list:
worker = create_archive_worker(config, "test", event_bus=event_bus)
services[Module.ARCHIVE_WORKER] = worker

if Module.SIMULATOR_WORKER in services_list:
worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus)
services[Module.SIMULATOR_WORKER] = worker

if Module.AUTO_ARCHIVER in services_list:
auto_archive_service = AutoArchiveService(study_service, config)
services[Module.AUTO_ARCHIVER] = auto_archive_service

return services


def start_all_services(config_file: Path, services_list: List[Module]) -> None:
services = _init(config_file, services_list)
for service in services:
services[service].start(threaded=True)
Loading

0 comments on commit ffa6d0b

Please sign in to comment.