Skip to content

Commit

Permalink
Add post/pre experiment simulation hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Oct 18, 2024
1 parent da9495d commit b51ecfb
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/ert/config/parsing/hook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ class HookRuntime(StrEnum):
PRE_UPDATE = "PRE_UPDATE"
POST_UPDATE = "POST_UPDATE"
PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE"
PRE_EXPERIMENT = "PRE_EXPERIMENT"
POST_EXPERIMENT = "POST_EXPERIMENT"
8 changes: 6 additions & 2 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
from ert.config import ErtConfig, HookRuntime, QueueConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,23 +81,27 @@ def run_experiment(

self.set_env_key("_ERT_EXPERIMENT_ID", str(self.experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(self.ensemble.id))
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "False")

run_args = create_run_arguments(
self.run_paths,
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
run_args,
self.ensemble,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, self.ensemble)

@classmethod
def name(cls) -> str:
Expand Down
6 changes: 5 additions & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Storage
Expand Down Expand Up @@ -80,7 +80,10 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "True")

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand All @@ -105,6 +108,7 @@ def run_experiment(
posterior,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@classmethod
def name(cls) -> str:
Expand Down
14 changes: 14 additions & 0 deletions src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,18 @@ def run_experiment(
ensemble=prior,
)

self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key(
"_IS_FINAL_ITERATION",
"False",
)
self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, prior)
sample_prior(
prior,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
prior_args,
prior,
Expand All @@ -157,6 +164,11 @@ def run_experiment(

self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior)
for prior_iter in range(self._total_iterations):
self.set_env_key("_ERT_ITERATION", str(prior_iter + 1))
self.set_env_key(
"_IS_FINAL_ITERATION",
"True" if (prior_iter == self._total_iterations - 1) else "False",
)
self.send_event(
RunModelUpdateBeginEvent(iteration=prior_iter, run_id=prior.id)
)
Expand Down Expand Up @@ -219,6 +231,8 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@classmethod
def name(cls) -> str:
return "Iterated ensemble smoother"
Expand Down
14 changes: 13 additions & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Storage
Expand Down Expand Up @@ -97,6 +97,14 @@ def run_experiment(
f"Experiment misconfigured, got starting iteration: {self.start_iteration},"
f"restart iteration = {prior.iteration + 1}"
)

self.set_env_key("_ERT_ITERATION", str(self.start_iteration))
self.set_env_key(
"_IS_FINAL_ITERATION",
"True"
if (self.start_iteration == self._total_iterations - 1)
else "False",
)
except (KeyError, ValueError) as err:
raise ErtRunError(
f"Prior ensemble with ID: {id} does not exists"
Expand Down Expand Up @@ -124,6 +132,8 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down Expand Up @@ -155,6 +165,8 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@staticmethod
def parse_weights(weights: str) -> List[float]:
"""Parse weights string and scale weights such that their reciprocals sum
Expand Down
68 changes: 68 additions & 0 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,74 @@ def test_that_stop_on_fail_workflow_jobs_stop_ert(
run_cli(TEST_RUN_MODE, "--disable-monitor", "poly.ert")


@pytest.mark.usefixtures("copy_poly_case")
def test_that_post_experiment_hook_works(
monkeypatch,
):
monkeypatch.setattr(_ert.threading, "_can_raise", False)

# The executable
with open("dump_final_ensemble_id.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_IS_FINAL_ITERATION> final_ensemble_info.txt
""")
)
os.chmod("dump_final_ensemble_id.sh", 0o755)

# The workflow job
with open("DUMP_FINAL_ENSEMBLE_ID", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_final_ensemble_info.sh
""")

# The workflow
with open("POST_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
s.write("""dump_final_ensemble_id""")

# The executable
with open("dump_first_ensemble_id.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_ERT_ITERATION > first_ensemble_id.txt
""")
)
os.chmod("dump_first_ensemble_id.sh", 0o755)

# The workflow job
with open("DUMP_FIRST_ENSEMBLE_ID", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_first_ensemble_id.sh
""")

# The workflow
with open("PRE_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
s.write("""dump_first_ensemble_id""")

with open("poly.ert", mode="a", encoding="utf-8") as fh:
fh.write(
dedent(
"""
NUM_REALIZATIONS 2
LOAD_WORKFLOW_JOB DUMP_FINAL_ENSEMBLE_ID dump_final_ensemble_id
LOAD_WORKFLOW POST_EXPERIMENT_DUMP.WF POST_EXPERIMENT_DUMP
HOOK_WORKFLOW POST_EXPERIMENT_DUMP POST_EXPERIMENT
LOAD_WORKFLOW_JOB DUMP_FIRST_ENSEMBLE_ID dump_first_ensemble_id
LOAD_WORKFLOW PRE_EXPERIMENT_DUMP.WF PRE_EXPERIMENT_DUMP
HOOK_WORKFLOW PRE_EXPERIMENT_DUMP PRE_EXPERIMENT
"""
)
)

run_cli(ITERATIVE_ENSEMBLE_SMOOTHER_MODE, "--disable-monitor", "poly.ert")

# ...2do assert correct contents in files


@pytest.fixture(name="mock_cli_run")
def fixture_mock_cli_run(monkeypatch):
end_event = Mock()
Expand Down

0 comments on commit b51ecfb

Please sign in to comment.