Skip to content

Commit

Permalink
Initialize experiment along with simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Oct 10, 2024
1 parent b9cfd46 commit 0684762
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 110 deletions.
5 changes: 3 additions & 2 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
experiment: Experiment,
controls: Iterable[str],
results: Iterable[str],
callback: Optional[Callable[[BatchContext], None]] = None,
Expand Down Expand Up @@ -98,6 +99,7 @@ def callback(*args, **kwargs):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.experiment = experiment
self.control_keys = set(controls)
self.result_keys = set(results)
self.callback = callback
Expand Down Expand Up @@ -162,7 +164,6 @@ def start(
self,
case_name: str,
case_data: List[Tuple[int, Dict[str, Dict[str, Any]]]],
experiment: Experiment,
) -> BatchContext:
"""Start batch simulation, return a simulation context
Expand Down Expand Up @@ -221,7 +222,7 @@ def start(
time, so when you have called the 'start' method you need to let that
batch complete before you start a new batch.
"""
ensemble = experiment.create_ensemble(
ensemble = self.experiment.create_ensemble(
name=case_name,
ensemble_size=self.ert_config.model_config.num_realizations,
)
Expand Down
8 changes: 8 additions & 0 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def refresh(self) -> None:
self._ensembles = self._load_ensembles()
self._experiments = self._load_experiments()

def load_experiment(self, uuid: UUID) -> LocalExperiment:
if uuid not in self._experiments:
self._experiments[uuid] = LocalExperiment(
self, self._experiment_path(uuid), self.mode
)

return self._experiments[uuid]

def get_experiment(self, uuid: UUID) -> LocalExperiment:
"""
Retrieves an experiment by UUID.
Expand Down
4 changes: 2 additions & 2 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage):
responses=[],
)

_server = BatchSimulator(ert_config, {}, [])
_context = _server.start("dispatch_server", [(0, {})], experiment)
_server = BatchSimulator(ert_config, experiment, {}, [])
_context = _server.start("dispatch_server", [(0, {})])

return _context

Expand Down
73 changes: 25 additions & 48 deletions src/everest/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time
from collections import defaultdict
from datetime import datetime
from itertools import count
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple

Expand All @@ -10,35 +9,25 @@
from ropt.evaluator import EvaluatorContext, EvaluatorResult

from ert import BatchSimulator, WorkflowRunner
from ert.config import ErtConfig, ExtParamConfig, HookRuntime
from ert.storage import open_storage
from ert.config import ErtConfig, HookRuntime
from ert.storage import Experiment, Storage
from everest.config import EverestConfig
from everest.simulator.everest_to_ert import everest_to_ert_config


class Simulator(BatchSimulator):
"""Everest simulator: BatchSimulator"""

def __init__(self, ever_config: EverestConfig, callback=None) -> None:
config_dict = everest_to_ert_config(
ever_config, site_config=ErtConfig.read_site_config()
)
ert_config = ErtConfig.with_plugins().from_dict(config_dict=config_dict)

# Inject ExtParam nodes. This is needed because EXT_PARAM is not an ERT
# configuration key, but only a placeholder for the control definitions.
ens_config = ert_config.ensemble_config
for control_name, variables in config_dict["EXT_PARAM"].items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

def __init__(
self,
ever_config: EverestConfig,
ert_config: ErtConfig,
storage: Storage,
experiment: Experiment,
callback=None,
) -> None:
super(Simulator, self).__init__(
ert_config,
experiment,
self._get_controls(ever_config),
self._get_results(ever_config),
callback=callback,
Expand All @@ -51,6 +40,8 @@ def __init__(self, ever_config: EverestConfig, callback=None) -> None:
if ever_config.simulator is not None and ever_config.simulator.enable_cache:
self._cache = _SimulatorCache()

self.storage = storage

def _get_controls(self, ever_config: EverestConfig) -> List[str]:
controls = ever_config.controls or []
return [control.name for control in controls]
Expand Down Expand Up @@ -118,33 +109,19 @@ def __call__(
self._add_control(controls, control_name, control_value)
case_data.append((real_id, controls))

with open_storage(self.ert_config.ens_path, "w") as storage:
if self._experiment_id is None:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self.ert_config.ensemble_config.parameter_configuration,
responses=self.ert_config.ensemble_config.response_configuration,
)
sim_context = self.start(f"batch_{self._batch}", case_data)

self._experiment_id = experiment.id
else:
experiment = storage.get_experiment(self._experiment_id)

sim_context = self.start(f"batch_{self._batch}", case_data, experiment)

while sim_context.running():
time.sleep(0.2)
results = sim_context.results()

# Pre-simulation workflows are run by sim_context, but
# post-stimulation workflows are not, do it here:
ensemble = sim_context.get_ensemble()
for workflow in self.ert_config.hooked_workflows[
HookRuntime.POST_SIMULATION
]:
WorkflowRunner(
workflow, storage, ensemble, ert_config=self.ert_config
).run_blocking()
while sim_context.running():
time.sleep(0.2)
results = sim_context.results()

# Pre-simulation workflows are run by sim_context, but
# post-stimulation workflows are not, do it here:
ensemble = sim_context.get_ensemble()
for workflow in self.ert_config.hooked_workflows[HookRuntime.POST_SIMULATION]:
WorkflowRunner(
workflow, self.storage, ensemble, ert_config=self.ert_config
).run_blocking()

for fnc_name, alias in self._function_aliases.items():
for result in results:
Expand Down
95 changes: 64 additions & 31 deletions src/everest/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
from seba_sqlite import SqliteStorage

import everest
from ert.config import ErtConfig, ExtParamConfig
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.optimizer.everest2ropt import everest2ropt
from everest.plugins.site_config_env import PluginSiteConfigEnv
from everest.simulator import Simulator
from everest.simulator.everest_to_ert import everest_to_ert_config
from everest.strings import EVEREST, SIMULATOR_END, SIMULATOR_START, SIMULATOR_UPDATE
from everest.util import makedirs_if_needed

Expand Down Expand Up @@ -389,38 +392,68 @@ def start_optimization(self):
"""
assert self._monitor_thread is None

# Initialize the Everest simulator:
simulator = Simulator(self.config, callback=self._simulation_callback)

# Initialize the ropt optimizer:
optimizer = self._configure_optimizer(simulator)

# Before each batch evaluation we check if we should abort:
optimizer.add_observer(
EventType.START_EVALUATION,
partial(self._ropt_callback, optimizer=optimizer, simulator=simulator),
config_dict = everest_to_ert_config(
self.config, site_config=ErtConfig.read_site_config()
)

# The SqliteStorage object is used to store optimization results from
# Seba in an sqlite database. It reacts directly to events emitted by
# Seba and is not called by Everest directly. The stored results are
# accessed by Everest via separate SebaSnapshot objects.
# This mechanism is outdated and not supported by the ropt package. It
# is retained for now via the seba_sqlite package.
seba_storage = SqliteStorage(optimizer, self.config.optimization_output_dir)

# Run the optimization:
exit_code = optimizer.run().exit_code

# Extract the best result from the storage.
self._result = seba_storage.get_optimal_result()

if self._monitor_thread is not None:
self._monitor_thread.stop()
self._monitor_thread.join()
self._monitor_thread = None

return "max_batch_num_reached" if self._max_batch_num_reached else exit_code
ert_config = ErtConfig.with_plugins().from_dict(config_dict=config_dict)

# Inject ExtParam nodes. This is needed because EXT_PARAM is not an ERT
# configuration key, but only a placeholder for the control definitions.
ens_config = ert_config.ensemble_config
for control_name, variables in config_dict["EXT_PARAM"].items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

with open_storage(ert_config.ens_path, mode="w") as storage:
# Initialize the Everest simulator:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=ert_config.ensemble_config.parameter_configuration,
responses=ert_config.ensemble_config.response_configuration,
)

simulator = Simulator(
self.config,
ert_config,
storage,
experiment,
callback=self._simulation_callback,
)

# Initialize the ropt optimizer:
optimizer = self._configure_optimizer(simulator)

# Before each batch evaluation we check if we should abort:
optimizer.add_observer(
EventType.START_EVALUATION,
partial(self._ropt_callback, optimizer=optimizer, simulator=simulator),
)

# The SqliteStorage object is used to store optimization results from
# Seba in an sqlite database. It reacts directly to events emitted by
# Seba and is not called by Everest directly. The stored results are
# accessed by Everest via separate SebaSnapshot objects.
# This mechanism is outdated and not supported by the ropt package. It
# is retained for now via the seba_sqlite package.
seba_storage = SqliteStorage(optimizer, self.config.optimization_output_dir)

# Run the optimization:
exit_code = optimizer.run().exit_code

# Extract the best result from the storage.
self._result = seba_storage.get_optimal_result()

if self._monitor_thread is not None:
self._monitor_thread.stop()
self._monitor_thread.join()
self._monitor_thread = None

return "max_batch_num_reached" if self._max_batch_num_reached else exit_code

@property
def result(self):
Expand Down
84 changes: 57 additions & 27 deletions tests/everest/test_simulator_cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import datetime

import numpy as np
from ropt.plan import OptimizationPlanRunner

from ert.config import ErtConfig, ExtParamConfig
from ert.storage import open_storage
from everest.config import EverestConfig, SimulatorConfig
from everest.optimizer.everest2ropt import everest2ropt
from everest.simulator import Simulator
from everest.simulator.everest_to_ert import everest_to_ert_config

CONFIG_FILE = "config_advanced_scipy.yml"

Expand All @@ -24,34 +29,59 @@ def new_call(*args):
config.simulator = SimulatorConfig(enable_cache=True)

ropt_config = everest2ropt(config)
simulator = Simulator(config)

# Run once, populating the cache of the simulator:
variables1 = (
OptimizationPlanRunner(
enopt_config=ropt_config,
evaluator=simulator,
seed=config.environment.random_seed,
)
.run()
.variables

config_dict = everest_to_ert_config(
config, site_config=ErtConfig.read_site_config()
)
assert variables1 is not None
assert np.allclose(variables1, [0.1, 0, 0.4], atol=0.02)
assert n_evals > 0
ert_config = ErtConfig.with_plugins().from_dict(config_dict=config_dict)

# Run again with the same simulator:
n_evals = 0
variables2 = (
OptimizationPlanRunner(
enopt_config=ropt_config,
evaluator=simulator,
seed=config.environment.random_seed,
# Inject ExtParam nodes. This is needed because EXT_PARAM is not an ERT
# configuration key, but only a placeholder for the control definitions.
ens_config = ert_config.ensemble_config
for control_name, variables in config_dict["EXT_PARAM"].items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)
.run()
.variables
)
assert variables2 is not None
assert n_evals == 0

assert np.array_equal(variables1, variables2)
with open_storage(ert_config.ens_path, mode="w") as storage:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=ert_config.ensemble_config.parameter_configuration,
responses=ert_config.ensemble_config.response_configuration,
)

simulator = Simulator(config, ert_config, storage, experiment)

# Run once, populating the cache of the simulator:
variables1 = (
OptimizationPlanRunner(
enopt_config=ropt_config,
evaluator=simulator,
seed=config.environment.random_seed,
)
.run()
.variables
)
assert variables1 is not None
assert np.allclose(variables1, [0.1, 0, 0.4], atol=0.02)
assert n_evals > 0

# Run again with the same simulator:
n_evals = 0
variables2 = (
OptimizationPlanRunner(
enopt_config=ropt_config,
evaluator=simulator,
seed=config.environment.random_seed,
)
.run()
.variables
)
assert variables2 is not None
assert n_evals == 0

assert np.array_equal(variables1, variables2)

0 comments on commit 0684762

Please sign in to comment.