Skip to content

Commit

Permalink
feat(db-init): separate database initialization from global database …
Browse files Browse the repository at this point in the history
…session (#1805)
  • Loading branch information
mabw-rte authored and laurent-laporte-pro committed Dec 7, 2023
1 parent 63c61b9 commit 9a4e736
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
2 changes: 1 addition & 1 deletion antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __repr__(self) -> str:
def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None:
updated_values = {
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result: False,
TaskJob.result_status: False,
TaskJob.result_msg: "Task was interrupted due to server restart",
TaskJob.completion_date: datetime.utcnow(),
}
Expand Down
70 changes: 69 additions & 1 deletion tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,31 @@

import pytest
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import sessionmaker

from antarest.core.config import Config, RemoteWorkerConfig, TaskConfig
from antarest.core.interfaces.eventbus import Event, EventType, IEventBus
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.model import PermissionInfo, PublicMode
from antarest.core.persistence import Base
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.tasks.model import TaskDTO, TaskJob, TaskJobLog, TaskListFilter, TaskResult, TaskStatus, TaskType
from antarest.core.tasks.model import (
TaskDTO,
TaskJob,
TaskJobLog,
TaskListFilter,
TaskResult,
TaskStatus,
TaskType,
cancel_orphan_tasks,
)
from antarest.core.tasks.repository import TaskJobRepository
from antarest.core.tasks.service import TaskJobService
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db
from antarest.eventbus.business.local_eventbus import LocalEventBus
from antarest.eventbus.service import EventBusService
from antarest.utils import SESSION_ARGS
from antarest.worker.worker import AbstractWorker, WorkerTaskCommand
from tests.helpers import with_db_context

Expand Down Expand Up @@ -453,3 +465,59 @@ def test_cancel():
service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True)
task.status = TaskStatus.CANCELLED.value
repo_mock.save.assert_called_with(task)


@pytest.mark.parametrize(
("status", "result_status", "result_msg"),
[
(TaskStatus.RUNNING.value, False, "task ongoing"),
(TaskStatus.PENDING.value, True, "task pending"),
(TaskStatus.FAILED.value, False, "task failed"),
(TaskStatus.COMPLETED.value, True, "task finished"),
(TaskStatus.TIMEOUT.value, False, "task timed out"),
(TaskStatus.CANCELLED.value, True, "task canceled"),
],
)
def test_cancel_orphan_tasks(
db_engine: Engine,
status: int,
result_status: bool,
result_msg: str,
max_diff_seconds: int = 6,
test_id: str = "test_cancel_orphan_tasks_id",
):
completion_date: datetime.datetime = datetime.datetime.utcnow()
task_job = TaskJob(
id=test_id,
status=status,
result_status=result_status,
result_msg=result_msg,
completion_date=completion_date,
)
with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session:
if session.query(TaskJob).get(test_id) is not None:
session.merge(task_job)
session.commit()
else:
session.add(task_job)
session.commit()
cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS))
with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session:
if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]:
updated_task_job = (
session.query(TaskJob)
.filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value]))
.all()
)
assert not updated_task_job
updated_task_job = session.query(TaskJob).get(test_id)
assert updated_task_job.status == TaskStatus.FAILED.value
assert not updated_task_job.result_status
assert updated_task_job.result_msg == "Task was interrupted due to server restart"
assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds
else:
updated_task_job = session.query(TaskJob).get(test_id)
assert updated_task_job.status == status
assert updated_task_job.result_status == result_status
assert updated_task_job.result_msg == result_msg
assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds

0 comments on commit 9a4e736

Please sign in to comment.