diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 1206db9fc4..6eacab3f10 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -177,7 +177,7 @@ def __repr__(self) -> str: def cancel_orphan_tasks(engine: Engine, session_args: Dict[str, bool]) -> None: updated_values = { TaskJob.status: TaskStatus.FAILED.value, - TaskJob.result: False, + TaskJob.result_status: False, TaskJob.result_msg: "Task was interrupted due to server restart", TaskJob.completion_date: datetime.utcnow(), } diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index dfad126555..73782745ce 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -6,6 +6,8 @@ import pytest from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import sessionmaker from antarest.core.config import Config, RemoteWorkerConfig, TaskConfig from antarest.core.interfaces.eventbus import Event, EventType, IEventBus @@ -13,12 +15,22 @@ from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError -from antarest.core.tasks.model import TaskDTO, TaskJob, TaskJobLog, TaskListFilter, TaskResult, TaskStatus, TaskType +from antarest.core.tasks.model import ( + TaskDTO, + TaskJob, + TaskJobLog, + TaskListFilter, + TaskResult, + TaskStatus, + TaskType, + cancel_orphan_tasks, +) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.tasks.service import TaskJobService from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.service import EventBusService +from antarest.utils import SESSION_ARGS from antarest.worker.worker import AbstractWorker, WorkerTaskCommand from tests.helpers import with_db_context @@ -453,3 +465,59 @@ def test_cancel(): service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) task.status = TaskStatus.CANCELLED.value repo_mock.save.assert_called_with(task) + + +@pytest.mark.parametrize( + ("status", "result_status", "result_msg"), + [ + (TaskStatus.RUNNING.value, False, "task ongoing"), + (TaskStatus.PENDING.value, True, "task pending"), + (TaskStatus.FAILED.value, False, "task failed"), + (TaskStatus.COMPLETED.value, True, "task finished"), + (TaskStatus.TIMEOUT.value, False, "task timed out"), + (TaskStatus.CANCELLED.value, True, "task canceled"), + ], +) +def test_cancel_orphan_tasks( + db_engine: Engine, + status: int, + result_status: bool, + result_msg: str, + max_diff_seconds: int = 6, + test_id: str = "test_cancel_orphan_tasks_id", +): + completion_date: datetime.datetime = datetime.datetime.utcnow() + task_job = TaskJob( + id=test_id, + status=status, + result_status=result_status, + result_msg=result_msg, + completion_date=completion_date, + ) + with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + if session.query(TaskJob).get(test_id) is not None: + session.merge(task_job) + session.commit() + else: + session.add(task_job) + session.commit() + cancel_orphan_tasks(engine=db_engine, session_args=dict(SESSION_ARGS)) + with sessionmaker(bind=db_engine, **dict(SESSION_ARGS))() as session: + if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: + updated_task_job = ( + session.query(TaskJob) + .filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])) + .all() + ) + assert not updated_task_job + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == TaskStatus.FAILED.value + assert not updated_task_job.result_status + assert updated_task_job.result_msg == "Task was interrupted due to server restart" + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds + else: + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == status + assert updated_task_job.result_status == result_status + assert updated_task_job.result_msg == result_msg + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds