Skip to content

Commit

Permalink
build(python): bump project dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Aug 29, 2024
1 parent adf9146 commit 92f9e96
Show file tree
Hide file tree
Showing 184 changed files with 2,859 additions and 3,467 deletions.
2 changes: 1 addition & 1 deletion antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def put(self, id: str, data: JSON, duration: int = 3600) -> None:
redis_element = RedisCacheElement(duration=duration, data=data)
redis_key = f"cache:{id}"
logger.info(f"Adding cache key {id}")
self.redis.set(redis_key, redis_element.json())
self.redis.set(redis_key, redis_element.model_dump_json())
self.redis.expire(redis_key, duration)

def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
Expand Down
17 changes: 12 additions & 5 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import multiprocessing
import tempfile
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional

Expand All @@ -12,6 +13,12 @@
DEFAULT_WORKSPACE_NAME = "default"


class Launcher(str, Enum):
SLURM = "slurm"
LOCAL = "local"
DEFAULT = "default"


@dataclass(frozen=True)
class ExternalAuthConfig:
"""
Expand Down Expand Up @@ -387,7 +394,7 @@ def __post_init__(self) -> None:
msg = f"Invalid configuration: {self.default=} must be one of {possible!r}"
raise ValueError(msg)

def get_nb_cores(self, launcher: str) -> "NbCoresConfig":
def get_nb_cores(self, launcher: Launcher) -> "NbCoresConfig":
"""
Retrieve the number of cores configuration for a given launcher: "local" or "slurm".
If "default" is specified, retrieve the configuration of the default launcher.
Expand All @@ -404,12 +411,12 @@ def get_nb_cores(self, launcher: str) -> "NbCoresConfig":
"""
config_map = {"local": self.local, "slurm": self.slurm}
config_map["default"] = config_map[self.default]
launcher_config = config_map.get(launcher)
launcher_config = config_map.get(launcher.value)
if launcher_config is None:
raise InvalidConfigurationError(launcher)
raise InvalidConfigurationError(launcher.value)
return launcher_config.nb_cores

def get_time_limit(self, launcher: str) -> TimeLimitConfig:
def get_time_limit(self, launcher: Launcher) -> TimeLimitConfig:
"""
Retrieve the time limit for a job of the given launcher: "local" or "slurm".
If "default" is specified, retrieve the configuration of the default launcher.
Expand All @@ -426,7 +433,7 @@ def get_time_limit(self, launcher: str) -> TimeLimitConfig:
"""
config_map = {"local": self.local, "slurm": self.slurm}
config_map["default"] = config_map[self.default]
launcher_config = config_map.get(launcher)
launcher_config = config_map.get(launcher.value)
if launcher_config is None:
raise InvalidConfigurationError(launcher)
return launcher_config.time_limit
Expand Down
16 changes: 1 addition & 15 deletions antarest/core/core_blueprint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import logging
from typing import Any

from fastapi import APIRouter, Depends
from fastapi import APIRouter
from pydantic import BaseModel

from antarest.core.config import Config
from antarest.core.jwt import JWTUser
from antarest.core.requests import UserHasNotPermissionError
from antarest.core.utils.web import APITag
from antarest.core.version_info import VersionInfoDTO, get_commit_id, get_dependencies
from antarest.login.auth import Auth
Expand Down Expand Up @@ -54,15 +51,4 @@ def version_info() -> Any:
dependencies=get_dependencies(),
)

@bp.get("/kill", include_in_schema=False)
def kill_worker(
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
if not current_user.is_site_admin():
raise UserHasNotPermissionError()
logging.getLogger(__name__).critical("Killing the worker")
# PyInstaller modifies the behavior of built-in functions, such as `exit`.
# It is advisable to use `sys.exit` or raise the `SystemExit` exception instead.
raise SystemExit(f"Worker killed by the user #{current_user.id}")

return bp
28 changes: 15 additions & 13 deletions antarest/core/filesystem_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@

import typing_extensions as te
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Field
from starlette.responses import PlainTextResponse, StreamingResponse

from antarest.core.config import Config
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth

FilesystemName = te.Annotated[str, Field(regex=r"^\w+$", description="Filesystem name")]
MountPointName = te.Annotated[str, Field(regex=r"^\w+$", description="Mount point name")]
FilesystemName = te.Annotated[str, Field(pattern=r"^\w+$", description="Filesystem name")]
MountPointName = te.Annotated[str, Field(pattern=r"^\w+$", description="Mount point name")]


class FilesystemDTO(
BaseModel,
extra=Extra.forbid,
schema_extra={
extra="forbid",
json_schema_extra={
"example": {
"name": "ws",
"mount_dirs": {
Expand All @@ -50,8 +50,8 @@ class FilesystemDTO(

class MountPointDTO(
BaseModel,
extra=Extra.forbid,
schema_extra={
extra="forbid",
json_schema_extra={
"example": {
"name": "default",
"path": "/path/to/workspaces/internal_studies",
Expand All @@ -77,10 +77,10 @@ class MountPointDTO(

name: MountPointName
path: Path = Field(description="Full path of the mount point in Antares Web Server")
total_bytes: int = Field(0, description="Total size of the mount point in bytes")
used_bytes: int = Field(0, description="Used size of the mount point in bytes")
free_bytes: int = Field(0, description="Free size of the mount point in bytes")
message: str = Field("", description="A message describing the status of the mount point")
total_bytes: t.Optional[int] = 0 # Total size of the mount point in bytes
used_bytes: t.Optional[int] = 0 # Used size of the mount point in bytes
free_bytes: t.Optional[int] = 0 # Free size of the mount point in bytes
message: t.Optional[str] = "" # A message describing the status of the mount point

@classmethod
async def from_path(cls, name: str, path: Path) -> "MountPointDTO":
Expand All @@ -98,8 +98,8 @@ async def from_path(cls, name: str, path: Path) -> "MountPointDTO":

class FileInfoDTO(
BaseModel,
extra=Extra.forbid,
schema_extra={
extra="forbid",
json_schema_extra={
"example": {
"path": "/path/to/workspaces/internal_studies/5a503c20-24a3-4734-9cf8-89565c9db5ec/study.antares",
"file_type": "file",
Expand Down Expand Up @@ -148,6 +148,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo
path=full_path,
file_type="unknown",
file_count=0, # missing
size_bytes=0, # missing
created=datetime.datetime.min,
modified=datetime.datetime.min,
accessed=datetime.datetime.min,
Expand All @@ -162,6 +163,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo
created=datetime.datetime.fromtimestamp(file_stat.st_ctime),
modified=datetime.datetime.fromtimestamp(file_stat.st_mtime),
accessed=datetime.datetime.fromtimestamp(file_stat.st_atime),
message="OK",
)

if stat.S_ISDIR(file_stat.st_mode):
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/filetransfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class FileDownloadDTO(BaseModel):
id: str
name: str
filename: str
expiration_date: Optional[str]
expiration_date: Optional[str] = None
ready: bool
failed: bool = False
error_message: str = ""
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

JSON = Dict[str, Any]
ELEMENT = Union[str, int, float, bool, bytes]
SUB_JSON = Union[ELEMENT, JSON, List, None]
SUB_JSON = Union[ELEMENT, JSON, List[Any], None]


class PublicMode(str, enum.Enum):
Expand Down
15 changes: 8 additions & 7 deletions antarest/core/permissions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing as t

from antarest.core.jwt import JWTUser
from antarest.core.model import PermissionInfo, PublicMode, StudyPermissionType
Expand All @@ -7,8 +8,8 @@
logger = logging.getLogger(__name__)


permission_matrix = {
StudyPermissionType.READ: {
permission_matrix: t.Dict[str, t.Dict[str, t.Sequence[t.Union[RoleType, PublicMode]]]] = {
StudyPermissionType.READ.value: {
"roles": [
RoleType.ADMIN,
RoleType.RUNNER,
Expand All @@ -22,15 +23,15 @@
PublicMode.READ,
],
},
StudyPermissionType.RUN: {
StudyPermissionType.RUN.value: {
"roles": [RoleType.ADMIN, RoleType.RUNNER, RoleType.WRITER],
"public_modes": [PublicMode.FULL, PublicMode.EDIT, PublicMode.EXECUTE],
},
StudyPermissionType.WRITE: {
StudyPermissionType.WRITE.value: {
"roles": [RoleType.ADMIN, RoleType.WRITER],
"public_modes": [PublicMode.FULL, PublicMode.EDIT],
},
StudyPermissionType.MANAGE_PERMISSIONS: {
StudyPermissionType.MANAGE_PERMISSIONS.value: {
"roles": [RoleType.ADMIN],
"public_modes": [],
},
Expand Down Expand Up @@ -65,11 +66,11 @@ def check_permission(

allowed_roles = permission_matrix[permission]["roles"]
group_permission = any(
role in allowed_roles # type: ignore
role in allowed_roles
for role in [group.role for group in (user.groups or []) if group.id in permission_info.groups]
)
if group_permission:
return True

allowed_public_modes = permission_matrix[permission]["public_modes"]
return permission_info.public_mode in allowed_public_modes # type: ignore
return permission_info.public_mode in allowed_public_modes
45 changes: 43 additions & 2 deletions antarest/core/requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing as t
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional
from typing import Any, Generator, Tuple

from fastapi import HTTPException
from markupsafe import escape
Expand All @@ -17,13 +19,52 @@
}


class CaseInsensitiveDict(t.MutableMapping[str, t.Any]): # copy of the requests class to avoid importing the package
def __init__(self, data=None, **kwargs) -> None: # type: ignore
self._store: OrderedDict[str, t.Any] = OrderedDict()
if data is None:
data = {}
self.update(data, **kwargs)

def __setitem__(self, key: str, value: t.Any) -> None:
self._store[key.lower()] = (key, value)

def __getitem__(self, key: str) -> t.Any:
return self._store[key.lower()][1]

def __delitem__(self, key: str) -> None:
del self._store[key.lower()]

def __iter__(self) -> t.Any:
return (casedkey for casedkey, mappedvalue in self._store.values())

def __len__(self) -> int:
return len(self._store)

def lower_items(self) -> Generator[Tuple[Any, Any], Any, None]:
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())

def __eq__(self, other: t.Any) -> bool:
if isinstance(other, t.Mapping):
other = CaseInsensitiveDict(other)
else:
return NotImplemented
return dict(self.lower_items()) == dict(other.lower_items())

def copy(self) -> "CaseInsensitiveDict":
return CaseInsensitiveDict(self._store.values())

def __repr__(self) -> str:
return str(dict(self.items()))


@dataclass
class RequestParameters:
"""
DTO object to handle data inside request to send to service
"""

user: Optional[JWTUser] = None
user: t.Optional[JWTUser] = None

def get_user_id(self) -> str:
return str(escape(str(self.user.id))) if self.user else "Unknown"
Expand Down
41 changes: 22 additions & 19 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from enum import Enum

from pydantic import BaseModel, Extra
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore
Expand Down Expand Up @@ -44,43 +44,43 @@ def is_final(self) -> bool:
]


class TaskResult(BaseModel, extra=Extra.forbid):
class TaskResult(BaseModel, extra="forbid"):
success: bool
message: str
# Can be used to store json serialized result
return_value: t.Optional[str]
return_value: t.Optional[str] = None


class TaskLogDTO(BaseModel, extra=Extra.forbid):
class TaskLogDTO(BaseModel, extra="forbid"):
id: str
message: str


class CustomTaskEventMessages(BaseModel, extra=Extra.forbid):
class CustomTaskEventMessages(BaseModel, extra="forbid"):
start: str
running: str
end: str


class TaskEventPayload(BaseModel, extra=Extra.forbid):
class TaskEventPayload(BaseModel, extra="forbid"):
id: str
message: str


class TaskDTO(BaseModel, extra=Extra.forbid):
class TaskDTO(BaseModel, extra="forbid"):
id: str
name: str
owner: t.Optional[int]
owner: t.Optional[int] = None
status: TaskStatus
creation_date_utc: str
completion_date_utc: t.Optional[str]
result: t.Optional[TaskResult]
logs: t.Optional[t.List[TaskLogDTO]]
completion_date_utc: t.Optional[str] = None
result: t.Optional[TaskResult] = None
logs: t.Optional[t.List[TaskLogDTO]] = None
type: t.Optional[str] = None
ref_id: t.Optional[str] = None


class TaskListFilter(BaseModel, extra=Extra.forbid):
class TaskListFilter(BaseModel, extra="forbid"):
status: t.List[TaskStatus] = []
name: t.Optional[str] = None
type: t.List[TaskType] = []
Expand Down Expand Up @@ -158,20 +158,23 @@ class TaskJob(Base): # type: ignore
study: "Study" = relationship("Study", back_populates="jobs", uselist=False)

def to_dto(self, with_logs: bool = False) -> TaskDTO:
result = None
if self.completion_date:
assert self.result_status is not None
assert self.result_msg is not None
result = TaskResult(
success=self.result_status,
message=self.result_msg,
return_value=self.result,
)
return TaskDTO(
id=self.id,
owner=self.owner_id,
creation_date_utc=str(self.creation_date),
completion_date_utc=str(self.completion_date) if self.completion_date else None,
name=self.name,
status=TaskStatus(self.status),
result=TaskResult(
success=self.result_status,
message=self.result_msg,
return_value=self.result,
)
if self.completion_date
else None,
result=result,
logs=sorted([log.to_dto() for log in self.logs], key=lambda log: log.id) if with_logs else None,
type=self.type,
ref_id=self.ref_id,
Expand Down
Loading

0 comments on commit 92f9e96

Please sign in to comment.