Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(variants): avoid Recursive error when creating big variant tree #1967

Merged
merged 4 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions antarest/study/storage/variantstudy/business/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def get_or_create_section(json_ini: JSON, section: str) -> JSON:


def remove_none_args(command_dto: CommandDTO) -> CommandDTO:
if isinstance(command_dto.args, list):
command_dto.args = [{k: v for k, v in args.items() if v is not None} for args in command_dto.args]
args = command_dto.args
if isinstance(args, list):
command_dto.args = [{k: v for k, v in args.items() if v is not None} for args in args]
elif isinstance(args, dict):
command_dto.args = {k: v for k, v in args.items() if v is not None}
else:
command_dto.args = {k: v for k, v in command_dto.args.items() if v is not None}
raise TypeError(f"Invalid type for args: {type(args)}")
return command_dto


Expand Down
53 changes: 44 additions & 9 deletions antarest/study/storage/variantstudy/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
import typing as t

from pydantic import BaseModel

Expand All @@ -7,28 +7,63 @@


class GenerationResultInfoDTO(BaseModel):
"""
Result information of a snapshot generation process.

Attributes:
success: A boolean indicating whether the generation process was successful.
details: A list of tuples containing detailed information about the generation process.
"""

success: bool
details: List[Tuple[str, bool, str]]
details: t.MutableSequence[t.Tuple[str, bool, str]]


class CommandDTO(BaseModel):
id: Optional[str]
"""
This class represents a command.

Attributes:
id: The unique identifier of the command.
action: The action to be performed by the command.
args: The arguments for the command action.
version: The version of the command.
"""

id: t.Optional[str]
action: str
# if args is a list, this mean the command will be mapped to the list of args
args: Union[List[JSON], JSON]
args: t.Union[t.MutableSequence[JSON], JSON]
version: int = 1


class CommandResultDTO(BaseModel):
"""
This class represents the result of a command.

Attributes:
study_id: The unique identifier of the study.
id: The unique identifier of the command.
success: A boolean indicating whether the command was successful.
message: A message detailing the result of the command.
"""

study_id: str
id: str
success: bool
message: str


class VariantTreeDTO(BaseModel):
node: StudyMetadataDTO
children: List["VariantTreeDTO"]
class VariantTreeDTO:
"""
This class represents a variant tree structure.

Attributes:
node: The metadata of the study (ID, name, version, etc.).
children: A list of variant children.
"""

VariantTreeDTO.update_forward_refs()
def __init__(self, node: StudyMetadataDTO, children: t.MutableSequence["VariantTreeDTO"]) -> None:
# We are intentionally not using Pydantic’s `BaseModel` here to prevent potential
# `RecursionError` exceptions that can occur with Pydantic versions before v2.
self.node = node
self.children = children or []
2 changes: 1 addition & 1 deletion tests/integration/studies_blueprint/test_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_raw_study(
)
assert res.status_code == 200, res.json()
duration = time.time() - start
assert 0 <= duration <= 0.1, f"Duration is {duration} seconds"
assert 0 <= duration <= 0.3, f"Duration is {duration} seconds"

def test_variant_study(
self,
Expand Down
13 changes: 13 additions & 0 deletions tests/integration/variant_blueprint/test_variant_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,16 @@ def test_comments(client: TestClient, admin_access_token: str, variant_id: str)
# Asserts comments did not disappear
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment


def test_recursive_variant_tree(client: TestClient, admin_access_token: str):
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers)
base_study_id = base_study_res.json()
parent_id = base_study_res.json()
for k in range(150):
res = client.post(f"/v1/studies/{base_study_id}/variants?name=variant_{k}", headers=admin_headers)
base_study_id = res.json()
# Asserts that we do not trigger a Recursive Exception
res = client.get(f"/v1/studies/{parent_id}/variants", headers=admin_headers)
assert res.status_code == 200
Loading