Skip to content

Commit

Permalink
feat(tests): add tests for v87 bc
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jan 23, 2024
1 parent 9159e4c commit bcf072d
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 19 deletions.
44 changes: 37 additions & 7 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

from pydantic import BaseModel

Expand Down Expand Up @@ -264,12 +264,8 @@ def update_binding_constraint(
}

study_version = int(study.version)
if data.key == "group" and data.value is not None:
if study_version < 870:
raise InvalidFieldForVersionError(
f"You cannot specify a group as your study version is older than v8.7: {data.value}"
)
args["group"] = data.value
args = BindingConstraintManager.fill_group_value(data, constraint, study_version, args)
args = BindingConstraintManager.fill_matrices_according_to_version(data, study_version, args)

if data.key == "time_step" and data.value != constraint.time_step:
# The user changed the time step, we need to update the matrix accordingly
Expand Down Expand Up @@ -297,6 +293,40 @@ def remove_binding_constraint(self, study: Study, binding_constraint_id: str) ->

execute_or_add_commands(study, file_study, [command], self.storage_service)

@staticmethod
def fill_group_value(
data: UpdateBindingConstProps, constraint: BindingConstraintConfigType, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if version < 870:
if data.key == "group":
raise InvalidFieldForVersionError(
f"You cannot specify a group as your study version is older than v8.7: {data.value}"
)
else:
# cast to 870 to use the attribute group
constraint = cast(BindingConstraintConfig870, constraint)
args["group"] = data.value if data.key == "group" else constraint.group
return args

@staticmethod
def fill_matrices_according_to_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if data.key == "values":
if version >= 870:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")
args["values"] = data.value
return args
for matrix in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]:
if data.key == matrix:
if version < 870:
raise InvalidFieldForVersionError(
"You cannot fill a 'matrix_term' as these values refer to v8.7+ studies"
)
args[matrix] = data.value
return args
return args

@staticmethod
def replace_matrices_according_to_frequency_and_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_corresponding_matrices(self, v: Optional[Union[MatrixType, str]], old: b
# Check the matrix link
return validate_matrix(v, {"command_context": self.command_context})
if isinstance(v, list):
check_matrix_values(time_step, v, old=True)
check_matrix_values(time_step, v, old=old)
return validate_matrix(v, {"command_context": self.command_context})
# Invalid datatype
# pragma: no cover
Expand Down
218 changes: 207 additions & 11 deletions tests/integration/study_data_blueprint/test_binding_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st
assert res.status_code == 200, res.json()
assert constraints is None

# The user change the time_step to daily instead of hourly.
# The user changed the time_step to daily instead of hourly.
# We must check that the matrix is a daily/weekly matrix.
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}",
Expand All @@ -265,10 +265,8 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st
params={"path": f"input/bindingconstraints/{bc_id}", "depth": 1, "formatted": True},
headers=user_headers,
)
assert res.status_code == 200, res.json()
dataframe = res.json()
assert len(dataframe["index"]) == 366
assert len(dataframe["columns"]) == 3 # less, equal, greater
assert res.status_code == 200
assert res.json()["data"] == np.zeros((366, 3)).tolist()

# Delete a binding constraint
res = client.delete(f"/v1/studies/{study_id}/bindingconstraints/{bc_id}", headers=user_headers)
Expand Down Expand Up @@ -305,10 +303,8 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st
params={"path": f"input/bindingconstraints/{bc_id_with_matrix}", "depth": 1, "formatted": True},
headers=user_headers,
)
assert res.status_code == 200, res.json()
dataframe = res.json()
assert len(dataframe["index"]) == 366
assert len(dataframe["columns"]) == 3 # less, equal, greater
assert res.status_code == 200
assert res.json()["data"] == daily_matrix

# =============================
# ERRORS
Expand Down Expand Up @@ -437,9 +433,46 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st
assert res.json()["exception"] == "CommandApplicationError"
assert res.json()["description"] == "Binding constraint not found"

# todo : add lots of test for v8.7
def test_for_version_870(self, client: TestClient, admin_access_token: str, study_id: str) -> None:
# Add a group before v8.7
grp_name = "random_grp"
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/binding_constraint_2",
json={"key": "group", "value": grp_name},
headers=user_headers,
)
assert res.status_code == 400
assert res.json()["exception"] == "InvalidFieldForVersionError"
assert (
res.json()["description"]
== f"You cannot specify a group as your study version is older than v8.7: {grp_name}"
)

# Update with a matrix from v8.7
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/binding_constraint_2",
json={"key": "less_term_matrix", "value": [[]]},
headers=user_headers,
)
assert res.status_code == 400
assert res.json()["exception"] == "InvalidFieldForVersionError"
assert res.json()["description"] == "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies"

@pytest.mark.parametrize("study_type", ["raw", "variant"])
def test_for_version_870(self, client: TestClient, admin_access_token: str, study_type: str) -> None:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

# =============================
# STUDY PREPARATION
# =============================

res = client.post(
"/v1/studies",
headers=admin_headers,
params={"name": "foo"},
)
assert res.status_code == 201, res.json()
study_id = res.json()

# Upgrade study to version 870
res = client.put(
f"/v1/studies/{study_id}/upgrade",
Expand All @@ -453,6 +486,159 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud

assert task.status == TaskStatus.COMPLETED, task

if study_type == "variant":
# Create Variant
res = client.post(
f"/v1/studies/{study_id}/variants",
headers=admin_headers,
params={"name": "Variant 1"},
)
assert res.status_code == 200
study_id = res.json()

# =============================
# CREATION
# =============================

# Creation of a bc without group
bc_id_wo_group = "binding_constraint_1"
args = {"enabled": True, "time_step": "hourly", "operator": "less", "coeffs": {}, "comments": "New API"}
res = client.post(
f"/v1/studies/{study_id}/bindingconstraints",
json={"name": bc_id_wo_group, **args},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_wo_group}", headers=admin_headers)
assert res.json()["group"] == "default"

# Creation of bc with a group
bc_id_w_group = "binding_constraint_2"
res = client.post(
f"/v1/studies/{study_id}/bindingconstraints",
json={"name": bc_id_w_group, "group": "specific_grp", **args},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_group}", headers=admin_headers)
assert res.json()["group"] == "specific_grp"

# Creation of bc with a matrix
bc_id_w_matrix = "binding_constraint_3"
matrix = np.ones((8784, 1))
matrix_to_list = matrix.tolist()
res = client.post(
f"/v1/studies/{study_id}/bindingconstraints",
json={"name": bc_id_w_matrix, "less_term_matrix": matrix_to_list, **args},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

if study_type == "variant":
res = client.get(f"/v1/studies/{study_id}/commands", headers=admin_headers)
last_cmd_args = res.json()[-1]["args"]
less_term_matrix = last_cmd_args["less_term_matrix"]
equal_term_matrix = last_cmd_args["equal_term_matrix"]
greater_term_matrix = last_cmd_args["greater_term_matrix"]
assert greater_term_matrix == equal_term_matrix != less_term_matrix

# Check that raw matrices are created
for term in ["lt", "gt", "eq"]:
res = client.get(
f"/v1/studies/{study_id}/raw",
params={"path": f"input/bindingconstraints/{bc_id_w_matrix}_{term}", "depth": 1, "formatted": True},
headers=admin_headers,
)
assert res.status_code == 200
data = res.json()["data"]
if term == "lt":
assert data == matrix_to_list
else:
assert data == np.zeros(matrix.shape).tolist()

# =============================
# UPDATE
# =============================

# Add a group
grp_name = "random_grp"
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_matrix}",
json={"key": "group", "value": grp_name},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

# Asserts the groupe is created
res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_matrix}", headers=admin_headers)
assert res.json()["group"] == grp_name

# Update matrix_term
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_matrix}",
json={"key": "greater_term_matrix", "value": matrix_to_list},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

res = client.get(
f"/v1/studies/{study_id}/raw",
params={"path": f"input/bindingconstraints/{bc_id_w_matrix}_gt", "depth": 1, "formatted": True},
headers=admin_headers,
)
assert res.status_code == 200
assert res.json()["data"] == matrix_to_list

# The user changed the time_step to daily instead of hourly.
# We must check that the matrices have been updated.
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_matrix}",
json={"key": "time_step", "value": "daily"},
headers=admin_headers,
)
assert res.status_code == 200, res.json()

if study_type == "variant":
# Check the last command is a change time_step
res = client.get(f"/v1/studies/{study_id}/commands", headers=admin_headers)
commands = res.json()
args = commands[-1]["args"]
assert args["time_step"] == "daily"
assert args["less_term_matrix"] == args["greater_term_matrix"] == args["equal_term_matrix"] is not None

# Check that the matrices are daily/weekly matrices
expected_matrix = np.zeros((366, 1)).tolist()
for term_alias in ["lt", "gt", "eq"]:
res = client.get(
f"/v1/studies/{study_id}/raw",
params={
"path": f"input/bindingconstraints/{bc_id_w_matrix}_{term_alias}",
"depth": 1,
"formatted": True,
},
headers=admin_headers,
)
assert res.status_code == 200
assert res.json()["data"] == expected_matrix

# =============================
# DELETE
# =============================

# Delete a binding constraint
res = client.delete(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_matrix}", headers=admin_headers)
assert res.status_code == 200, res.json()

# Asserts that the deletion worked
res = client.get(f"/v1/studies/{study_id}/bindingconstraints", headers=admin_headers)
assert len(res.json()) == 2

# =============================
# ERRORS
# =============================

# Creation with wrong matrix according to version
res = client.post(
f"/v1/studies/{study_id}/bindingconstraints",
Expand All @@ -469,3 +655,13 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud
)
assert res.status_code == 400
assert res.json()["description"] == "You cannot fill 'values' as it refers to the matrix before v8.7"

# Update with old matrices
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_group}",
json={"key": "values", "value": [[]]},
headers=admin_headers,
)
assert res.status_code == 400
assert res.json()["exception"] == "InvalidFieldForVersionError"
assert res.json()["description"] == "You cannot fill 'values' as it refers to the matrix before v8.7"

0 comments on commit bcf072d

Please sign in to comment.