Skip to content

Commit

Permalink
feat(bc): add the ability to filter BC on various properties
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Mar 21, 2024
1 parent 0752fca commit 314ab6f
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 28 deletions.
140 changes: 120 additions & 20 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,103 @@ def generate_id(self) -> str:
return self.data.generate_id()


class BindingConstraintFilter(BaseModel, frozen=True, extra="forbid"):
"""
Binding Constraint Filter gathering the main filtering parameters.
Attributes:
bc_id: binding constraint ID (exact match)
enabled: enabled status
operator: operator
comments: comments (word match)
group: on group name (exact match)
time_step: time step
area_name: area name (word match)
cluster_name: cluster name (word match)
link_id: link ID ('area1%area2') in at least one term.
cluster_id: cluster ID ('area.cluster') in at least one term.
"""

bc_id: str = ""
enabled: Optional[bool] = None
operator: Optional[BindingConstraintOperator] = None
comments: str = ""
group: str = ""
time_step: Optional[BindingConstraintFrequency] = None
area_name: str = ""
cluster_name: str = ""
link_id: str = ""
cluster_id: str = ""

def accept(self, constraint: "BindingConstraintConfigType") -> bool:
"""
Check if the constraint matches the filter.
Args:
constraint: the constraint to check
Returns:
True if the constraint matches the filter, False otherwise
"""
if self.bc_id and self.bc_id != constraint.id:
return False
if self.enabled is not None and self.enabled != constraint.enabled:
return False
if self.operator is not None and self.operator != constraint.operator:
return False
if self.comments:
comments = constraint.comments or ""
if self.comments.upper() not in comments.upper():
return False
if self.group:
group = getattr(constraint, "group") or ""
if self.group.upper() != group.upper():
return False
if self.time_step is not None and self.time_step != constraint.time_step:
return False

# Filter on terms
terms = constraint.constraints or []

if self.area_name:
all_areas = []
for term in terms:
if term.data is None:
continue
if isinstance(term.data, AreaLinkDTO):
all_areas.extend([term.data.area1, term.data.area2])
elif isinstance(term.data, AreaClusterDTO):
all_areas.append(term.data.area)
else: # pragma: no cover
raise NotImplementedError(f"Unknown term data type: {type(term.data)}")
upper_area_name = self.area_name.upper()
if all_areas and not any(upper_area_name in area.upper() for area in all_areas):
return False

if self.cluster_name:
all_clusters = []
for term in terms:
if term.data is None:
continue
if isinstance(term.data, AreaClusterDTO):
all_clusters.append(term.data.cluster)
upper_cluster_name = self.cluster_name.upper()
if all_clusters and not any(upper_cluster_name in cluster.upper() for cluster in all_clusters):
return False

if self.link_id:
all_link_ids = [term.data.generate_id() for term in terms if isinstance(term.data, AreaLinkDTO)]
if not any(self.link_id.lower() == link_id.lower() for link_id in all_link_ids):
return False

if self.cluster_id:
all_cluster_ids = [term.data.generate_id() for term in terms if isinstance(term.data, AreaClusterDTO)]
if not any(self.cluster_id.lower() == cluster_id.lower() for cluster_id in all_cluster_ids):
return False

return True


class BindingConstraintEditionModel(BaseModel, metaclass=AllOptionalMetaclass):
group: str
enabled: bool
Expand Down Expand Up @@ -229,26 +326,27 @@ def constraints_to_coeffs(
return coeffs

def get_binding_constraint(
self, study: Study, constraint_id: Optional[str]
self,
study: Study,
bc_filter: BindingConstraintFilter = BindingConstraintFilter(),
) -> Union[BindingConstraintConfigType, List[BindingConstraintConfigType], None]:
storage_service = self.storage_service.get_storage(study)
file_study = storage_service.get_raw(study)
config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"])
config_values = list(config.values())
study_version = int(study.version)
if constraint_id:
try:
index = [value["id"] for value in config_values].index(constraint_id)
config_value = config_values[index]
return BindingConstraintManager.process_constraint(config_value, study_version)
except ValueError:
return None

binding_constraint = []
for config_value in config_values:
new_config = BindingConstraintManager.process_constraint(config_value, study_version)
binding_constraint.append(new_config)
return binding_constraint
bc_by_ids: Dict[str, BindingConstraintConfigType] = {}
for value in config.values():
new_config = BindingConstraintManager.process_constraint(value, int(study.version))
bc_by_ids[new_config.id] = new_config

result = {bc_id: bc for bc_id, bc in bc_by_ids.items() if bc_filter.accept(bc)}

# If a specific bc_id is provided, we return a single element
if bc_filter.bc_id:
return result.get(bc_filter.bc_id)

# Else we return all the matching elements
return list(result.values())

def validate_binding_constraint(self, study: Study, constraint_id: str) -> None:
if int(study.version) < 870:
Expand Down Expand Up @@ -276,7 +374,7 @@ def create_binding_constraint(
if not bc_id:
raise InvalidConstraintName(f"Invalid binding constraint name: {data.name}.")

if bc_id in {bc.id for bc in self.get_binding_constraint(study, None)}: # type: ignore
if bc_id in {bc.id for bc in self.get_binding_constraint(study)}: # type: ignore
raise DuplicateConstraintName(f"A binding constraint with the same name already exists: {bc_id}.")

check_attributes_coherence(data, version)
Expand Down Expand Up @@ -329,7 +427,7 @@ def update_binding_constraint(
data: BindingConstraintEdition,
) -> BindingConstraintConfigType:
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
constraint = self.get_binding_constraint(study, BindingConstraintFilter(bc_id=binding_constraint_id))
study_version = int(study.version)
if not isinstance(constraint, BindingConstraintConfig) and not isinstance(
constraint, BindingConstraintConfig870
Expand Down Expand Up @@ -390,7 +488,9 @@ def remove_binding_constraint(self, study: Study, binding_constraint_id: str) ->
file_study = self.storage_service.get_storage(study).get_raw(study)

# Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy) and not self.get_binding_constraint(study, binding_constraint_id):
if isinstance(study, VariantStudy) and not self.get_binding_constraint(
study, BindingConstraintFilter(bc_id=binding_constraint_id)
):
raise CommandApplicationError("Binding constraint not found")

execute_or_add_commands(study, file_study, [command], self.storage_service)
Expand All @@ -402,7 +502,7 @@ def update_constraint_term(
term: Union[ConstraintTermDTO, str],
) -> None:
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
constraint = self.get_binding_constraint(study, BindingConstraintFilter(bc_id=binding_constraint_id))

if not isinstance(constraint, BindingConstraintConfig) and not isinstance(constraint, BindingConstraintConfig):
raise BindingConstraintNotFoundError(study.id)
Expand Down Expand Up @@ -454,7 +554,7 @@ def add_new_constraint_term(
constraint_term: ConstraintTermDTO,
) -> None:
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
constraint = self.get_binding_constraint(study, BindingConstraintFilter(bc_id=binding_constraint_id))
if not isinstance(constraint, BindingConstraintConfig) and not isinstance(constraint, BindingConstraintConfig):
raise BindingConstraintNotFoundError(study.id)

Expand Down
64 changes: 56 additions & 8 deletions antarest/study/web/study_data_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Sequence, Union, cast

from fastapi import APIRouter, Body, Depends
from fastapi.params import Query
from fastapi import APIRouter, Body, Depends, Query
from starlette.responses import RedirectResponse

from antarest.core.config import Config
Expand Down Expand Up @@ -45,6 +44,7 @@
BindingConstraintConfigType,
BindingConstraintCreation,
BindingConstraintEdition,
BindingConstraintFilter,
ConstraintTermDTO,
)
from antarest.study.business.correlation_management import CorrelationFormFields, CorrelationManager, CorrelationMatrix
Expand All @@ -53,11 +53,16 @@
from antarest.study.business.link_management import LinkInfoDTO
from antarest.study.business.optimization_management import OptimizationFormFields
from antarest.study.business.playlist_management import PlaylistColumns
from antarest.study.business.table_mode_management import ColumnsModelTypes, TableTemplateType
from antarest.study.business.table_mode_management import (
BindingConstraintOperator,
ColumnsModelTypes,
TableTemplateType,
)
from antarest.study.business.thematic_trimming_management import ThematicTrimmingFormFields
from antarest.study.business.timeseries_config_management import TSFormFields
from antarest.study.model import PatchArea, PatchCluster
from antarest.study.service import StudyService
from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -889,6 +894,35 @@ def update_version(
)
def get_binding_constraint_list(
uuid: str,
enabled: Optional[bool] = Query(None, description="Filter results based on enabled status"),
operator: Optional[BindingConstraintOperator] = Query(None, description="Filter results based on operator"),
comments: str = Query("", description="Filter results based on comments (word match)"),
group: str = Query("", description="filter binding constraints based on group name (exact match)"),
time_step: Optional[BindingConstraintFrequency] = Query(
None,
description="Filter results based on time step",
alias="timeStep",
),
area_name: str = Query(
"",
description="Filter results based on area name (word match)",
alias="areaName",
),
cluster_name: str = Query(
"",
description="Filter results based on cluster name (word match)",
alias="clusterName",
),
link_id: str = Query(
"",
description="Filter results based on link ID ('area1%area2')",
alias="linkId",
),
cluster_id: str = Query(
"",
description="Filter results based on cluster ID ('area.cluster')",
alias="clusterId",
),
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
logger.info(
Expand All @@ -897,7 +931,18 @@ def get_binding_constraint_list(
)
params = RequestParameters(user=current_user)
study = study_service.check_study_access(uuid, StudyPermissionType.READ, params)
return study_service.binding_constraint_manager.get_binding_constraint(study, None)
bc_filter = BindingConstraintFilter(
enabled=enabled,
operator=operator,
comments=comments,
group=group,
time_step=time_step,
area_name=area_name,
cluster_name=cluster_name,
link_id=link_id,
cluster_id=cluster_id,
)
return study_service.binding_constraint_manager.get_binding_constraint(study, bc_filter)

@bp.get(
"/studies/{uuid}/bindingconstraints/{binding_constraint_id}",
Expand All @@ -916,7 +961,8 @@ def get_binding_constraint(
)
params = RequestParameters(user=current_user)
study = study_service.check_study_access(uuid, StudyPermissionType.READ, params)
return study_service.binding_constraint_manager.get_binding_constraint(study, binding_constraint_id)
bc_filter = BindingConstraintFilter(bc_id=binding_constraint_id)
return study_service.binding_constraint_manager.get_binding_constraint(study, bc_filter)

@bp.put(
"/studies/{uuid}/bindingconstraints/{binding_constraint_id}",
Expand Down Expand Up @@ -968,7 +1014,9 @@ def validate_binding_constraint(

@bp.post("/studies/{uuid}/bindingconstraints", tags=[APITag.study_data], summary="Create a binding constraint")
def create_binding_constraint(
uuid: str, data: BindingConstraintCreation, current_user: JWTUser = Depends(auth.get_current_user)
uuid: str,
data: BindingConstraintCreation,
current_user: JWTUser = Depends(auth.get_current_user),
) -> BindingConstraintConfigType:
logger.info(
f"Creating a new binding constraint for study {uuid}",
Expand Down Expand Up @@ -1171,7 +1219,7 @@ def get_correlation_matrix(
"value": "north,east",
},
},
), # type: ignore
),
current_user: JWTUser = Depends(auth.get_current_user),
) -> CorrelationMatrix:
"""
Expand Down Expand Up @@ -2088,7 +2136,7 @@ def duplicate_cluster(
area_id: str,
cluster_type: ClusterType,
source_cluster_id: str,
new_cluster_name: str = Query(..., alias="newName", title="New Cluster Name"), # type: ignore
new_cluster_name: str = Query(..., alias="newName", title="New Cluster Name"),
current_user: JWTUser = Depends(auth.get_current_user),
) -> Union[STStorageOutput, ThermalClusterOutput, RenewableClusterOutput]:
logger.info(
Expand Down

0 comments on commit 314ab6f

Please sign in to comment.