Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fix/fix-endpoint-links-creation'…
Browse files Browse the repository at this point in the history
… into feature/add-endpoint-link-update
  • Loading branch information
TheoPascoli committed Sep 30, 2024
2 parents 39f81f6 + 0f4d6e0 commit d81eda0
Show file tree
Hide file tree
Showing 57 changed files with 3,575 additions and 4,012 deletions.
1 change: 1 addition & 0 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TaskType(str, Enum):
SCAN = "SCAN"
UPGRADE_STUDY = "UPGRADE_STUDY"
THERMAL_CLUSTER_SERIES_GENERATION = "THERMAL_CLUSTER_SERIES_GENERATION"
SNAPSHOT_CLEARING = "SNAPSHOT_CLEARING"


class TaskStatus(Enum):
Expand Down
15 changes: 2 additions & 13 deletions antarest/study/business/link_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import typing as t


from antarest.core.exceptions import ConfigFileNotFound, LinkValidationError
from antarest.core.model import JSON
from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model
Expand Down Expand Up @@ -56,7 +55,7 @@ class LinkManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service

def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkInfoDTOType]:
def get_all_links(self, study: RawStudy) -> t.List[LinkInfoDTOType]:
file_study = self.storage_service.get_storage(study).get_raw(study)
result: t.List[LinkInfoDTOType] = []

Expand All @@ -68,16 +67,6 @@ def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkIn

link_creation_data: t.Dict[str, t.Any] = {"area1": area_id, "area2": link}
link_creation_data.update(link_properties)
if not with_ui:
link_creation_data.update(
{
"colorr": None,
"colorb": None,
"colorg": None,
"link-width": None,
"link-style": None,
}
)

link_data: LinkInfoDTOType
if int(study.version) < 820:
Expand All @@ -91,7 +80,7 @@ def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkIn

def create_link(self, study: RawStudy, link_creation_info: LinkInfoDTOType) -> LinkInfoDTOType:
if link_creation_info.area1 == link_creation_info.area2:
raise LinkValidationError("Cannot create link on same node")
raise LinkValidationError(f"Cannot create link on same area: {link_creation_info.area1}")

study_version = int(study.version)
if study_version < 820 and isinstance(link_creation_info, LinkInfoDTO820):
Expand Down
3 changes: 1 addition & 2 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,12 +1811,11 @@ def get_all_areas(
def get_all_links(
self,
uuid: str,
with_ui: bool,
params: RequestParameters,
) -> t.List[LinkInfoDTOType]:
study = self.get_study(uuid)
assert_permission(params.user, study, StudyPermissionType.READ)
return self.links.get_all_links(study, with_ui)
return self.links.get_all_links(study)

def create_area(
self,
Expand Down
69 changes: 34 additions & 35 deletions antarest/study/storage/variantstudy/model/command/create_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import typing as t
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from pydantic import AliasGenerator, BaseModel, Field, ValidationInfo, field_validator, model_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator

from antarest.core.exceptions import LinkValidationError
from antarest.core.utils.string import to_kebab_case
Expand All @@ -35,54 +35,47 @@ class AreaInfo(BaseModel):


class LinkInfoProperties(BaseModel):
hurdles_cost: Optional[bool] = Field(False, alias="hurdles-cost")
loop_flow: Optional[bool] = Field(False, alias="loop-flow")
use_phase_shifter: Optional[bool] = Field(False, alias="use-phase-shifter")
transmission_capacities: Optional[TransmissionCapacity] = Field(
TransmissionCapacity.ENABLED, alias="transmission-capacities"
)
asset_type: Optional[AssetType] = Field(AssetType.AC, alias="asset-type")
display_comments: Optional[bool] = Field(True, alias="display-comments")
colorr: Optional[int] = DEFAULT_COLOR
colorb: Optional[int] = DEFAULT_COLOR
colorg: Optional[int] = DEFAULT_COLOR
link_width: Optional[float] = Field(1, alias="link-width")
link_style: Optional[LinkStyle] = Field(LinkStyle.PLAIN, alias="link-style")

@model_validator(mode="before")
def validate_colors(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
if type(values) is dict:
colors = {
"colorr": values.get("colorr"),
"colorb": values.get("colorb"),
"colorg": values.get("colorg"),
}
for color_name, color_value in colors.items():
if color_value is not None and (color_value < 0 or color_value > 255):
raise LinkValidationError(f"Invalid value for {color_name}. Must be between 0 and 255.")

return values
hurdles_cost: bool = False
loop_flow: bool = False
use_phase_shifter: bool = False
transmission_capacities: TransmissionCapacity = TransmissionCapacity.ENABLED
asset_type: AssetType = AssetType.AC
display_comments: bool = True
colorr: int = Field(default=DEFAULT_COLOR, gt=0, lt=255)
colorb: int = Field(default=DEFAULT_COLOR, gt=0, lt=255)
colorg: int = Field(default=DEFAULT_COLOR, gt=0, lt=255)
link_width: float = 1
link_style: LinkStyle = LinkStyle.PLAIN

class Config:
alias_generator = to_kebab_case
populate_by_name = True


class LinkInfoProperties820(LinkInfoProperties):
filter_synthesis: Optional[str] = Field(None, alias="filter-synthesis")
filter_year_by_year: Optional[str] = Field(None, alias="filter-year-by-year")
filter_synthesis: t.Optional[str] = None
filter_year_by_year: t.Optional[str] = None

class Config:
alias_generator = to_kebab_case
populate_by_name = True

@field_validator("filter_synthesis", "filter_year_by_year", mode="before")
def validate_individual_filters(cls, value: Optional[str], field: Any) -> Optional[str]:
def validate_individual_filters(cls, value: Optional[str]) -> Optional[str]:
if value is not None:
filter_options = ["hourly", "daily", "weekly", "monthly", "annual"]
filter_values = ["hourly", "daily", "weekly", "monthly", "annual"]

options = value.replace(" ", "").split(",")
invalid_options = [opt for opt in options if opt not in filter_options]
invalid_options = [opt for opt in options if opt not in filter_values]
if invalid_options:
raise LinkValidationError(
f"Invalid value(s) in filters: {', '.join(invalid_options)}. "
f"Allowed values are: {', '.join(filter_options)}."
f"Allowed values are: {', '.join(filter_values)}."
)
return value


class LinkProperties(LinkInfoProperties820, alias_generator=AliasGenerator(serialization_alias=to_kebab_case)):
class LinkProperties(LinkInfoProperties820):
pass


Expand Down Expand Up @@ -114,6 +107,12 @@ def validate_series(
new_values = values if isinstance(values, dict) else values.data
return validate_matrix(v, new_values) if v is not None else v

@model_validator(mode="after")
def validate_areas(self) -> "CreateLink":
if self.area1 == self.area2:
raise ValueError("Cannot create link on same node")
return self

def _create_link_in_config(self, area_from: str, area_to: str, study_data: FileStudyTreeConfig) -> None:
self.parameters = self.parameters or {}
study_data.areas[area_from].links[area_to] = Link(
Expand Down
74 changes: 70 additions & 4 deletions antarest/study/storage/variantstudy/variant_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import re
import shutil
import typing as t
from datetime import datetime
from datetime import datetime, timedelta
from functools import reduce
from pathlib import Path
from uuid import uuid4
Expand Down Expand Up @@ -45,9 +45,11 @@
from antarest.core.serialization import to_json_string
from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskResult, TaskType
from antarest.core.tasks.service import DEFAULT_AWAIT_MAX_TIMEOUT, ITaskService, TaskUpdateNotifier, noop_notifier
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import assert_this, suppress_exception
from antarest.matrixstore.service import MatrixService
from antarest.study.model import RawStudy, Study, StudyAdditionalData, StudyMetadataDTO, StudySimResultDTO
from antarest.study.repository import AccessPermissions, StudyFilter
from antarest.study.storage.abstract_storage_service import AbstractStorageService
from antarest.study.storage.patch_service import PatchService
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig, FileStudyTreeConfigDTO
Expand Down Expand Up @@ -625,9 +627,9 @@ def callback(notifier: TaskUpdateNotifier) -> TaskResult:
)
return TaskResult(
success=generate_result.success,
message=f"{study_id} generated successfully"
if generate_result.success
else f"{study_id} not generated",
message=(
f"{study_id} generated successfully" if generate_result.success else f"{study_id} not generated"
),
return_value=generate_result.model_dump_json(),
)

Expand Down Expand Up @@ -1053,3 +1055,67 @@ def initialize_additional_data(self, variant_study: VariantStudy) -> bool:
exc_info=e,
)
return False

def clear_all_snapshots(self, retention_hours: timedelta, params: t.Optional[RequestParameters] = None) -> str:
"""
Admin command that clear all variant snapshots older than `retention_hours` (in hours).
Only available for admin users.
Args:
retention_hours: number of retention hours
params: request parameters used to identify the user status
Returns: None
Raises:
UserHasNotPermissionError
"""
if params is None or (params.user and not params.user.is_site_admin() and not params.user.is_admin_token()):
raise UserHasNotPermissionError()

task_name = f"Cleaning all snapshot updated or accessed at least {retention_hours} hours ago."

snapshot_clearing_task_instance = SnapshotCleanerTask(
variant_study_service=self, retention_hours=retention_hours
)

return self.task_service.add_task(
snapshot_clearing_task_instance,
task_name,
task_type=TaskType.SNAPSHOT_CLEARING,
ref_id="SNAPSHOT_CLEANING",
custom_event_messages=None,
request_params=params,
)


class SnapshotCleanerTask:
def __init__(
self,
variant_study_service: VariantStudyService,
retention_hours: timedelta,
) -> None:
self._variant_study_service = variant_study_service
self._retention_hours = retention_hours

def _clear_all_snapshots(self) -> None:
with db():
variant_list = self._variant_study_service.repository.get_all(
study_filter=StudyFilter(
variant=True,
access_permissions=AccessPermissions(is_admin=True),
)
)
for variant in variant_list:
if variant.updated_at and variant.updated_at < datetime.utcnow() - self._retention_hours:
if variant.last_access and variant.last_access < datetime.utcnow() - self._retention_hours:
self._variant_study_service.clear_snapshot(variant)

def run_task(self, notifier: TaskUpdateNotifier) -> TaskResult:
msg = f"Start cleaning all snapshots updated or accessed {self._retention_hours} hours ago."
notifier(msg)
self._clear_all_snapshots()
msg = f"All selected snapshots were successfully cleared."
notifier(msg)
return TaskResult(success=True, message=msg)

__call__ = run_task
3 changes: 1 addition & 2 deletions antarest/study/web/study_data_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,14 @@ def get_areas(
)
def get_links(
uuid: str,
with_ui: bool = False,
current_user: JWTUser = Depends(auth.get_current_user),
) -> t.Any:
logger.info(
f"Fetching link list for study {uuid}",
extra={"user": current_user.id},
)
params = RequestParameters(user=current_user)
areas_list = study_service.get_all_links(uuid, with_ui, params)
areas_list = study_service.get_all_links(uuid, params)
return areas_list

@bp.post(
Expand Down
31 changes: 30 additions & 1 deletion antarest/study/web/variant_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.

import datetime
import logging
from typing import List, Optional, Union

Expand Down Expand Up @@ -416,4 +416,33 @@ def create_from_variant(
params = RequestParameters(user=current_user)
raise NotImplementedError()

@bp.put(
"/studies/variants/clear-snapshots",
tags=[APITag.study_variant_management],
summary="Clear variant snapshots",
responses={
200: {
"description": "Delete snapshots older than a specific number of hours. By default, this number is 24."
}
},
)
def clear_variant_snapshots(
hours: int = 24,
current_user: JWTUser = Depends(auth.get_current_user),
) -> str:
"""
Endpoint that clear `limit` hours old and older variant snapshots.
Args: limit (int, optional): Number of hours to clear. Defaults to 24.
Returns: ID of the task running the snapshot clearing.
"""
retention_hours = datetime.timedelta(hours=hours)
logger.info(
f"Delete all variant snapshots older than {retention_hours} hours.",
extra={"user": current_user.id},
)
params = RequestParameters(user=current_user)
return variant_study_service.clear_all_snapshots(retention_hours, params)

return bp
Loading

0 comments on commit d81eda0

Please sign in to comment.