Skip to content

Commit

Permalink
fix: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoPascoli committed Sep 30, 2024
1 parent e057c4b commit 1355def
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 89 deletions.
14 changes: 2 additions & 12 deletions antarest/study/business/link_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LinkManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service

def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkInfoDTOType]:
def get_all_links(self, study: RawStudy) -> t.List[LinkInfoDTOType]:
file_study = self.storage_service.get_storage(study).get_raw(study)
result: t.List[LinkInfoDTOType] = []

Expand All @@ -68,16 +68,6 @@ def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkIn

link_creation_data: t.Dict[str, t.Any] = {"area1": area_id, "area2": link}
link_creation_data.update(link_properties)
if not with_ui:
link_creation_data.update(
{
"colorr": None,
"colorb": None,
"colorg": None,
"link-width": None,
"link-style": None,
}
)

link_data: LinkInfoDTOType
if int(study.version) < 820:
Expand All @@ -91,7 +81,7 @@ def get_all_links(self, study: RawStudy, with_ui: bool = False) -> t.List[LinkIn

def create_link(self, study: RawStudy, link_creation_info: LinkInfoDTOType) -> LinkInfoDTOType:
if link_creation_info.area1 == link_creation_info.area2:
raise LinkValidationError("Cannot create link on same node")
raise LinkValidationError(f"Cannot create link on same area: {link_creation_info.area1}")

study_version = int(study.version)
if study_version < 820 and isinstance(link_creation_info, LinkInfoDTO820):
Expand Down
3 changes: 1 addition & 2 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,12 +1811,11 @@ def get_all_areas(
def get_all_links(
self,
uuid: str,
with_ui: bool,
params: RequestParameters,
) -> t.List[LinkInfoDTOType]:
study = self.get_study(uuid)
assert_permission(params.user, study, StudyPermissionType.READ)
return self.links.get_all_links(study, with_ui)
return self.links.get_all_links(study)

def create_area(
self,
Expand Down
70 changes: 34 additions & 36 deletions antarest/study/storage/variantstudy/model/command/create_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
import typing as t
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from pydantic import AliasGenerator, BaseModel, Field, ValidationInfo, field_validator, model_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator

from antarest.core.exceptions import LinkValidationError
from antarest.core.utils.string import to_kebab_case
Expand All @@ -35,54 +34,47 @@ class AreaInfo(BaseModel):


class LinkInfoProperties(BaseModel):
hurdles_cost: Optional[bool] = Field(False, alias="hurdles-cost")
loop_flow: Optional[bool] = Field(False, alias="loop-flow")
use_phase_shifter: Optional[bool] = Field(False, alias="use-phase-shifter")
transmission_capacities: Optional[TransmissionCapacity] = Field(
TransmissionCapacity.ENABLED, alias="transmission-capacities"
)
asset_type: Optional[AssetType] = Field(AssetType.AC, alias="asset-type")
display_comments: Optional[bool] = Field(True, alias="display-comments")
colorr: Optional[int] = DEFAULT_COLOR
colorb: Optional[int] = DEFAULT_COLOR
colorg: Optional[int] = DEFAULT_COLOR
link_width: Optional[float] = Field(1, alias="link-width")
link_style: Optional[LinkStyle] = Field(LinkStyle.PLAIN, alias="link-style")

@model_validator(mode="before")
def validate_colors(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
if type(values) is dict:
colors = {
"colorr": values.get("colorr"),
"colorb": values.get("colorb"),
"colorg": values.get("colorg"),
}
for color_name, color_value in colors.items():
if color_value is not None and (color_value < 0 or color_value > 255):
raise LinkValidationError(f"Invalid value for {color_name}. Must be between 0 and 255.")

return values
hurdles_cost: Optional[bool] = False
loop_flow: Optional[bool] = False
use_phase_shifter: Optional[bool] = False
transmission_capacities: Optional[TransmissionCapacity] = TransmissionCapacity.ENABLED
asset_type: Optional[AssetType] = AssetType.AC
display_comments: Optional[bool] = True
colorr: Optional[int] = Field(default=DEFAULT_COLOR, gt=0, lt=255)
colorb: Optional[int] = Field(default=DEFAULT_COLOR, gt=0, lt=255)
colorg: Optional[int] = Field(default=DEFAULT_COLOR, gt=0, lt=255)
link_width: Optional[float] = 1
link_style: Optional[LinkStyle] = LinkStyle.PLAIN

class Config:
alias_generator = to_kebab_case
allow_population_by_field_name = True


class LinkInfoProperties820(LinkInfoProperties):
filter_synthesis: Optional[str] = Field(None, alias="filter-synthesis")
filter_year_by_year: Optional[str] = Field(None, alias="filter-year-by-year")
filter_synthesis: Optional[str] = None
filter_year_by_year: Optional[str] = None

class Config:
alias_generator = to_kebab_case
allow_population_by_field_name = True

@field_validator("filter_synthesis", "filter_year_by_year", mode="before")
def validate_individual_filters(cls, value: Optional[str], field: Any) -> Optional[str]:
def validate_individual_filters(cls, value: Optional[str]) -> Optional[str]:
if value is not None:
filter_options = ["hourly", "daily", "weekly", "monthly", "annual"]
from antarest.study.business.areas.properties_management import DEFAULT_FILTER_VALUE

options = value.replace(" ", "").split(",")
invalid_options = [opt for opt in options if opt not in filter_options]
invalid_options = [opt for opt in options if opt not in DEFAULT_FILTER_VALUE]
if invalid_options:
raise LinkValidationError(
f"Invalid value(s) in filters: {', '.join(invalid_options)}. "
f"Allowed values are: {', '.join(filter_options)}."
f"Allowed values are: {', '.join(DEFAULT_FILTER_VALUE)}."
)
return value


class LinkProperties(LinkInfoProperties820, alias_generator=AliasGenerator(serialization_alias=to_kebab_case)):
class LinkProperties(LinkInfoProperties820):
pass


Expand Down Expand Up @@ -114,6 +106,12 @@ def validate_series(
new_values = values if isinstance(values, dict) else values.data
return validate_matrix(v, new_values) if v is not None else v

@model_validator(mode="after")
def validate_areas(self) -> "CreateLink":
if self.area1 == self.area2:
raise ValueError("Cannot create link on same node")
return self

def _create_link_in_config(self, area_from: str, area_to: str, study_data: FileStudyTreeConfig) -> None:
self.parameters = self.parameters or {}
study_data.areas[area_from].links[area_to] = Link(
Expand Down
3 changes: 1 addition & 2 deletions antarest/study/web/study_data_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,14 @@ def get_areas(
)
def get_links(
uuid: str,
with_ui: bool = False,
current_user: JWTUser = Depends(auth.get_current_user),
) -> t.Any:
logger.info(
f"Fetching link list for study {uuid}",
extra={"user": current_user.id},
)
params = RequestParameters(user=current_user)
areas_list = study_service.get_all_links(uuid, with_ui, params)
areas_list = study_service.get_all_links(uuid, params)
return areas_list

@bp.post(
Expand Down
76 changes: 41 additions & 35 deletions tests/integration/study_data_blueprint/test_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from starlette.testclient import TestClient

from antarest.study.storage.rawstudy.model.filesystem.config.links import TransmissionCapacity
from tests.integration.prepare_proxy import PreparerProxy


Expand All @@ -23,7 +24,10 @@ def test_link_820(self, client: TestClient, user_access_token: str, study_type:
client.headers = {"Authorization": f"Bearer {user_access_token}"} # type: ignore

preparer = PreparerProxy(client, user_access_token)
study_id = preparer.create_study("foo", version=880)
study_id = preparer.create_study("foo", version=820)
if study_type == "variant":
study_id = preparer.create_variant(study_id, name="Variant 1")

area1_id = preparer.create_area(study_id, name="Area 1")["id"]
area2_id = preparer.create_area(study_id, name="Area 2")["id"]
area3_id = preparer.create_area(study_id, name="Area 3")["id"]
Expand Down Expand Up @@ -51,51 +55,38 @@ def test_link_820(self, client: TestClient, user_access_token: str, study_type:
"use-phase-shifter": False,
}
assert expected == res.json()
client.delete(f"/v1/studies/{study_id}/links/{area1_id}/{area2_id}")
res = client.delete(f"/v1/studies/{study_id}/links/{area1_id}/{area2_id}")
res.raise_for_status()

# Test create link with parameters

res = client.post(
f"/v1/studies/{study_id}/links",
json={
"area1": area1_id,
"area2": area2_id,
"asset-type": "dc",
"colorb": 160,
"colorg": 170,
"colorr": 180,
"display-comments": True,
"filter-synthesis": "hourly",
"hurdles-cost": True,
"link-style": "plain",
"link-width": 2.0,
"loop-flow": False,
"transmission-capacities": "enabled",
"use-phase-shifter": True,
},
)

assert res.status_code == 200, res.json()

expected = {
"area1": "area 1",
"area2": "area 2",
parameters = {
"area1": area1_id,
"area2": area2_id,
"asset-type": "dc",
"colorb": 160,
"colorg": 170,
"colorr": 180,
"display-comments": True,
"filter-synthesis": "hourly",
"filter-year-by-year": "hourly, daily, weekly, monthly, annual",
"hurdles-cost": True,
"link-style": "plain",
"link-width": 2.0,
"loop-flow": False,
"transmission-capacities": "enabled",
"use-phase-shifter": True,
}
assert expected == res.json()
res = client.post(
f"/v1/studies/{study_id}/links",
json=parameters,
)

assert res.status_code == 200, res.json()
parameters["filter-year-by-year"] = "hourly, daily, weekly, monthly, annual"

assert parameters == res.json()
res = client.delete(f"/v1/studies/{study_id}/links/{area1_id}/{area2_id}")
res.raise_for_status()

# Create two links, count them, then delete one

Expand All @@ -111,20 +102,35 @@ def test_link_820(self, client: TestClient, user_access_token: str, study_type:
assert 2 == len(res.json())

res = client.delete(f"/v1/studies/{study_id}/links/{area1_id}/{area3_id}")
assert res.status_code == 200, res.json()
res.raise_for_status()

res = client.get(f"/v1/studies/{study_id}/links")

assert res.status_code == 200, res.json()
assert 1 == len(res.json())
client.delete(f"/v1/studies/{study_id}/links/{area1_id}/{area2_id}")
res.raise_for_status()

# Test create link with same area

res = client.post(f"/v1/studies/{study_id}/links", json={"area1": area1_id, "area2": area1_id})

assert res.status_code == 422, res.json()
expected = {"description": "Cannot create link on same node", "exception": "LinkValidationError"}
expected = {"description": "Cannot create link on same area: area 1", "exception": "LinkValidationError"}
assert expected == res.json()

# Test create link with wrong value for enum

res = client.post(
f"/v1/studies/{study_id}/links",
json={"area1": area1_id, "area2": area1_id, "asset-type": TransmissionCapacity.ENABLED},
)
assert res.status_code == 422, res.json()
expected = {
"body": {"area1": "area 1", "area2": "area 1", "asset-type": "enabled"},
"description": "Input should be 'ac', 'dc', 'gaz', 'virt' or 'other'",
"exception": "RequestValidationError",
}
assert expected == res.json()

# Test create link with wrong color parameter
Expand All @@ -133,8 +139,9 @@ def test_link_820(self, client: TestClient, user_access_token: str, study_type:

assert res.status_code == 422, res.json()
expected = {
"description": "Invalid value for colorr. Must be between 0 and 255.",
"exception": "LinkValidationError",
"body": {"area1": "area 1", "area2": "area 2", "colorr": 260},
"description": "Input should be less than 255",
"exception": "RequestValidationError",
}
assert expected == res.json()

Expand All @@ -152,8 +159,7 @@ def test_link_820(self, client: TestClient, user_access_token: str, study_type:
}
assert expected == res.json()

@pytest.mark.parametrize("study_type", ["raw", "variant"])
def test_create_link_810(self, client: TestClient, user_access_token: str, study_type: str) -> None:
def test_create_link_810(self, client: TestClient, user_access_token: str) -> None:
client.headers = {"Authorization": f"Bearer {user_access_token}"} # type: ignore

preparer = PreparerProxy(client, user_access_token)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None:
},
)
res.raise_for_status()
res_links = client.get(f"/v1/studies/{study_id}/links?with_ui=true")
res_links = client.get(f"/v1/studies/{study_id}/links")
assert res_links.json() == [
{
"area1": "area 1",
Expand Down
4 changes: 3 additions & 1 deletion tests/storage/business/test_arealink_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService):
}

area_manager.create_area(study, AreaCreationDTO(name="test2", type=AreaType.AREA))
study.version = 820
link_manager.create_link(
study,
LinkInfoDTO820(
Expand All @@ -144,6 +145,7 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService):
),
)
assert empty_study.config.areas["test"].links.get("test2") is not None
study.version = -1

link_manager.delete_link(study, "test", "test2")
assert empty_study.config.areas["test"].links.get("test2") is None
Expand Down Expand Up @@ -528,7 +530,7 @@ def test_get_all_area():
}
},
]
links = link_manager.get_all_links(study, with_ui=True)
links = link_manager.get_all_links(study)
assert [
{
"area1": "a1",
Expand Down
9 changes: 9 additions & 0 deletions tests/variantstudy/model/command/test_create_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def test_validation(self, empty_study: FileStudy, command_context: CommandContex
}
).apply(empty_study)

with pytest.raises(ValidationError):
CreateLink(
area1=area1,
area2=area1,
parameters={},
command_context=command_context,
series=[[0]],
)

def test_apply(self, empty_study: FileStudy, command_context: CommandContext):
study_path = empty_study.config.study_path
area1 = "Area1"
Expand Down

0 comments on commit 1355def

Please sign in to comment.