Skip to content

Commit

Permalink
Merge pull request #679 from AntaresSimulatorTeam/dev
Browse files Browse the repository at this point in the history
v2.2.0
  • Loading branch information
pl-buiquang authored Dec 15, 2021
2 parents 2c7ccf5 + 2f85ddc commit 992106a
Show file tree
Hide file tree
Showing 101 changed files with 1,905 additions and 545 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
with:
options: --check --diff
- name: Check Typing (mypy)
continue-on-error: true
#continue-on-error: true
run: |
mypy --install-types --non-interactive
mypy
Expand Down
2 changes: 1 addition & 1 deletion antarest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.5"
__version__ = "2.2.0"

from pathlib import Path

Expand Down
2 changes: 1 addition & 1 deletion antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class RedisCacheElement(BaseModel):


class RedisCache(ICache):
def __init__(self, redis_client: Redis):
def __init__(self, redis_client: Redis): # type: ignore
self.redis = redis_client

def start(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/cache/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def build_cache(
config: Config, redis_client: Optional[Redis] = None
config: Config, redis_client: Optional[Redis] = None # type: ignore
) -> ICache:
cache = (
RedisCache(redis_client)
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class SlurmConfig:
password: str = ""
default_wait_time: int = 0
default_time_limit: int = 0
default_n_cpu: int = 0
default_n_cpu: int = 1
default_json_db_name: str = ""
slurm_script_path: str = ""
antares_versions_on_remote_server: List[str] = field(
Expand Down
5 changes: 5 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.EXPECTATION_FAILED, message)


class CommandApplicationError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message)


class CommandUpdateAuthorizationError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.LOCKED, message)
Expand Down
10 changes: 5 additions & 5 deletions antarest/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from io import StringIO
from pathlib import Path

from alembic import command # type: ignore
from alembic.config import Config # type: ignore
from alembic.util import CommandError # type: ignore
from alembic import command
from alembic.config import Config
from alembic.util import CommandError
from sqlalchemy.ext.declarative import declarative_base # type: ignore

from antarest.core.utils.utils import get_local_path
Expand All @@ -18,7 +18,7 @@
def upgrade_db(config_file: Path) -> None:

os.environ.setdefault("ANTAREST_CONF", str(config_file))
alembic_cfg = Config(get_local_path() / "alembic.ini")
alembic_cfg = Config(str(get_local_path() / "alembic.ini"))
alembic_cfg.stdout = StringIO()
alembic_cfg.set_main_option(
"script_location", str(get_local_path() / "alembic")
Expand All @@ -35,7 +35,7 @@ def upgrade_db(config_file: Path) -> None:
raise e

alembic_cfg.stdout = StringIO()
command.heads(alembic_cfg)
command.heads(alembic_cfg) # type: ignore
head_output = alembic_cfg.stdout.getvalue()
head = head_output.split(" ")[0].strip()
if current_version != head:
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from fastapi import HTTPException
from markupsafe import escape
from markupsafe import escape # type: ignore

from antarest.core.jwt import JWTUser

Expand Down
2 changes: 1 addition & 1 deletion antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO:
if self.completion_date
else None,
logs=sorted(
[log.to_dto() for log in self.logs], key=lambda l: l.id
[log.to_dto() for log in self.logs], key=lambda l: l.id # type: ignore
)
if with_logs
else None,
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_local_path() -> Path:
return filepath


def new_redis_instance(config: RedisConfig) -> redis.Redis:
def new_redis_instance(config: RedisConfig) -> redis.Redis: # type: ignore
return redis.Redis(host=config.host, port=config.port, db=0)


Expand Down
2 changes: 1 addition & 1 deletion antarest/eventbus/business/redis_eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class RedisEventBus(IEventBusBackend):
def __init__(self, redis_client: Redis) -> None:
def __init__(self, redis_client: Redis) -> None: # type: ignore
self.redis = redis_client
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(REDIS_STORE_KEY)
Expand Down
2 changes: 1 addition & 1 deletion antarest/eventbus/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def build_eventbus(
application: FastAPI,
config: Config,
autostart: bool = True,
redis_client: Optional[Redis] = None,
redis_client: Optional[Redis] = None, # type: ignore
) -> IEventBus:

eventbus = EventBusService(
Expand Down
1 change: 0 additions & 1 deletion antarest/eventbus/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
self.start()

def push(self, event: Event) -> None:
# TODO add arg permissions with group/role, user, public
self.backend.push_event(event)

def add_listener(
Expand Down
5 changes: 2 additions & 3 deletions antarest/eventbus/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from antarest.core.config import Config
from antarest.core.interfaces.eventbus import IEventBus, Event
from antarest.core.jwt import JWTUser, DEFAULT_ADMIN_USER
from antarest.core.model import PermissionInfo, StudyPermissionType
from antarest.core.model import PermissionInfo, StudyPermissionType, PublicMode
from antarest.core.permissions import check_permission
from antarest.login.auth import Auth

Expand Down Expand Up @@ -82,7 +82,7 @@ async def broadcast(
self, message: str, permissions: PermissionInfo, channel: Optional[str]
) -> None:
for connection in self.active_connections:
if check_permission(
if channel is not None or check_permission(
connection.user, permissions, StudyPermissionType.READ
):
if (
Expand Down Expand Up @@ -119,7 +119,6 @@ async def connect(
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
user = Auth.get_user_from_token(token, jwt_manager)
if user is None:
# TODO check auth and subscribe to rooms
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
except Exception as e:
logger.error(
Expand Down
4 changes: 2 additions & 2 deletions antarest/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def open_app() -> None:
menu = QMenu()
openapp = QAction("Open application")
menu.addAction(openapp)
openapp.triggered.connect(open_app)
openapp.triggered.connect(open_app) # type: ignore

# To quit the app
quit = QAction("Quit")
quit.triggered.connect(app.quit)
quit.triggered.connect(app.quit) # type: ignore
menu.addAction(quit)

# Adding options to the System Tray
Expand Down
9 changes: 7 additions & 2 deletions antarest/launcher/adapters/abstractlauncher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Callable, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional, Any
from uuid import UUID

from antarest.core.config import Config
from antarest.core.model import JSON
from antarest.core.requests import RequestParameters
from antarest.launcher.model import JobStatus, LogType
from antarest.study.service import StudyService
Expand Down Expand Up @@ -31,7 +32,11 @@ def __init__(

@abstractmethod
def run_study(
self, study_uuid: str, version: str, params: RequestParameters
self,
study_uuid: str,
version: str,
launcher_parameters: Optional[JSON],
params: RequestParameters,
) -> UUID:
raise NotImplementedError()

Expand Down
7 changes: 6 additions & 1 deletion antarest/launcher/adapters/local_launcher/local_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import UUID, uuid4

from antarest.core.config import Config
from antarest.core.model import JSON
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.launcher.adapters.abstractlauncher import (
Expand All @@ -31,7 +32,11 @@ def __init__(
self.job_id_to_study_id: Dict[str, str] = {}

def run_study(
self, study_uuid: str, version: str, params: RequestParameters
self,
study_uuid: str,
version: str,
launcher_parameters: Optional[JSON],
params: RequestParameters,
) -> UUID:
if self.config.launcher.local is None:
raise LauncherInitException()
Expand Down
105 changes: 96 additions & 9 deletions antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EventChannelDirectory,
)
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.model import JSON
from antarest.core.requests import RequestParameters
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.launcher.adapters.abstractlauncher import (
Expand All @@ -39,6 +40,11 @@
logging.getLogger("paramiko").setLevel("WARN")


MAX_NB_CPU = 24
MAX_TIME_LIMIT = 604800
MIN_TIME_LIMIT = 3600


class VersionNotSupportedError(Exception):
pass

Expand Down Expand Up @@ -155,14 +161,45 @@ def _delete_study(self, study_path: Path) -> None:
if study_path.exists():
shutil.rmtree(study_path)

def _import_study_output(self, job_id: str) -> Optional[str]:
def _import_study_output(
self, job_id: str, xpansion_mode: bool = False
) -> Optional[str]:
study_id = self.job_id_to_study_id[job_id]
if xpansion_mode:
self._import_xpansion_result(job_id, study_id)
return self.storage_service.import_output(
study_id,
self.slurm_config.local_workspace / "OUTPUT" / job_id / "output",
params=RequestParameters(DEFAULT_ADMIN_USER),
)

def _import_xpansion_result(self, job_id: str, study_id: str) -> None:
output_path = (
self.slurm_config.local_workspace / "OUTPUT" / job_id / "output"
)
if output_path.exists() and len(os.listdir(output_path)) == 1:
output_path = output_path / os.listdir(output_path)[0]
shutil.copytree(
self.slurm_config.local_workspace
/ "OUTPUT"
/ job_id
/ "input"
/ "links",
output_path / "updated_links",
)
study = self.storage_service.get_study(study_id)
if int(study.version) < 800:
shutil.copytree(
self.slurm_config.local_workspace
/ "OUTPUT"
/ job_id
/ "user"
/ "expansion",
output_path / "results",
)
else:
logger.warning("Output path in xpansion result not found")

def _check_studies_state(self) -> None:
try:
run_with(
Expand Down Expand Up @@ -191,11 +228,13 @@ def _check_studies_state(self) -> None:
with db():
output_id: Optional[str] = None
if not study.with_error:
output_id = self._import_study_output(study.name)
output_id = self._import_study_output(
study.name, study.xpansion_study
)
self.callbacks.update_status(
study.name,
JobStatus.FAILED
if study.with_error
if study.with_error or output_id is None
else JobStatus.SUCCESS,
None,
output_id,
Expand Down Expand Up @@ -283,7 +322,11 @@ def _clean_up_study(self, launch_id: str) -> None:
del self.job_id_to_study_id[launch_id]

def _run_study(
self, study_uuid: str, launch_uuid: str, params: RequestParameters
self,
study_uuid: str,
launch_uuid: str,
launcher_params: Optional[JSON],
params: RequestParameters,
) -> None:
with db():
study_path = Path(self.launcher_args.studies_in) / str(launch_uuid)
Expand All @@ -301,8 +344,11 @@ def _run_study(

self._assert_study_version_is_supported(study_uuid, params)

launcher_args = self._check_and_apply_launcher_params(
launcher_params
)
run_with(
self.launcher_args, self.launcher_params, show_banner=False
launcher_args, self.launcher_params, show_banner=False
)
self.callbacks.update_status(
str(launch_uuid), JobStatus.RUNNING, None, None
Expand All @@ -322,13 +368,47 @@ def _run_study(

self._delete_study(study_path)

def _check_and_apply_launcher_params(
self, launcher_params: Optional[JSON]
) -> argparse.Namespace:
if launcher_params:
launcher_args = deepcopy(self.launcher_args)
if launcher_params.get("xpansion", False):
launcher_args.xpansion_mode = True
time_limit = launcher_params.get("time_limit", None)
if time_limit and isinstance(time_limit, int):
if MIN_TIME_LIMIT < time_limit < MAX_TIME_LIMIT:
launcher_args.time_limit = time_limit
else:
logger.warning(
f"Invalid slurm launcher time limit ({time_limit}), should be between {MIN_TIME_LIMIT} and {MAX_TIME_LIMIT}"
)
post_processing = launcher_params.get("post_processing", False)
if isinstance(post_processing, bool):
launcher_args.post_processing = post_processing
nb_cpu = launcher_params.get("nb_cpu", None)
if nb_cpu and isinstance(nb_cpu, int):
if 0 < nb_cpu <= MAX_NB_CPU:
launcher_args.n_cpu = nb_cpu
else:
logger.warning(
f"Invalid slurm launcher nb_cpu ({nb_cpu}), should be between 1 and 24"
)
return launcher_args
return self.launcher_args

def run_study(
self, study_uuid: str, version: str, params: RequestParameters
self,
study_uuid: str,
version: str,
launcher_parameters: Optional[JSON],
params: RequestParameters,
) -> UUID: # TODO: version ?
launch_uuid = uuid4()

thread = threading.Thread(
target=self._run_study, args=(study_uuid, launch_uuid, params)
target=self._run_study,
args=(study_uuid, launch_uuid, launcher_parameters, params),
)
thread.start()

Expand All @@ -351,5 +431,12 @@ def kill_job(self, job_id: str) -> None:
launcher_args, self.launcher_params, show_banner=False
)
return

raise JobIdNotFound()
logger.warning(
"Failed to retrieve job id in antares launcher database"
)
self.callbacks.update_status(
job_id,
JobStatus.FAILED,
None,
None,
)
Loading

0 comments on commit 992106a

Please sign in to comment.