Skip to content

Commit

Permalink
(draft) baserunmodel everest batchsimulator
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Oct 14, 2024
1 parent b2b5d39 commit a3ca527
Showing 1 changed file with 7 additions and 48 deletions.
55 changes: 7 additions & 48 deletions src/ert/run_models/batch_simulator_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,25 @@
from queue import SimpleQueue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypedDict

import numpy as np
from ropt.enums import EventType, OptimizerExitCode
from ropt.plan import OptimizationPlanRunner
from seba_sqlite import SqliteStorage

import everest
from ert.analysis import ErtAnalysisError, iterative_smoother_update
from ert.config import ErtConfig, HookRuntime
from ert.config import ErtConfig
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Storage
from ert.storage import Storage
from everest.config import EverestConfig
from everest.jobs import shell_commands
from everest.optimizer.everest2ropt import everest2ropt
from everest.simulator import Simulator
from everest.strings import EVEREST, SIMULATOR_END, SIMULATOR_START, SIMULATOR_UPDATE

from .. import BatchContext
from ..simulator.batch_simulator_context import Status
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from uuid import UUID

import numpy.typing as npt
pass


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -190,10 +186,7 @@ def extract(path_str, key):
else:
jobs: List[JobProgress] = []
for fms in progress_queue.steps:
if (
not self._display_all_jobs
and fms.name in everest.jobs.shell_commands
):
if not self._display_all_jobs and fms.name in shell_commands:
continue
realization = extract(fms.std_out_file, "geo_realization")
simulation = extract(fms.std_out_file, "simulation")
Expand Down Expand Up @@ -280,10 +273,8 @@ def __init__(
self._display_all_jobs = display_all_jobs
self._result: Optional[Any] = None
self._exit_code: Optional[OptimizerExitCode] = None
self._max_batch_num_reached = False

# Need to customize:
# Runpath creation / deletion
# Number of iterations concept .. separate from baserunmodel
super().__init__(
config,
storage,
Expand All @@ -296,38 +287,6 @@ def __init__(

self.num_retries_per_iter = 0 # OK?

def analyzeStep(
self,
prior_storage: Ensemble,
posterior_storage: Ensemble,
ensemble_id: UUID,
iteration: int,
initial_mask: npt.NDArray[np.bool_],
) -> None:
self.validate()
self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior_storage)
try:
_, self.sies_smoother = iterative_smoother_update(
prior_storage,
posterior_storage,
self.sies_smoother,
parameters=prior_storage.experiment.update_parameters,
observations=prior_storage.experiment.observation_keys,
update_settings=self.update_settings,
analysis_config=self.analysis_config,
sies_step_length=self.sies_step_length,
initial_mask=initial_mask,
rng=self.rng,
progress_callback=functools.partial(
self.send_smoother_event, iteration, ensemble_id
),
)
except ErtAnalysisError as e:
raise ErtRunError(
f"Update algorithm failed with the following error: {e}"
) from e
self.run_workflows(HookRuntime.POST_UPDATE, self._storage, posterior_storage)

def _simulation_callback(self, *args, **_):
logging.getLogger(EVEREST).debug("Simulation callback called")
ctx = args[0]
Expand Down

0 comments on commit a3ca527

Please sign in to comment.