Skip to content

Commit

Permalink
feat(commands): change VariantCommandGenerator to use the new detai…
Browse files Browse the repository at this point in the history
…ls model
  • Loading branch information
laurent-laporte-pro committed Mar 27, 2024
1 parent 89a164f commit c9009f0
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 87 deletions.
3 changes: 1 addition & 2 deletions antarest/study/storage/variantstudy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import uuid

import typing_extensions as te

from pydantic import BaseModel

from antarest.core.model import JSON
Expand All @@ -25,7 +24,7 @@ class NewDetailsDTO(te.TypedDict):
msg: message de la génération de la commande ou message d'erreur (si le statut est false).
"""

id: uuid.UUID
id: t.Optional[uuid.UUID]
name: str
status: bool
msg: str
Expand Down
11 changes: 9 additions & 2 deletions antarest/study/storage/variantstudy/snapshot_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,15 @@ def _apply_commands(
if not results.success:
message = f"Failed to generate variant study {variant_study.id}"
if results.details:
detail: t.Tuple[str, bool, str] = results.details[-1]
message += f": {detail[2]}"
detail = results.details[-1]
if isinstance(detail, (tuple, list)):
# old format: LegacyDetailsDTO
message += f": {detail[2]}"
elif isinstance(detail, dict):
# new format since v2.17: NewDetailsDTO
message += f": {detail['msg']}"
else: # pragma: no cover
raise NotImplementedError(f"Unexpected detail type: {type(detail)}")
raise VariantGenerationError(message)
return results

Expand Down
87 changes: 44 additions & 43 deletions antarest/study/storage/variantstudy/variant_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@
from antarest.study.storage.variantstudy.model.command.common import CommandOutput
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy
from antarest.study.storage.variantstudy.model.model import GenerationResultInfoDTO
from antarest.study.storage.variantstudy.model.model import GenerationResultInfoDTO, NewDetailsDTO

logger = logging.getLogger(__name__)

APPLY_CALLBACK = Callable[[ICommand, Union[FileStudyTreeConfig, FileStudy]], CommandOutput]


class CmdNotifier:
def __init__(self, study_id: str, total_count: int) -> None:
self.index = 0
self.study_id = study_id
self.total_count = total_count

def __call__(self, x: float) -> None:
logger.info(f"Command {self.index}/{self.total_count} [{self.study_id}] applied in {x}s")


class VariantCommandGenerator:
def __init__(self, study_factory: StudyFactory) -> None:
self.study_factory = study_factory
Expand All @@ -33,53 +43,44 @@ def _generate(
# Apply commands
results: GenerationResultInfoDTO = GenerationResultInfoDTO(success=True, details=[])

stopwatch.reset_current()
logger.info("Applying commands")
command_index = 0
total_commands = len(commands)
study_id = metadata.id if metadata is not None else "-"
for command_batch in commands:
command_output_status = True
command_output_message = ""
command_name = command_batch[0].command_name.value if len(command_batch) > 0 else ""
study_id = "-" if metadata is None else metadata.id

# flatten the list of commands
all_commands = [command for command_batch in commands for command in command_batch]

# Prepare the stopwatch
cmd_notifier = CmdNotifier(study_id, len(all_commands))
stopwatch.reset_current()

# Store all the outputs
for index, cmd in enumerate(all_commands, 1):
try:
command_index += 1
command_output_messages: List[str] = []
for command in command_batch:
output = applier(command, data)
command_output_messages.append(output.message)
command_output_status = command_output_status and output.status
if not command_output_status:
break
command_output_message = "\n".join(command_output_messages)
output = applier(cmd, data)
except Exception as e:
command_output_status = False
command_output_message = f"Error while applying command {command_name}"
logger.error(command_output_message, exc_info=e)
break
finally:
results.details.append(
(
command_name,
command_output_status,
command_output_message,
)
)
results.success = command_output_status
if notifier:
notifier(
command_index - 1,
command_output_status,
command_output_message,
)
stopwatch.log_elapsed(
lambda x: logger.info(
f"Command {command_index}/{total_commands} [{study_id}] {command.match_signature()} applied in {x}s"
)
# Unhandled exception
output = CommandOutput(
status=False,
message=f"Error while applying command {cmd.command_name}",
)
logger.error(output.message, exc_info=e)

detail: NewDetailsDTO = {
"id": cmd.command_id,
"name": cmd.command_name.value,
"status": output.status,
"msg": output.message,
}
results.details.append(detail)

if notifier:
notifier(index - 1, output.status, output.message)

cmd_notifier.index = index
stopwatch.log_elapsed(cmd_notifier)

results.success = all(detail["status"] for detail in results.details) # type: ignore

if not results.success:
break
data_type = isinstance(data, FileStudy)
stopwatch.log_elapsed(
lambda x: logger.info(
Expand Down
168 changes: 137 additions & 31 deletions tests/study/storage/variantstudy/test_snapshot_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,29 @@
from antarest.study.model import RawStudy, Study, StudyAdditionalData
from antarest.study.storage.rawstudy.raw_study_service import RawStudyService
from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy, VariantStudySnapshot
from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO
from antarest.study.storage.variantstudy.model.model import CommandDTO
from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator, search_ref_study
from antarest.study.storage.variantstudy.variant_study_service import VariantStudyService
from tests.db_statement_recorder import DBStatementRecorder
from tests.helpers import with_db_context


class AnyUUID:
"""Mock object to match any UUID."""

def __init__(self, as_string: bool = False):
self.as_string = as_string

def __eq__(self, other):
if self.as_string:
try:
uuid.UUID(other)
return True
except ValueError:
return False
return isinstance(other, uuid.UUID)


def _create_variant(
tmp_path: Path,
variant_name: str,
Expand Down Expand Up @@ -852,15 +868,35 @@ def test_generate__nominal_case(
assert len(db_recorder.sql_statements) == 5, str(db_recorder)

# Check: the variant generation must succeed.
assert results == GenerationResultInfoDTO(
success=True,
details=[
("create_area", True, "Area 'North' created"),
("create_area", True, "Area 'South' created"),
("create_link", True, "Link between 'north' and 'south' created"),
("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."),
assert results.dict() == {
"success": True,
"details": [
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'North' created",
},
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'South' created",
},
{
"id": AnyUUID(),
"name": "create_link",
"status": True,
"msg": "Link between 'north' and 'south' created",
},
{
"id": AnyUUID(),
"name": "create_cluster",
"status": True,
"msg": "Thermal cluster 'gas_cluster' added to area 'south'.",
},
],
)
}

# Check: the variant is correctly generated and all commands are applied.
snapshot_dir = variant_study.snapshot_dir
Expand Down Expand Up @@ -908,13 +944,33 @@ def test_generate__nominal_case(
assert list(snapshot_dir.parent.iterdir()) == [snapshot_dir]

# Check: the notifications are correctly registered.
assert notifier.notifications == [ # type: ignore
assert notifier.notifications == [
{
"details": [
["create_area", True, "Area 'North' created"],
["create_area", True, "Area 'South' created"],
["create_link", True, "Link between 'north' and 'south' created"],
["create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."],
{
"id": AnyUUID(as_string=True),
"msg": "Area 'North' created",
"name": "create_area",
"status": True,
},
{
"id": AnyUUID(as_string=True),
"msg": "Area 'South' created",
"name": "create_area",
"status": True,
},
{
"id": AnyUUID(as_string=True),
"msg": "Link between 'north' and 'south' created",
"name": "create_link",
"status": True,
},
{
"id": AnyUUID(as_string=True),
"msg": "Thermal cluster 'gas_cluster' added to area 'south'.",
"name": "create_cluster",
"status": True,
},
],
"success": True,
}
Expand Down Expand Up @@ -997,15 +1053,35 @@ def test_generate__with_denormalize_true(
)

# Check the results
assert results == GenerationResultInfoDTO(
success=True,
details=[
("create_area", True, "Area 'North' created"),
("create_area", True, "Area 'South' created"),
("create_link", True, "Link between 'north' and 'south' created"),
("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."),
assert results.dict() == {
"success": True,
"details": [
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'North' created",
},
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'South' created",
},
{
"id": AnyUUID(),
"name": "create_link",
"status": True,
"msg": "Link between 'north' and 'south' created",
},
{
"id": AnyUUID(),
"name": "create_cluster",
"status": True,
"msg": "Thermal cluster 'gas_cluster' added to area 'south'.",
},
],
)
}

# Check: the matrices are denormalized (we should have TSV files).
snapshot_dir = variant_study.snapshot_dir
Expand Down Expand Up @@ -1100,15 +1176,35 @@ def test_generate__notification_failure(
)

# Check the results
assert results == GenerationResultInfoDTO(
success=True,
details=[
("create_area", True, "Area 'North' created"),
("create_area", True, "Area 'South' created"),
("create_link", True, "Link between 'north' and 'south' created"),
("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."),
assert results.dict() == {
"success": True,
"details": [
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'North' created",
},
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'South' created",
},
{
"id": AnyUUID(),
"name": "create_link",
"status": True,
"msg": "Link between 'north' and 'south' created",
},
{
"id": AnyUUID(),
"name": "create_cluster",
"status": True,
"msg": "Thermal cluster 'gas_cluster' added to area 'south'.",
},
],
)
}

# Check th logs
assert "Something went wrong" in caplog.text
Expand Down Expand Up @@ -1162,4 +1258,14 @@ def test_generate__variant_of_variant(
)

# Check the results
assert results == GenerationResultInfoDTO(success=True, details=[("create_area", True, "Area 'East' created")])
assert results.dict() == {
"success": True,
"details": [
{
"id": AnyUUID(),
"name": "create_area",
"status": True,
"msg": "Area 'East' created",
},
],
}
Loading

0 comments on commit c9009f0

Please sign in to comment.