Skip to content

Commit

Permalink
Merge branch 'dev' into feat/use-lower-case-for-groups-and-names
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Dec 9, 2024
2 parents d21b267 + 7ad6803 commit 7e4f32b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 83 deletions.
15 changes: 1 addition & 14 deletions antarest/core/tasks/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ class TaskJobRepository:
Database connector to manage Tasks/Jobs entities.
"""

def __init__(self, session: t.Optional[Session] = None):
"""
Initialize the repository.
Args:
session: Optional SQLAlchemy session to be used.
"""
self._session = session

@property
def session(self) -> Session:
"""
Expand All @@ -44,11 +35,7 @@ def session(self) -> Session:
Returns:
SQLAlchemy session.
"""
if self._session is None:
# Get or create the session from a context variable (thread local variable)
return db.session
# Get the user-defined session
return self._session
return db.session

def save(self, task: TaskJob) -> TaskJob:
session = self.session
Expand Down
105 changes: 59 additions & 46 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from antarest.core.tasks.repository import TaskJobRepository
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import retry
from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -390,35 +391,41 @@ def _run_task(
task_id: str,
custom_event_messages: t.Optional[CustomTaskEventMessages] = None,
) -> None:
# attention: this function is executed in a thread, not in the main process
with db():
task = db.session.query(TaskJob).get(task_id)
task_type = task.type
study_id = task.ref_id
# We need to catch all exceptions so that the calling thread is guaranteed
# to not die
try:
# attention: this function is executed in a thread, not in the main process
with db():
# Important to keep this retry for now,
# in case commit is not visible (read from replica ...)
task = retry(lambda: self.repo.get_or_raise(task_id))
task_type = task.type
study_id = task.ref_id

self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running",
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running",
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
)

logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value})
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")
logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{TaskJob.status: TaskStatus.RUNNING.value}
)
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")

try:
with db():
# We must use the DB session attached to the current thread
result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus))
Expand Down Expand Up @@ -463,29 +470,35 @@ def _run_task(
err_msg = f"Task {task_id} failed: Unhandled exception {exc}"
logger.error(err_msg, exc_info=exc)

with db():
result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback."
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_msg: result_msg,
TaskJob.result_status: False,
TaskJob.completion_date: datetime.datetime.utcnow(),
}
try:
with db():
result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback."
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_msg: result_msg,
TaskJob.result_status: False,
TaskJob.completion_date: datetime.datetime.utcnow(),
}
)
db.session.commit()

message = err_msg if custom_event_messages is None else custom_event_messages.end
self.event_bus.push(
Event(
type=EventType.TASK_FAILED,
payload=TaskEventPayload(
id=task_id, message=message, type=task_type, study_id=study_id
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
db.session.commit()

message = err_msg if custom_event_messages is None else custom_event_messages.end
self.event_bus.push(
Event(
type=EventType.TASK_FAILED,
payload=TaskEventPayload(
id=task_id, message=message, type=task_type, study_id=study_id
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
except Exception as inner_exc:
logger.error(
f"An exception occurred while handling execution error of task {task_id}: {inner_exc}",
exc_info=inner_exc,
)
)

def get_task_progress(self, task_id: str, params: RequestParameters) -> t.Optional[int]:
task = self.repo.get_or_raise(task_id)
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def retry(func: t.Callable[[], T], attempts: int = 10, interval: float = 0.5) ->
attempt += 1
return func()
except Exception as e:
logger.info(f"💤 Sleeping {interval} second(s)...")
logger.info(f"💤 Sleeping {interval} second(s) before retry...", exc_info=e)
time.sleep(interval)
caught_exception = e
raise caught_exception or ShouldNotHappenException()
Expand Down
46 changes: 24 additions & 22 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,22 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
return TaskResult(success=True, message="")


def test_repository(db_session: Session) -> None:
@with_db_context
def test_repository() -> None:
# Prepare two users in the database
user1_id = 9
db_session.add(User(id=user1_id, name="John"))
db.session.add(User(id=user1_id, name="John"))
user2_id = 10
db_session.add(User(id=user2_id, name="Jane"))
db_session.commit()
db.session.add(User(id=user2_id, name="Jane"))
db.session.commit()

# Create a RawStudy in the database
study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512"
db_session.add(RawStudy(id=study_id, name="foo", version="860"))
db_session.commit()
db.session.add(RawStudy(id=study_id, name="foo", version="860"))
db.session.commit()

# Create a TaskJobService
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY)

Expand Down Expand Up @@ -282,10 +283,10 @@ def test_repository(db_session: Session) -> None:
assert len(new_task.logs) == 2
assert new_task.logs[0].message == "hello"

assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2
assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2

task_job_repo.delete(new_task.id)
assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0
assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0
assert task_job_repo.get(new_task.id) is None


Expand Down Expand Up @@ -390,21 +391,22 @@ def test_cancel_orphan_tasks(
assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds


def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None:
@with_db_context
def test_get_progress(admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None:
# Prepare two users in the database
user1_id = 9
db_session.add(User(id=user1_id, name="John"))
db.session.add(User(id=user1_id, name="John"))
user2_id = 10
db_session.add(User(id=user2_id, name="Jane"))
db_session.commit()
db.session.add(User(id=user2_id, name="Jane"))
db.session.commit()

# Create a RawStudy in the database
study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512"
db_session.add(RawStudy(id=study_id, name="foo", version="860"))
db_session.commit()
db.session.add(RawStudy(id=study_id, name="foo", version="860"))
db.session.commit()

# Create a TaskJobService
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

# User 1 launches a ts generation
first_task = TaskJob(
Expand Down Expand Up @@ -451,12 +453,12 @@ def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Con
service.get_task_progress(wrong_id, RequestParameters(user))


@with_db_context
def test_ts_generation_task(
tmp_path: Path,
core_config: Config,
admin_user: JWTUser,
raw_study_service: RawStudyService,
db_session: Session,
) -> None:
# =======================
# SET UP
Expand All @@ -465,7 +467,7 @@ def test_ts_generation_task(
event_bus = DummyEventBusService()

# Create a TaskJobService and add tasks
task_job_repo = TaskJobRepository(db_session)
task_job_repo = TaskJobRepository()

# Create a TaskJobService
task_job_service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus)
Expand All @@ -474,8 +476,8 @@ def test_ts_generation_task(
raw_study_path = tmp_path / "study"

regular_user = User(id=99, name="regular")
db_session.add(regular_user)
db_session.commit()
db.session.add(regular_user)
db.session.commit()

raw_study = RawStudy(
id="my_raw_study",
Expand All @@ -490,8 +492,8 @@ def test_ts_generation_task(
path=str(raw_study_path),
)
study_metadata_repository = StudyMetadataRepository(Mock(), None)
db_session.add(raw_study)
db_session.commit()
db.session.add(raw_study)
db.session.commit()

# Set up the Raw Study
raw_study_service.create(raw_study)
Expand Down

0 comments on commit 7e4f32b

Please sign in to comment.