Skip to content

Commit

Permalink
fix(tasks): frozen tasks with pgpool (#2264)
Browse files Browse the repository at this point in the history
This work aims at fixing issues encountered with load balanced pgpool,
where a commit is not always instantly visible in another session (when
read from replica).

Fixes are:
- restore a workaround which consists in re-trying the read until success
- ensure exceptions occurring in the task execution thread are
  ALL caught

Also removing a technical debt:
- remove optional session from task job repo which was just used in tests.
  It did not work with code which uses both the repo and also the session
  from the singleton db.session.

Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl authored and skamril committed Dec 11, 2024
1 parent 3cbd45a commit 53667c0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 83 deletions.
15 changes: 1 addition & 14 deletions antarest/core/tasks/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ class TaskJobRepository:
Database connector to manage Tasks/Jobs entities.
"""

def __init__(self, session: t.Optional[Session] = None):
"""
Initialize the repository.
Args:
session: Optional SQLAlchemy session to be used.
"""
self._session = session

@property
def session(self) -> Session:
"""
Expand All @@ -44,11 +35,7 @@ def session(self) -> Session:
Returns:
SQLAlchemy session.
"""
if self._session is None:
# Get or create the session from a context variable (thread local variable)
return db.session
# Get the user-defined session
return self._session
return db.session

def save(self, task: TaskJob) -> TaskJob:
session = self.session
Expand Down
105 changes: 59 additions & 46 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from antarest.core.tasks.repository import TaskJobRepository
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import retry
from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -390,35 +391,41 @@ def _run_task(
task_id: str,
custom_event_messages: t.Optional[CustomTaskEventMessages] = None,
) -> None:
# attention: this function is executed in a thread, not in the main process
with db():
task = db.session.query(TaskJob).get(task_id)
task_type = task.type
study_id = task.ref_id
# We need to catch all exceptions so that the calling thread is guaranteed
# to not die
try:
# attention: this function is executed in a thread, not in the main process
with db():
# Important to keep this retry for now,
# in case commit is not visible (read from replica ...)
task = retry(lambda: self.repo.get_or_raise(task_id))
task_type = task.type
study_id = task.ref_id

self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running",
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running",
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
)

logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value})
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")
logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{TaskJob.status: TaskStatus.RUNNING.value}
)
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")

try:
with db():
# We must use the DB session attached to the current thread
result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus))
Expand Down Expand Up @@ -463,29 +470,35 @@ def _run_task(
err_msg = f"Task {task_id} failed: Unhandled exception {exc}"
logger.error(err_msg, exc_info=exc)

with db():
result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback."
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_msg: result_msg,
TaskJob.result_status: False,
TaskJob.completion_date: datetime.datetime.utcnow(),
}
try:
with db():
result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback."
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_msg: result_msg,
TaskJob.result_status: False,
TaskJob.completion_date: datetime.datetime.utcnow(),
}
)
db.session.commit()

message = err_msg if custom_event_messages is None else custom_event_messages.end
self.event_bus.push(
Event(
type=EventType.TASK_FAILED,
payload=TaskEventPayload(
id=task_id, message=message, type=task_type, study_id=study_id
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
db.session.commit()

message = err_msg if custom_event_messages is None else custom_event_messages.end
self.event_bus.push(
Event(
type=EventType.TASK_FAILED,
payload=TaskEventPayload(
id=task_id, message=message, type=task_type, study_id=study_id
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
except Exception as inner_exc:
logger.error(
f"An exception occurred while handling execution error of task {task_id}: {inner_exc}",
exc_info=inner_exc,
)
)

def get_task_progress(self, task_id: str, params: RequestParameters) -> t.Optional[int]:
task = self.repo.get_or_raise(task_id)
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def retry(func: t.Callable[[], T], attempts: int = 10, interval: float = 0.5) ->
attempt += 1
return func()
except Exception as e:
logger.info(f"💤 Sleeping {interval} second(s)...")
logger.info(f"💤 Sleeping {interval} second(s) before retry...", exc_info=e)
time.sleep(interval)
caught_exception = e
raise caught_exception or ShouldNotHappenException()
Expand Down
46 changes: 24 additions & 22 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,22 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
return TaskResult(success=True, message="")


def test_repository(db_session: Session) -> None:
@with_db_context
def test_repository() -> None:
# Prepare two users in the database
user1_id = 9
db_session.add(User(id=user1_id, name="John"))
db.session.add(User(id=user1_id, name="John"))
user2_id = 10
db_session.add(User(id=user2_id, name="Jane"))
db_session.commit()
db.session.add(User(id=user2_id, name="Jane"))
db.session.commit()

# Create a RawStudy in the database
study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512"
db_session.add(RawStudy(id=study_id, name="foo", version="860"))
db_session.commit()
db.session.add(RawStudy(id=study_id, name="foo", version="860"))
db.session.commit()

# Create a TaskJobService
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY)

Expand Down Expand Up @@ -282,10 +283,10 @@ def test_repository(db_session: Session) -> None:
assert len(new_task.logs) == 2
assert new_task.logs[0].message == "hello"

assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2
assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2

task_job_repo.delete(new_task.id)
assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0
assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0
assert task_job_repo.get(new_task.id) is None


Expand Down Expand Up @@ -390,21 +391,22 @@ def test_cancel_orphan_tasks(
assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds


def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None:
@with_db_context
def test_get_progress(admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None:
# Prepare two users in the database
user1_id = 9
db_session.add(User(id=user1_id, name="John"))
db.session.add(User(id=user1_id, name="John"))
user2_id = 10
db_session.add(User(id=user2_id, name="Jane"))
db_session.commit()
db.session.add(User(id=user2_id, name="Jane"))
db.session.commit()

# Create a RawStudy in the database
study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512"
db_session.add(RawStudy(id=study_id, name="foo", version="860"))
db_session.commit()
db.session.add(RawStudy(id=study_id, name="foo", version="860"))
db.session.commit()

# Create a TaskJobService
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

# User 1 launches a ts generation
first_task = TaskJob(
Expand Down Expand Up @@ -451,12 +453,12 @@ def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Con
service.get_task_progress(wrong_id, RequestParameters(user))


@with_db_context
def test_ts_generation_task(
tmp_path: Path,
core_config: Config,
admin_user: JWTUser,
raw_study_service: RawStudyService,
db_session: Session,
) -> None:
# =======================
# SET UP
Expand All @@ -465,7 +467,7 @@ def test_ts_generation_task(
event_bus = DummyEventBusService()

# Create a TaskJobService and add tasks
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

# Create a TaskJobService
task_job_service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus)
Expand All @@ -474,8 +476,8 @@ def test_ts_generation_task(
raw_study_path = tmp_path / "study"

regular_user = User(id=99, name="regular")
db_session.add(regular_user)
db_session.commit()
db.session.add(regular_user)
db.session.commit()

raw_study = RawStudy(
id="my_raw_study",
Expand All @@ -490,8 +492,8 @@ def test_ts_generation_task(
path=str(raw_study_path),
)
study_metadata_repository = StudyMetadataRepository(Mock(), None)
db_session.add(raw_study)
db_session.commit()
db.session.add(raw_study)
db.session.commit()

# Set up the Raw Study
raw_study_service.create(raw_study)
Expand Down

0 comments on commit 53667c0

Please sign in to comment.