Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(variant): generates cascades of variants synchronously to avoid timeout dead locks #1806

Merged
merged 22 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
da13c45
chore(typing): simplify typing in pytest fixture
laurent-laporte-pro Oct 23, 2023
1e2e617
chore(variant-study-service): correct the `timeout` parameter to use …
laurent-laporte-pro Oct 23, 2023
38addea
feat(repository): add a method to update a task status to TIMEOUT in …
laurent-laporte-pro Oct 23, 2023
ff357bc
fix(task-job-service): update the task status to TIMEOUT if the proce…
laurent-laporte-pro Oct 23, 2023
959b2c0
refactor(variant-study-service): improve implementation of safe gener…
laurent-laporte-pro Oct 23, 2023
eb87bb8
chore: improve log message readability by adding prominent emojis
laurent-laporte-pro Oct 24, 2023
979ac6a
feat(variant-generation): improve error handling for timeout tasks
laurent-laporte-pro Oct 24, 2023
0ba6320
test(variant-generation): add a unit test to check that the variant g…
laurent-laporte-pro Oct 24, 2023
0cac11c
feat(tasks): add timeout parameter to control task completion wait time
laurent-laporte-pro Nov 6, 2023
cd15eec
refactor(variant-study-repo): optionally use a user-defined SqlAlchem…
laurent-laporte-pro Nov 7, 2023
8262e14
feat(variant-study-repo): add the `get_ancestor_or_self_ids` to retri…
laurent-laporte-pro Nov 7, 2023
4c9c081
style(db): add type hints in database model classes
laurent-laporte-pro Nov 8, 2023
21823ba
test(hydro): improves testing of hydraulic allocation variants
laurent-laporte-pro Nov 8, 2023
94d3cdb
docs(cache): add the documentation of cache constants
laurent-laporte-pro Nov 9, 2023
ea69c38
feat(permission): add the `assert_permission_on_studies` to check use…
laurent-laporte-pro Nov 9, 2023
e4f7756
feat(db): add methods to check the variant study snapshot status
laurent-laporte-pro Nov 9, 2023
70595c9
feat(variant): add the `SnapshotGenerator` class
laurent-laporte-pro Nov 10, 2023
be805dc
feat(variant): use the `SnapshotGenerator` class in the `VariantStudy…
laurent-laporte-pro Nov 10, 2023
340881e
test(variant): correct the `test_variant_model` unit test to use fixt…
laurent-laporte-pro Nov 10, 2023
16cf54f
perf(db): improved study query performance using owner and groups pre…
laurent-laporte-pro Nov 12, 2023
7dff31a
style: correct return type of the get task endpoint
laurent-laporte-pro Nov 14, 2023
8d42c1b
style: correct variable naming
laurent-laporte-pro Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.EXPECTATION_FAILED, message)


class VariantGenerationTimeoutError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.REQUEST_TIMEOUT, message)


class NoParentStudyError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.NOT_FOUND, message)
Expand Down
17 changes: 17 additions & 0 deletions antarest/core/interfaces/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@


class CacheConstants(Enum):
"""
Constants used to identify cache entries.

- `RAW_STUDY`: variable used to store JSON (or bytes) objects.
This cache is used by the `RawStudyService` or `VariantStudyService` to store
values that are retrieved from the filesystem.
Note: that it is unlikely that this cache is used, because it is only used
to fetch data inside a study when the URL is "" and the depth is -1.

- `STUDY_FACTORY`: variable used to store objects of type `FileStudyTreeConfigDTO`.
This cache is used by the `create_from_fs` function when retrieving the configuration
of a study from the data on the disk.

- `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`.
This cache is used by the `get_studies_information` function to store the list of studies.
"""

RAW_STUDY = "RAW_STUDY"
STUDY_FACTORY = "STUDY_FACTORY"
STUDY_LISTING = "STUDY_LISTING"
Expand Down
10 changes: 9 additions & 1 deletion antarest/core/tasks/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import HTTPException

from antarest.core.tasks.model import TaskJob, TaskListFilter
from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import assert_this

Expand Down Expand Up @@ -82,3 +82,11 @@ def delete(self, tid: str) -> None:
if task:
db.session.delete(task)
db.session.commit()

def update_timeout(self, task_id: str, timeout: int) -> None:
"""Update task status to TIMEOUT."""
task: TaskJob = db.session.get(TaskJob, task_id)
task.status = TaskStatus.TIMEOUT
task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds"
task.result_status = False
db.session.commit()
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved
38 changes: 24 additions & 14 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
TaskUpdateNotifier = Callable[[str], None]
Task = Callable[[TaskUpdateNotifier], TaskResult]

DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours
"""Default timeout for `await_task` in seconds."""


class ITaskService(ABC):
@abstractmethod
Expand Down Expand Up @@ -74,7 +77,7 @@ def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParamet
raise NotImplementedError()

@abstractmethod
def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None:
def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None:
raise NotImplementedError()


Expand All @@ -83,9 +86,6 @@ def noop_notifier(message: str) -> None:
"""This function is used in tasks when no notification is required."""


DEFAULT_AWAIT_MAX_TIMEOUT = 172800


class TaskJobService(ITaskService):
def __init__(
self,
Expand Down Expand Up @@ -141,6 +141,7 @@ def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult:
task_type,
)
while not task_result_wrapper:
logger.info("💤 Sleeping 1 second...")
time.sleep(1)
self.event_bus.remove_listener(listener_id)
return task_result_wrapper[0]
Expand Down Expand Up @@ -283,22 +284,31 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara
user = None if request_params.user.is_site_admin() else request_params.user.impersonator
return self.repo.list(task_filter, user)

def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None:
logger.info(f"Awaiting task {task_id}")
def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None:
if task_id in self.tasks:
self.tasks[task_id].result(timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT)
try:
logger.info(f"🤔 Awaiting task '{task_id}' {timeout_sec}s...")
self.tasks[task_id].result(timeout_sec)
logger.info(f"📌 Task '{task_id}' done.")
except Exception as exc:
logger.critical(f"🤕 Task '{task_id}' failed: {exc}.")
raise
else:
logger.warning(f"Task {task_id} not handled by this worker, will poll for task completion from db")
end = time.time() + (timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT)
logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db")
end = time.time() + timeout_sec
while time.time() < end:
with db():
task = self.repo.get(task_id)
if not task:
logger.error(f"Awaited task {task_id} was not found")
break
if task is None:
logger.error(f"Awaited task '{task_id}' was not found")
return
if TaskStatus(task.status).is_final():
break
time.sleep(2)
return
logger.info("💤 Sleeping 2 seconds...")
time.sleep(2)
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved
logger.error(f"Timeout while awaiting task '{task_id}'")
with db():
self.repo.update_timeout(task_id, timeout_sec)
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved

def _run_task(
self,
Expand Down
47 changes: 39 additions & 8 deletions antarest/core/tasks/web.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import concurrent.futures
import http
import logging
from typing import Any

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException

from antarest.core.config import Config
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters
from antarest.core.tasks.model import TaskListFilter
from antarest.core.tasks.service import TaskJobService
from antarest.core.tasks.model import TaskDTO, TaskListFilter
from antarest.core.tasks.service import DEFAULT_AWAIT_MAX_TIMEOUT, TaskJobService
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth

Expand All @@ -17,13 +19,13 @@
def create_tasks_api(service: TaskJobService, config: Config) -> APIRouter:
"""
Endpoints login implementation

Args:
service: login facade service
config: server config
jwt: jwt manager

Returns:

API router
"""
bp = APIRouter(prefix="/v1")
auth = Auth(config)
Expand All @@ -36,17 +38,46 @@ def list_tasks(
request_params = RequestParameters(user=current_user)
return service.list_tasks(filter, request_params)

@bp.get("/tasks/{task_id}", tags=[APITag.tasks])
@bp.get("/tasks/{task_id}", tags=[APITag.tasks], response_model=TaskDTO)
def get_task(
task_id: str,
wait_for_completion: bool = False,
with_logs: bool = False,
timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
) -> TaskDTO:
"""
Retrieve information about a specific task.

Args:
- `task_id`: Unique identifier of the task.
- `wait_for_completion`: Set to `True` to wait for task completion.
- `with_logs`: Set to `True` to retrieve the job logs (Antares Solver logs).
- `timeout`: Maximum time in seconds to wait for task completion.

Raises:
- 408 REQUEST_TIMEOUT: when the request times out while waiting for task completion.

Returns:
TaskDTO: Information about the specified task.
"""
request_params = RequestParameters(user=current_user)
task_status = service.status_task(task_id, request_params, with_logs)

if wait_for_completion and not task_status.status.is_final():
service.await_task(task_id)
# Ensure 0 <= timeout <= 48 h
timeout = min(max(0, timeout), DEFAULT_AWAIT_MAX_TIMEOUT)
try:
service.await_task(task_id, timeout_sec=timeout)
except concurrent.futures.TimeoutError as exc: # pragma: no cover
# Note that if the task does not complete within the specified time,
# the task will continue running but the user will receive a timeout.
# In this case, it is the user's responsibility to cancel the task.
raise HTTPException(
status_code=http.HTTPStatus.REQUEST_TIMEOUT,
detail="The request timed out while waiting for task completion.",
) from exc

return service.status_task(task_id, request_params, with_logs)

@bp.put("/tasks/{task_id}/cancel", tags=[APITag.tasks])
Expand Down
1 change: 1 addition & 0 deletions antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def retry(func: Callable[[], T], attempts: int = 10, interval: float = 0.5) -> T
attempt += 1
return func()
except Exception as e:
logger.info(f"💤 Sleeping {interval} second(s)...")
time.sleep(interval)
caught_exception = e
raise caught_exception or ShouldNotHappenException()
Expand Down
2 changes: 1 addition & 1 deletion antarest/matrixstore/matrix_garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_raw_studies_matrices(self) -> Set[str]:

def _get_variant_studies_matrices(self) -> Set[str]:
logger.info("Getting all matrices used in variant studies")
command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_commandblocks()
command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_command_blocks()

def transform_to_command(command_dto: CommandDTO, study_ref: str) -> List[ICommand]:
try:
Expand Down
87 changes: 67 additions & 20 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
import logging
from datetime import datetime
from typing import List, Optional
import typing as t

from sqlalchemy.orm import with_polymorphic # type: ignore
from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore

from antarest.core.interfaces.cache import CacheConstants, ICache
from antarest.core.utils.fastapi_sqlalchemy import db
Expand All @@ -17,8 +17,30 @@ class StudyMetadataRepository:
Database connector to manage Study entity
"""

def __init__(self, cache_service: ICache):
def __init__(self, cache_service: ICache, session: t.Optional[Session] = None):
"""
Initialize the repository.

Args:
cache_service: Cache service for the repository.
session: Optional SQLAlchemy session to be used.
"""
self.cache_service = cache_service
self._session = session

@property
def session(self) -> Session:
"""
Get the SQLAlchemy session for the repository.

Returns:
SQLAlchemy session.
"""
if self._session is None:
# Get or create the session from a context variable (thread local variable)
return db.session
# Get the user-defined session
return self._session
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved

def save(
self,
Expand All @@ -29,7 +51,7 @@ def save(
metadata_id = metadata.id or metadata.name
logger.debug(f"Saving study {metadata_id}")
if update_modification_date:
metadata.updated_at = datetime.utcnow()
metadata.updated_at = datetime.datetime.utcnow()

metadata.groups = [db.session.merge(g) for g in metadata.groups]
if metadata.owner:
Expand All @@ -44,35 +66,60 @@ def save(
def refresh(self, metadata: Study) -> None:
db.session.refresh(metadata)

def get(self, id: str) -> Optional[Study]:
def get(self, id: str) -> t.Optional[Study]:
"""Get the study by ID or return `None` if not found in database."""
metadata: Study = db.session.query(Study).get(id)
return metadata
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
study: Study = (
# fmt: off
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.get(id)
# fmt: on
)
return study

def one(self, id: str) -> Study:
"""Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database."""
study: Study = db.session.query(Study).filter_by(id=id).one()
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
study: Study = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.filter_by(id=id)
.one()
)
return study

def get_list(self, study_id: List[str]) -> List[Study]:
studies: List[Study] = db.session.query(Study).where(Study.id.in_(study_id)).all()
def get_list(self, study_id: t.List[str]) -> t.List[Study]:
# When we fetch a study, we also need to fetch the associated owner and groups
# to check the permissions of the current user efficiently.
studies: t.List[Study] = (
db.session.query(Study)
.options(joinedload(Study.owner))
.options(joinedload(Study.groups))
.where(Study.id.in_(study_id))
.all()
)
return studies

def get_additional_data(self, study_id: str) -> Optional[StudyAdditionalData]:
metadata: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id)
return metadata
def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]:
study: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id)
return study

def get_all(self) -> List[Study]:
def get_all(self) -> t.List[Study]:
entity = with_polymorphic(Study, "*")
laurent-laporte-pro marked this conversation as resolved.
Show resolved Hide resolved
metadatas: List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
return metadatas
studies: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all()
return studies

def get_all_raw(self, show_missing: bool = True) -> List[RawStudy]:
def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]:
query = db.session.query(RawStudy)
if not show_missing:
query = query.filter(RawStudy.missing.is_(None))
metadatas: List[RawStudy] = query.all()
return metadatas
studies: t.List[RawStudy] = query.all()
return studies

def delete(self, id: str) -> None:
logger.debug(f"Deleting study {id}")
Expand Down
Loading
Loading