Skip to content

Commit

Permalink
fix(variant): generates cascades of variants synchronously to avoid t…
Browse files Browse the repository at this point in the history
…imeout dead locks
  • Loading branch information
laurent-laporte-pro authored Nov 14, 2023
2 parents 1d5b8df + 8d42c1b commit 1874410
Show file tree
Hide file tree
Showing 22 changed files with 2,451 additions and 524 deletions.
5 changes: 5 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.EXPECTATION_FAILED, message)


class VariantGenerationTimeoutError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.REQUEST_TIMEOUT, message)


class NoParentStudyError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.NOT_FOUND, message)
Expand Down
17 changes: 17 additions & 0 deletions antarest/core/interfaces/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@


class CacheConstants(Enum):
"""
Constants used to identify cache entries.
- `RAW_STUDY`: variable used to store JSON (or bytes) objects.
This cache is used by the `RawStudyService` or `VariantStudyService` to store
values that are retrieved from the filesystem.
Note: that it is unlikely that this cache is used, because it is only used
to fetch data inside a study when the URL is "" and the depth is -1.
- `STUDY_FACTORY`: variable used to store objects of type `FileStudyTreeConfigDTO`.
This cache is used by the `create_from_fs` function when retrieving the configuration
of a study from the data on the disk.
- `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`.
This cache is used by the `get_studies_information` function to store the list of studies.
"""

RAW_STUDY = "RAW_STUDY"
STUDY_FACTORY = "STUDY_FACTORY"
STUDY_LISTING = "STUDY_LISTING"
Expand Down
10 changes: 9 additions & 1 deletion antarest/core/tasks/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import HTTPException

from antarest.core.tasks.model import TaskJob, TaskListFilter
from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import assert_this

Expand Down Expand Up @@ -82,3 +82,11 @@ def delete(self, tid: str) -> None:
if task:
db.session.delete(task)
db.session.commit()

def update_timeout(self, task_id: str, timeout: int) -> None:
"""Update task status to TIMEOUT."""
task: TaskJob = db.session.get(TaskJob, task_id)
task.status = TaskStatus.TIMEOUT
task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds"
task.result_status = False
db.session.commit()
38 changes: 24 additions & 14 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
TaskUpdateNotifier = Callable[[str], None]
Task = Callable[[TaskUpdateNotifier], TaskResult]

DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours
"""Default timeout for `await_task` in seconds."""


class ITaskService(ABC):
@abstractmethod
Expand Down Expand Up @@ -74,7 +77,7 @@ def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParamet
raise NotImplementedError()

@abstractmethod
def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None:
def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None:
raise NotImplementedError()


Expand All @@ -83,9 +86,6 @@ def noop_notifier(message: str) -> None:
"""This function is used in tasks when no notification is required."""


DEFAULT_AWAIT_MAX_TIMEOUT = 172800


class TaskJobService(ITaskService):
def __init__(
self,
Expand Down Expand Up @@ -141,6 +141,7 @@ def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult:
task_type,
)
while not task_result_wrapper:
logger.info("💤 Sleeping 1 second...")
time.sleep(1)
self.event_bus.remove_listener(listener_id)
return task_result_wrapper[0]
Expand Down Expand Up @@ -283,22 +284,31 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara
user = None if request_params.user.is_site_admin() else request_params.user.impersonator
return self.repo.list(task_filter, user)

def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None:
logger.info(f"Awaiting task {task_id}")
def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None:
if task_id in self.tasks:
self.tasks[task_id].result(timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT)
try:
logger.info(f"🤔 Awaiting task '{task_id}' {timeout_sec}s...")
self.tasks[task_id].result(timeout_sec)
logger.info(f"📌 Task '{task_id}' done.")
except Exception as exc:
logger.critical(f"🤕 Task '{task_id}' failed: {exc}.")
raise
else:
logger.warning(f"Task {task_id} not handled by this worker, will poll for task completion from db")
end = time.time() + (timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT)
logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db")
end = time.time() + timeout_sec
while time.time() < end:
with db():
task = self.repo.get(task_id)
if not task:
logger.error(f"Awaited task {task_id} was not found")
break
if task is None:
logger.error(f"Awaited task '{task_id}' was not found")
return
if TaskStatus(task.status).is_final():
break
time.sleep(2)
return
logger.info("💤 Sleeping 2 seconds...")
time.sleep(2)
logger.error(f"Timeout while awaiting task '{task_id}'")
with db():
self.repo.update_timeout(task_id, timeout_sec)

def _run_task(
self,
Expand Down
47 changes: 39 additions & 8 deletions antarest/core/tasks/web.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import concurrent.futures
import http
import logging
from typing import Any

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException

from antarest.core.config import Config
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters
from antarest.core.tasks.model import TaskListFilter
from antarest.core.tasks.service import TaskJobService
from antarest.core.tasks.model import TaskDTO, TaskListFilter
from antarest.core.tasks.service import DEFAULT_AWAIT_MAX_TIMEOUT, TaskJobService
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth

Expand All @@ -17,13 +19,13 @@
def create_tasks_api(service: TaskJobService, config: Config) -> APIRouter:
"""
Endpoints login implementation
Args:
service: login facade service
config: server config
jwt: jwt manager
Returns:
API router
"""
bp = APIRouter(prefix="/v1")
auth = Auth(config)
Expand All @@ -36,17 +38,46 @@ def list_tasks(
request_params = RequestParameters(user=current_user)
return service.list_tasks(filter, request_params)

@bp.get("/tasks/{task_id}", tags=[APITag.tasks])
@bp.get("/tasks/{task_id}", tags=[APITag.tasks], response_model=TaskDTO)
def get_task(
task_id: str,
wait_for_completion: bool = False,
with_logs: bool = False,
timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
) -> TaskDTO:
"""
Retrieve information about a specific task.
Args:
- `task_id`: Unique identifier of the task.
- `wait_for_completion`: Set to `True` to wait for task completion.
- `with_logs`: Set to `True` to retrieve the job logs (Antares Solver logs).
- `timeout`: Maximum time in seconds to wait for task completion.
Raises:
- 408 REQUEST_TIMEOUT: when the request times out while waiting for task completion.
Returns:
TaskDTO: Information about the specified task.
"""
request_params = RequestParameters(user=current_user)
task_status = service.status_task(task_id, request_params, with_logs)

if wait_for_completion and not task_status.status.is_final():
service.await_task(task_id)
# Ensure 0 <= timeout <= 48 h
timeout = min(max(0, timeout), DEFAULT_AWAIT_MAX_TIMEOUT)
try:
service.await_task(task_id, timeout_sec=timeout)
except concurrent.futures.TimeoutError as exc: # pragma: no cover
# Note that if the task does not complete within the specified time,
# the task will continue running but the user will receive a timeout.
# In this case, it is the user's responsibility to cancel the task.
raise HTTPException(
status_code=http.HTTPStatus.REQUEST_TIMEOUT,
detail="The request timed out while waiting for task completion.",
) from exc

return service.status_task(task_id, request_params, with_logs)

@bp.put("/tasks/{task_id}/cancel", tags=[APITag.tasks])
Expand Down
1 change: 1 addition & 0 deletions antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def retry(func: Callable[[], T], attempts: int = 10, interval: float = 0.5) -> T
attempt += 1
return func()
except Exception as e:
logger.info(f"💤 Sleeping {interval} second(s)...")
time.sleep(interval)
caught_exception = e
raise caught_exception or ShouldNotHappenException()
Expand Down
2 changes: 1 addition & 1 deletion antarest/matrixstore/matrix_garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_raw_studies_matrices(self) -> Set[str]:

def _get_variant_studies_matrices(self) -> Set[str]:
logger.info("Getting all matrices used in variant studies")
command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_commandblocks()
command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_command_blocks()

def transform_to_command(command_dto: CommandDTO, study_ref: str) -> List[ICommand]:
try:
Expand Down
87 changes: 67 additions & 20 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
import logging
from datetime import datetime
from typing import List, Optional
import typing as t

from sqlalchemy.orm import with_polymorphic # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.utils.fastapi_sqlalchemy import db
Expand All @@ -17,8 +17,30 @@ class StudyMetadataRepository:
Database connector to manage Study entity
"""

def __init__(self, cache_service: ICache):
def __init__(self, cache_service: ICache, session: t.Optional[Session] = None):
"""
Initialize the repository.
Args:
cache_service: Cache service for the repository.
session: Optional SQLAlchemy session to be used.
"""
self.cache_service = cache_service
self._session = session

@property
def session(self) -> Session:
"""
Get the SQLAlchemy session for the repository.
Returns:
SQLAlchemy session.
"""
if self._session is None:
# Get or create the session from a context variable (thread local variable)
return db.session
# Get the user-defined session
return self._session

def save(
self,
Expand All @@ -29,7 +51,7 @@ def save(
metadata_id = metadata.id or metadata.name
logger.debug(f"Saving study {metadata_id}")
if update_modification_date:
metadata.updated_at = datetime.utcnow()
metadata.updated_at = datetime.datetime.utcnow()

metadata.groups = [db.session.merge(g) for g in metadata.groups]
if metadata.owner:
Expand All @@ -44,35 +66,60 @@ def save(
def refresh(self, metadata: Study) -> None:
db.session.refresh(metadata)

def get(self, id: str) -> Optional[Study]:
def get(self, id: str) -> t.Optional[Study]:
"""Get the study by ID or return `None` if not found in database."""
metadata: Study = db.session.query(Study).get(id)
return metadata
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
study: Study = (
# fmt: off
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.get(id)
# fmt: on
)
return study

def one(self, id: str) -> Study:
"""Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database."""
study: Study = db.session.query(Study).filter_by(id=id).one()
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
study: Study = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.filter_by(id=id)
.one()
)
return study

def get_list(self, study_id: List[str]) -> List[Study]:
studies: List[Study] = db.session.query(Study).where(Study.id.in_(study_id)).all()
def get_list(self, study_id: t.List[str]) -> t.List[Study]:
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
studies: t.List[Study] = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.where(Study.id.in_(study_id))
.all()
)
return studies

def get_additional_data(self, study_id: str) -> Optional[StudyAdditionalData]:
metadata: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id)
return metadata
def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]:
study: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id)
return study

def get_all(self) -> List[Study]:
def get_all(self) -> t.List[Study]:
entity = with_polymorphic(Study, "*")
metadatas: List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
return metadatas
studies: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
return studies

def get_all_raw(self, show_missing: bool = True) -> List[RawStudy]:
def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]:
query = db.session.query(RawStudy)
if not show_missing:
query = query.filter(RawStudy.missing.is_(None))
metadatas: List[RawStudy] = query.all()
return metadatas
studies: t.List[RawStudy] = query.all()
return studies

def delete(self, id: str) -> None:
logger.debug(f"Deleting study {id}")
Expand Down
Loading

0 comments on commit 1874410

Please sign in to comment.