diff --git a/antarest/core/tasks/repository.py b/antarest/core/tasks/repository.py index 0a2028db33..74c9e84b78 100644 --- a/antarest/core/tasks/repository.py +++ b/antarest/core/tasks/repository.py @@ -27,15 +27,6 @@ class TaskJobRepository: Database connector to manage Tasks/Jobs entities. """ - def __init__(self, session: t.Optional[Session] = None): - """ - Initialize the repository. - - Args: - session: Optional SQLAlchemy session to be used. - """ - self._session = session - @property def session(self) -> Session: """ @@ -44,11 +35,7 @@ def session(self) -> Session: Returns: SQLAlchemy session. """ - if self._session is None: - # Get or create the session from a context variable (thread local variable) - return db.session - # Get the user-defined session - return self._session + return db.session def save(self, task: TaskJob) -> TaskJob: session = self.session diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index f992e227c6..a97a2fedf5 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -38,6 +38,7 @@ ) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.core.utils.utils import retry from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult logger = logging.getLogger(__name__) @@ -390,35 +391,41 @@ def _run_task( task_id: str, custom_event_messages: t.Optional[CustomTaskEventMessages] = None, ) -> None: - # attention: this function is executed in a thread, not in the main process - with db(): - task = db.session.query(TaskJob).get(task_id) - task_type = task.type - study_id = task.ref_id + # We need to catch all exceptions so that the calling thread is guaranteed + # to not die + try: + # attention: this function is executed in a thread, not in the main process + with db(): + # Important to keep this retry for now, + # in case commit is not visible (read from replica ...) + task = retry(lambda: self.repo.get_or_raise(task_id)) + task_type = task.type + study_id = task.ref_id - self.event_bus.push( - Event( - type=EventType.TASK_RUNNING, - payload=TaskEventPayload( - id=task_id, - message=custom_event_messages.running - if custom_event_messages is not None - else f"Task {task_id} is running", - type=task_type, - study_id=study_id, - ).model_dump(), - permissions=PermissionInfo(public_mode=PublicMode.READ), - channel=EventChannelDirectory.TASK + task_id, + self.event_bus.push( + Event( + type=EventType.TASK_RUNNING, + payload=TaskEventPayload( + id=task_id, + message=custom_event_messages.running + if custom_event_messages is not None + else f"Task {task_id} is running", + type=task_type, + study_id=study_id, + ).model_dump(), + permissions=PermissionInfo(public_mode=PublicMode.READ), + channel=EventChannelDirectory.TASK + task_id, + ) ) - ) - logger.info(f"Starting task {task_id}") - with db(): - db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value}) - db.session.commit() - logger.info(f"Task {task_id} set to RUNNING") + logger.info(f"Starting task {task_id}") + with db(): + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + {TaskJob.status: TaskStatus.RUNNING.value} + ) + db.session.commit() + logger.info(f"Task {task_id} set to RUNNING") - try: with db(): # We must use the DB session attached to the current thread result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus)) @@ -463,29 +470,35 @@ def _run_task( err_msg = f"Task {task_id} failed: Unhandled exception {exc}" logger.error(err_msg, exc_info=exc) - with db(): - result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." - db.session.query(TaskJob).filter(TaskJob.id == task_id).update( - { - TaskJob.status: TaskStatus.FAILED.value, - TaskJob.result_msg: result_msg, - TaskJob.result_status: False, - TaskJob.completion_date: datetime.datetime.utcnow(), - } + try: + with db(): + result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result_msg: result_msg, + TaskJob.result_status: False, + TaskJob.completion_date: datetime.datetime.utcnow(), + } + ) + db.session.commit() + + message = err_msg if custom_event_messages is None else custom_event_messages.end + self.event_bus.push( + Event( + type=EventType.TASK_FAILED, + payload=TaskEventPayload( + id=task_id, message=message, type=task_type, study_id=study_id + ).model_dump(), + permissions=PermissionInfo(public_mode=PublicMode.READ), + channel=EventChannelDirectory.TASK + task_id, + ) ) - db.session.commit() - - message = err_msg if custom_event_messages is None else custom_event_messages.end - self.event_bus.push( - Event( - type=EventType.TASK_FAILED, - payload=TaskEventPayload( - id=task_id, message=message, type=task_type, study_id=study_id - ).model_dump(), - permissions=PermissionInfo(public_mode=PublicMode.READ), - channel=EventChannelDirectory.TASK + task_id, + except Exception as inner_exc: + logger.error( + f"An exception occurred while handling execution error of task {task_id}: {inner_exc}", + exc_info=inner_exc, ) - ) def get_task_progress(self, task_id: str, params: RequestParameters) -> t.Optional[int]: task = self.repo.get_or_raise(task_id) diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index c748420549..2940db582a 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -105,7 +105,7 @@ def retry(func: t.Callable[[], T], attempts: int = 10, interval: float = 0.5) -> attempt += 1 return func() except Exception as e: - logger.info(f"💤 Sleeping {interval} second(s)...") + logger.info(f"💤 Sleeping {interval} second(s) before retry...", exc_info=e) time.sleep(interval) caught_exception = e raise caught_exception or ShouldNotHappenException() diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 139db56691..5a8b490047 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -204,21 +204,22 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return TaskResult(success=True, message="") -def test_repository(db_session: Session) -> None: +@with_db_context +def test_repository() -> None: # Prepare two users in the database user1_id = 9 - db_session.add(User(id=user1_id, name="John")) + db.session.add(User(id=user1_id, name="John")) user2_id = 10 - db_session.add(User(id=user2_id, name="Jane")) - db_session.commit() + db.session.add(User(id=user2_id, name="Jane")) + db.session.commit() # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id=study_id, name="foo", version="860")) - db_session.commit() + db.session.add(RawStudy(id=study_id, name="foo", version="860")) + db.session.commit() # Create a TaskJobService - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY) @@ -282,10 +283,10 @@ def test_repository(db_session: Session) -> None: assert len(new_task.logs) == 2 assert new_task.logs[0].message == "hello" - assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 + assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 task_job_repo.delete(new_task.id) - assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 + assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 assert task_job_repo.get(new_task.id) is None @@ -390,21 +391,22 @@ def test_cancel_orphan_tasks( assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds -def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None: +@with_db_context +def test_get_progress(admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None: # Prepare two users in the database user1_id = 9 - db_session.add(User(id=user1_id, name="John")) + db.session.add(User(id=user1_id, name="John")) user2_id = 10 - db_session.add(User(id=user2_id, name="Jane")) - db_session.commit() + db.session.add(User(id=user2_id, name="Jane")) + db.session.commit() # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id=study_id, name="foo", version="860")) - db_session.commit() + db.session.add(RawStudy(id=study_id, name="foo", version="860")) + db.session.commit() # Create a TaskJobService - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() # User 1 launches a ts generation first_task = TaskJob( @@ -451,12 +453,12 @@ def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Con service.get_task_progress(wrong_id, RequestParameters(user)) +@with_db_context def test_ts_generation_task( tmp_path: Path, core_config: Config, admin_user: JWTUser, raw_study_service: RawStudyService, - db_session: Session, ) -> None: # ======================= # SET UP @@ -465,7 +467,7 @@ def test_ts_generation_task( event_bus = DummyEventBusService() # Create a TaskJobService and add tasks - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() # Create a TaskJobService task_job_service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) @@ -474,8 +476,8 @@ def test_ts_generation_task( raw_study_path = tmp_path / "study" regular_user = User(id=99, name="regular") - db_session.add(regular_user) - db_session.commit() + db.session.add(regular_user) + db.session.commit() raw_study = RawStudy( id="my_raw_study", @@ -490,8 +492,8 @@ def test_ts_generation_task( path=str(raw_study_path), ) study_metadata_repository = StudyMetadataRepository(Mock(), None) - db_session.add(raw_study) - db_session.commit() + db.session.add(raw_study) + db.session.commit() # Set up the Raw Study raw_study_service.create(raw_study)