Skip to content

Commit

Permalink
Refactor BatchSimulator
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Oct 10, 2024
1 parent 3dfec31 commit b9cfd46
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 101 deletions.
10 changes: 3 additions & 7 deletions src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class ExtParamConfig(ParameterConfig):
If a list of strings is given, the order is preserved.
"""

input_keys: Union[List[str], Dict[str, List[Tuple[str, str]]]] = field(
default_factory=list
)
input_keys: Union[List[str], Dict[str, List[str]]] = field(default_factory=list)
forward_init: bool = False
output_file: str = ""
forward_init_file: str = ""
Expand Down Expand Up @@ -136,16 +134,14 @@ def __contains__(self, key: Union[Tuple[str, str], str]) -> bool:
"""
if isinstance(self.input_keys, dict) and isinstance(key, tuple):
key, suffix = key
return (
key in self.input_keys and suffix in self.input_keys[key] # type: ignore[comparison-overlap]
)
return key in self.input_keys and suffix in self.input_keys[key]
else:
return key in self.input_keys

def __repr__(self) -> str:
return f"ExtParamConfig(keys={self.input_keys})"

def __getitem__(self, index: str) -> List[Tuple[str, str]]:
def __getitem__(self, index: str) -> List[str]:
"""Retrieve an item from the configuration
If @index is a string, assumes its a key and retrieves the suffixes
Expand Down
51 changes: 16 additions & 35 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from ert.config import ErtConfig, ExtParamConfig, GenDataConfig
from ert.config import ErtConfig, ExtParamConfig

from .batch_simulator_context import BatchContext

Expand All @@ -16,8 +26,8 @@ class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
controls: Dict[str, List[str]],
results: List[str],
controls: Iterable[str],
results: Iterable[str],
callback: Optional[Callable[[BatchContext], None]] = None,
):
"""Will create simulator which can be used to run multiple simulations.
Expand Down Expand Up @@ -88,39 +98,10 @@ def callback(*args, **kwargs):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.control_keys = set(controls.keys())
self.control_keys = set(controls)
self.result_keys = set(results)
self.callback = callback

ens_config = self.ert_config.ensemble_config
for control_name, variables in controls.items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

if "gen_data" not in ens_config:
ens_config.addNode(
GenDataConfig(
keys=results,
input_files=[f"{k}" for k in results],
report_steps_list=[None for _ in results],
)
)
else:
existing_gendata = ens_config.response_configs["gen_data"]
existing_keys = existing_gendata.keys
assert isinstance(existing_gendata, GenDataConfig)

for key in results:
if key not in existing_keys:
existing_gendata.keys.append(key)
existing_gendata.input_files.append(f"{key}")
existing_gendata.report_steps_list.append(None)

def _setup_sim(
self,
sim_id: int,
Expand All @@ -143,7 +124,7 @@ def _check_suffix(
f"these suffixes: {missingsuffixes}"
)
for suffix in assignment:
if suffix not in suffixes: # type: ignore[comparison-overlap]
if suffix not in suffixes:
raise KeyError(
f"Key {key} has suffixes {suffixes}. "
f"Can't find the requested suffix {suffix}"
Expand Down
53 changes: 52 additions & 1 deletion src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import json
import logging
import os
from typing import Union
from typing import DefaultDict, Dict, List, Union

import everest
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.config.install_data_config import InstallDataConfig
from everest.config.install_job_config import InstallJobConfig
from everest.config.simulator_config import SimulatorConfig
Expand Down Expand Up @@ -455,6 +459,51 @@ def _extract_seed(ever_config: EverestConfig, ert_config):
ert_config["RANDOM_SEED"] = random_seed


def _extract_controls(ever_config: EverestConfig, ert_config):
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = collections.defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result)

# This adds an EXT_PARAM key to the ert_config, which is not a true ERT
# configuration key. When initializing an ERT config object, it is ignored.
# It is used by the Simulator object to inject ExtParamConfig nodes.
controls = ever_config.controls or []
ert_config["EXT_PARAM"] = {
control.name: _get_variables(control.variables) for control in controls
}


def _extract_results(ever_config: EverestConfig, ert_config):
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
gen_data = ert_config.get("GEN_DATA", [])
for name in objectives_names + constraint_names:
gen_data.append((name, f"RESULT_FILE:{name}"))
ert_config["GEN_DATA"] = gen_data


def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
"""
Takes as input an Everest configuration, the site-config and converts them
Expand All @@ -475,5 +524,7 @@ def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
_extract_model(ever_config, ert_config)
_extract_queue_system(ever_config, ert_config)
_extract_seed(ever_config, ert_config)
_extract_controls(ever_config, ert_config)
_extract_results(ever_config, ert_config)

return ert_config
99 changes: 42 additions & 57 deletions src/everest/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,72 @@
from collections import defaultdict
from datetime import datetime
from itertools import count
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple

import numpy as np
from numpy import float64
from numpy._typing import NDArray
from ropt.evaluator import EvaluatorContext, EvaluatorResult

from ert import BatchSimulator, WorkflowRunner
from ert.config import ErtConfig, HookRuntime
from ert.config import ErtConfig, ExtParamConfig, HookRuntime
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.simulator.everest_to_ert import everest_to_ert_config


class Simulator(BatchSimulator):
"""Everest simulator: BatchSimulator"""

def __init__(self, ever_config: EverestConfig, callback=None):
self._ert_config = ErtConfig.with_plugins().from_dict(
config_dict=everest_to_ert_config(
ever_config, site_config=ErtConfig.read_site_config()
)
def __init__(self, ever_config: EverestConfig, callback=None) -> None:
config_dict = everest_to_ert_config(
ever_config, site_config=ErtConfig.read_site_config()
)
controls_def = self._get_controls_def(ever_config)
results_def = self._get_results_def(ever_config)
ert_config = ErtConfig.with_plugins().from_dict(config_dict=config_dict)

# Inject ExtParam nodes. This is needed because EXT_PARAM is not an ERT
# configuration key, but only a placeholder for the control definitions.
ens_config = ert_config.ensemble_config
for control_name, variables in config_dict["EXT_PARAM"].items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

super(Simulator, self).__init__(
self._ert_config, controls_def, results_def, callback=callback
ert_config,
self._get_controls(ever_config),
self._get_results(ever_config),
callback=callback,
)

self._function_aliases = self._get_aliases(ever_config)
self._experiment_id = None
self._batch = 0
self._cache: Optional[_SimulatorCache] = None
if ever_config.simulator is not None and ever_config.simulator.enable_cache:
self._cache = _SimulatorCache()

@staticmethod
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result) # { name : [ index ]

def _get_controls_def(
self, ever_config: EverestConfig
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
def _get_controls(self, ever_config: EverestConfig) -> List[str]:
controls = ever_config.controls or []
return {
control.name: self._get_variables(control.variables) for control in controls
}
return [control.name for control in controls]

def _get_results_def(self, ever_config: EverestConfig):
self._function_aliases = {
def _get_results(self, ever_config: EverestConfig) -> List[str]:
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names

def _get_aliases(self, ever_config: EverestConfig) -> Dict[str, str]:
aliases = {
objective.name: objective.alias
for objective in ever_config.objective_functions
if objective.alias is not None
Expand All @@ -83,19 +78,9 @@ def _get_results_def(self, ever_config: EverestConfig):
constraint.upper_bound is not None
and constraint.lower_bound is not None
):
self._function_aliases[f"{constraint.name}:lower"] = constraint.name
self._function_aliases[f"{constraint.name}:upper"] = constraint.name

objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.name not in self._function_aliases
]

constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names
aliases[f"{constraint.name}:lower"] = constraint.name
aliases[f"{constraint.name}:upper"] = constraint.name
return aliases

def __call__(
self, control_values: NDArray[np.float64], metadata: EvaluatorContext
Expand Down Expand Up @@ -133,7 +118,7 @@ def __call__(
self._add_control(controls, control_name, control_value)
case_data.append((real_id, controls))

with open_storage(self._ert_config.ens_path, "w") as storage:
with open_storage(self.ert_config.ens_path, "w") as storage:
if self._experiment_id is None:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
Expand Down
47 changes: 46 additions & 1 deletion tests/ert/unit_tests/simulator/test_batch_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from ert.config import ErtConfig
from ert.config import ErtConfig, ExtParamConfig, GenDataConfig
from ert.scheduler import JobState
from ert.simulator import BatchContext, BatchSimulator

Expand Down Expand Up @@ -35,6 +35,51 @@ def batch_sim_example(setup_case):
return setup_case("batch_sim", "batch_sim.ert")


# The batch simulator was recently refactored. It now requires an ERT config
# object that has been generated in the derived Simulator class. The resulting
# ERT config object includes features that cannot be specified in an ERT
# configuration file. This is acceptable since the batch simulator is only used
# by Everest and slated to be replaced in the near future with newer ERT
# functionality. However, the tests in this file assume that the batch simulator
# can be configured independently from an Everest configuration. To make the
# tests work, the batch simulator class is patched here to inject the missing
# functionality.
class PatchedBatchSimulator(BatchSimulator):
def __init__(self, ert_config, controls, results, callback=None):
super().__init__(ert_config, set(controls), results, callback)
ens_config = ert_config.ensemble_config
for control_name, variables in controls.items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

if "gen_data" not in ens_config:
ens_config.addNode(
GenDataConfig(
keys=results,
input_files=[f"{k}" for k in results],
report_steps_list=[None for _ in results],
)
)
else:
existing_gendata = ens_config.response_configs["gen_data"]
existing_keys = existing_gendata.keys
assert isinstance(existing_gendata, GenDataConfig)

for key in results:
if key not in existing_keys:
existing_gendata.keys.append(key)
existing_gendata.input_files.append(f"{key}")
existing_gendata.report_steps_list.append(None)


BatchSimulator = PatchedBatchSimulator


def test_that_simulator_raises_error_when_missing_ertconfig():
with pytest.raises(ValueError, match="The first argument must be valid ErtConfig"):
_ = BatchSimulator(
Expand Down
Loading

0 comments on commit b9cfd46

Please sign in to comment.