Skip to content

Commit

Permalink
Merge pull request #663 from AntaresSimulatorTeam/dev
Browse files Browse the repository at this point in the history
v2.1.5
  • Loading branch information
pl-buiquang authored Dec 9, 2021
2 parents c704dba + 06eea06 commit 2c7ccf5
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 43 deletions.
2 changes: 1 addition & 1 deletion antarest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.4"
__version__ = "2.1.5"

from pathlib import Path

Expand Down
1 change: 1 addition & 0 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get(
else refresh_timeout,
)
return redis_element.data
logger.info(f"Cache key {id} not found")
return None

def invalidate(self, id: str) -> None:
Expand Down
16 changes: 15 additions & 1 deletion antarest/core/core_blueprint.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging
import subprocess
from pathlib import Path
from typing import Any, Optional

from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel

from antarest import __version__
from antarest.core.config import Config
from antarest.core.jwt import JWTUser
from antarest.core.requests import UserHasNotPermissionError
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth


def get_commit_id(path_resources: Path) -> Optional[str]:
Expand Down Expand Up @@ -54,6 +58,7 @@ def create_utils_routes(config: Config) -> APIRouter:
"""
bp = APIRouter()
auth = Auth(config)

@bp.get("/health", tags=[APITag.misc], response_model=StatusDTO)
def health() -> Any:
Expand All @@ -70,4 +75,13 @@ def version() -> Any:
version=__version__, gitcommit=get_commit_id(config.resources_path)
)

@bp.get("/kill", include_in_schema=False)
def kill_worker(
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
if not current_user.is_site_admin():
raise UserHasNotPermissionError()
logging.getLogger(__name__).warning("Killing the worker")
exit(1)

return bp
1 change: 1 addition & 0 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class EventType:
TASK_RUNNING = "TASK_RUNNING"
TASK_COMPLETED = "TASK_COMPLETED"
TASK_FAILED = "TASK_FAILED"
TASK_CANCEL_REQUEST = "TASK_CANCEL_REQUEST"
DOWNLOAD_CREATED = "DOWNLOAD_CREATED"
DOWNLOAD_READY = "DOWNLOAD_READY"
DOWNLOAD_EXPIRED = "DOWNLOAD_EXPIRED"
Expand Down
42 changes: 40 additions & 2 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, Future
from http import HTTPStatus
from typing import Callable, Optional, List, Dict
from typing import Callable, Optional, List, Dict, Awaitable

from fastapi import HTTPException

Expand All @@ -21,6 +21,7 @@
from antarest.core.requests import (
RequestParameters,
MustBeAuthenticatedError,
UserHasNotPermissionError,
)
from antarest.core.tasks.model import (
TaskDTO,
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
self.threadpool = ThreadPoolExecutor(
max_workers=config.tasks.max_workers, thread_name_prefix="taskjob_"
)
self.event_bus.add_listener(self.create_task_event_callback())
# set the status of previously running job to FAILED due to server restart
self._fix_running_status()

Expand Down Expand Up @@ -135,6 +137,36 @@ def add_task(
self.tasks[task.id] = future
return str(task.id)

def create_task_event_callback(self) -> Callable[[Event], Awaitable[None]]:
async def task_event_callback(event: Event) -> None:
if event.type == EventType.TASK_CANCEL_REQUEST:
self._cancel_task(str(event.payload), dispatch=False)

return task_event_callback

def cancel_task(
self, task_id: str, params: RequestParameters, dispatch: bool = False
) -> None:
task = self.repo.get_or_raise(task_id)
if params.user and (
params.user.is_site_admin()
or task.owner_id == params.user.impersonator
):
self._cancel_task(task_id, dispatch)
else:
raise UserHasNotPermissionError()

def _cancel_task(self, task_id: str, dispatch: bool = False) -> None:
task = self.repo.get_or_raise(task_id)
if task_id in self.tasks:
self.tasks[task_id].cancel()
task.status = TaskStatus.CANCELLED
self.repo.save(task)
elif dispatch:
self.event_bus.push(
Event(type=EventType.TASK_CANCEL_REQUEST, payload=task_id)
)

def status_task(
self,
task_id: str,
Expand Down Expand Up @@ -175,8 +207,11 @@ def list_db_tasks(
def await_task(
self, task_id: str, timeout_sec: Optional[int] = None
) -> None:
logger.info(f"Awaiting task {task_id}")
if task_id in self.tasks:
self.tasks[task_id].result(timeout_sec)
self.tasks[task_id].result(
timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT
)
else:
logger.warning(
f"Task {task_id} not handled by this worker, will poll for task completion from db"
Expand Down Expand Up @@ -211,11 +246,14 @@ def _run_task(
)

with db():
logger.info(f"Starting task {task_id}")
task = retry(lambda: self.repo.get_or_raise(task_id))
task.status = TaskStatus.RUNNING.value
self.repo.save(task)
logger.info(f"Task {task_id} set to RUNNING")
try:
result = callback(self._task_logger(task_id))
logger.info(f"Task {task_id} ended")
self._update_task_status(
task_id,
TaskStatus.COMPLETED
Expand Down
8 changes: 8 additions & 0 deletions antarest/core/tasks/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,12 @@ def get_task(
service.await_task(task_id)
return service.status_task(task_id, request_params, with_logs)

@bp.put("/tasks/{task_id}/cancel")
def cancel_task(
task_id: str,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
request_params = RequestParameters(user=current_user)
return service.cancel_task(task_id, request_params, dispatch=True)

return bp
2 changes: 0 additions & 2 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import sqlalchemy.ext.baked # type: ignore
import uvicorn # type: ignore
from dateutil import tz
from fastapi import FastAPI, HTTPException
from fastapi_jwt_auth import AuthJWT # type: ignore
from sqlalchemy import create_engine, text
Expand All @@ -23,7 +22,6 @@
from antarest.core.config import Config
from antarest.core.core_blueprint import create_utils_routes
from antarest.core.filetransfer.main import build_filetransfer_service
from antarest.core.filetransfer.web import create_file_transfer_api
from antarest.core.logging.utils import configure_logger, LoggingMiddleware
from antarest.core.persistence import upgrade_db
from antarest.core.swagger import customize_openapi
Expand Down
5 changes: 4 additions & 1 deletion antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,10 @@ def get_studies_information(
for k in cached_studies:
studies[k] = StudyMetadataDTO.parse_obj(cached_studies[k])
else:
for study in self.repository.get_all():
logger.info("Retrieving all studies")
all_studies = self.repository.get_all()
logger.info("Studies retrieved")
for study in all_studies:
if not managed or is_managed(study):
study_metadata = self._try_get_studies_information(
study, summary
Expand Down
6 changes: 5 additions & 1 deletion antarest/study/storage/abstract_storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,11 @@ def export_output(self, metadata: T, output_id: str, target: Path) -> Path:

@abstractmethod
def export_study_flat(
self, metadata: T, dest: Path, outputs: bool = True
self,
metadata: T,
dest: Path,
outputs: bool = True,
denormalize: bool = True,
) -> None:
raise NotImplementedError()

Expand Down
6 changes: 5 additions & 1 deletion antarest/study/storage/rawstudy/raw_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def import_study(self, metadata: RawStudy, stream: IO[bytes]) -> Study:
return metadata

def export_study_flat(
self, metadata: RawStudy, dest: Path, outputs: bool = True
self,
metadata: RawStudy,
dest: Path,
outputs: bool = True,
denormalize: bool = True,
) -> None:
path_study = Path(metadata.path)
start_time = time.time()
Expand Down
40 changes: 16 additions & 24 deletions antarest/study/storage/variantstudy/variant_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,23 +456,6 @@ def _get_variants_parents(

return output_list

def get_study_information(
self, study: VariantStudy, summary: bool = False
) -> StudyMetadataDTO:
"""
Get information present in study.antares file
Args:
study: study
summary: if true, only retrieve basic info from database
Returns: study metadata
"""
return super().get_study_information(
study,
summary,
)

def get(
self,
metadata: VariantStudy,
Expand All @@ -493,7 +476,7 @@ def get(
Returns: study data formatted in json
"""
self._safe_generation(metadata, timeout=30)
self._safe_generation(metadata, timeout=60)
self.repository.refresh(metadata)
return super().get(
metadata=metadata,
Expand Down Expand Up @@ -567,6 +550,7 @@ def generate_task(
)
):
logger.info(f"Starting variant study {metadata.id} generation")
self.repository.refresh(metadata)
if metadata.generation_task:
try:
previous_task = self.task_service.status_task(
Expand Down Expand Up @@ -670,7 +654,10 @@ def _generate(
)
else:
self.raw_study_service.export_study_flat(
metadata=parent_study, dest=dest_path, outputs=False
metadata=parent_study,
dest=dest_path,
outputs=False,
denormalize=False,
)

results = self._generate_snapshot(variant_study, notifier)
Expand Down Expand Up @@ -863,7 +850,7 @@ def get_study_sim_result(
study: study
Returns: study output data
"""
self._safe_generation(study, timeout=30)
self._safe_generation(study, timeout=60)
return super().get_study_sim_result(study=study)

def set_reference_output(
Expand Down Expand Up @@ -917,7 +904,11 @@ def get_study_path(self, metadata: Study) -> Path:
return Path(metadata.path) / SNAPSHOT_RELATIVE_PATH

def export_study_flat(
self, metadata: VariantStudy, dest: Path, outputs: bool = True
self,
metadata: VariantStudy,
dest: Path,
outputs: bool = True,
denormalize: bool = True,
) -> None:

self._safe_generation(metadata)
Expand All @@ -938,6 +929,7 @@ def export_study_flat(
duration = "{:.3f}".format(stop_time - start_time)
logger.info(f"Study {path_study} exported (flat mode) in {duration}s")
_, study = self.study_factory.create_from_fs(dest, "", use_cache=False)
study.denormalize()
duration = "{:.3f}".format(time.time() - stop_time)
logger.info(f"Study {path_study} denormalized in {duration}s")
if denormalize:
study.denormalize()
duration = "{:.3f}".format(time.time() - stop_time)
logger.info(f"Study {path_study} denormalized in {duration}s")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="AntaREST",
version="2.1.4",
version="2.1.5",
description="Antares Server",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion sonar-project.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sonar.sources=antarest
sonar.language=python
sonar.exclusions=antarest/gui.py,antarest/main.py
sonar.python.coverage.reportPaths=coverage.xml
sonar.projectVersion=2.1.4
sonar.projectVersion=2.1.5
41 changes: 40 additions & 1 deletion tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Callable
from unittest.mock import Mock, ANY, call

import pytest
from sqlalchemy import create_engine

from antarest.core.config import Config
from antarest.core.interfaces.eventbus import EventType, Event
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.persistence import Base
from antarest.core.requests import RequestParameters
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.tasks.model import (
TaskJob,
TaskStatus,
Expand Down Expand Up @@ -344,3 +346,40 @@ def test_repository():
== 0
)
assert task_repository.get(new_task.id) is None


def test_cancel():
engine = create_engine("sqlite:///:memory:", echo=True)
Base.metadata.create_all(engine)
DBSessionMiddleware(
Mock(),
custom_engine=engine,
session_args={"autocommit": False, "autoflush": False},
)

repo_mock = Mock(spec=TaskJobRepository)
repo_mock.list.return_value = []
service = TaskJobService(
config=Config(), repository=repo_mock, event_bus=Mock()
)

with pytest.raises(UserHasNotPermissionError):
service.cancel_task("a", RequestParameters())

service.cancel_task(
"b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True
)
service.event_bus.push.assert_called_with(
Event(type=EventType.TASK_CANCEL_REQUEST, payload="b")
)

creation_date = datetime.datetime.utcnow()
task = TaskJob(id="a", name="b", status=2, creation_date=creation_date)
repo_mock.list.return_value = [task]
repo_mock.get_or_raise.return_value = task
service.tasks["a"] = Mock()
service.cancel_task(
"a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True
)
task.status = TaskStatus.CANCELLED.value
repo_mock.save.assert_called_with(task)
2 changes: 1 addition & 1 deletion webapp/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "antares-web",
"version": "2.1.4",
"version": "2.1.5",
"private": true,
"dependencies": {
"@fortawesome/fontawesome-svg-core": "^1.2.36",
Expand Down
2 changes: 1 addition & 1 deletion webapp/public/locales/en/data.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
"jsonFormat": "JSON Format",
"graphSelector": "Columns",
"monotonicView": "Monotonic",
"matrixEmpty": "No data"
"matrixEmpty": "Empty matrix"
}
Loading

0 comments on commit 2c7ccf5

Please sign in to comment.