Skip to content

Commit

Permalink
Refactor FeatureToggling to FeatureScheduler
Browse files Browse the repository at this point in the history
We only need to toggle scheduler. If we ever need to toggle some other
feature, we can implement that as a separate class when the time comes.
This commit removes the complexity associated with supporting multiple
such features.
  • Loading branch information
pinkwah committed Feb 26, 2024
1 parent a1eee46 commit fe12aea
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 178 deletions.
8 changes: 4 additions & 4 deletions src/ert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ert.namespace import Namespace
from ert.run_models.multiple_data_assimilation import MultipleDataAssimilation
from ert.services import StorageService, WebvizErt
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler
from ert.shared.plugins.plugin_manager import ErtPluginContext, ErtPluginManager
from ert.shared.storage.command import add_parser_options as ert_api_add_parser_options
from ert.validation import (
Expand Down Expand Up @@ -239,7 +239,7 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
gui_parser.add_argument(
"--verbose", action="store_true", help="Show verbose output.", default=False
)
FeatureToggling.add_feature_toggling_args(gui_parser)
FeatureScheduler.add_to_argparse(gui_parser)

# lint_parser
lint_parser = subparsers.add_parser(
Expand Down Expand Up @@ -478,7 +478,7 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
)
cli_parser.add_argument("config", type=valid_file, help=config_help)

FeatureToggling.add_feature_toggling_args(cli_parser)
FeatureScheduler.add_to_argparse(cli_parser)

return parser

Expand Down Expand Up @@ -556,7 +556,7 @@ def main() -> None:
handler.setLevel(logging.INFO)
root_logger.addHandler(handler)

FeatureToggling.update_from_args(args)
FeatureScheduler.set_value(args)
try:
with ErtPluginContext() as context:
context.plugin_manager.add_logging_handle_to_root(logging.getLogger())
Expand Down
14 changes: 5 additions & 9 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from cloudevents.http.event import CloudEvent

from ert.async_utils import get_event_loop, new_event_loop
from ert.config.parsing.queue_system import QueueSystem
from ert.ensemble_evaluator import identifiers
from ert.job_queue import JobQueue
from ert.scheduler import Scheduler, create_driver
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler

from .._wait_for_evaluator import wait_for_evaluator
from ._ensemble import Ensemble
Expand Down Expand Up @@ -175,12 +174,9 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
"""
event_creator = self.generate_event_creator(experiment_id=experiment_id)
timeout_queue: Optional[asyncio.Queue[Any]] = None
if (
self._queue_config.queue_system in [QueueSystem.LOCAL]
and FeatureToggling.value("scheduler") is not False
):
FeatureToggling._conf["scheduler"].value = True
if not FeatureToggling.is_enabled("scheduler"):
using_scheduler = FeatureScheduler.is_enabled(self._queue_config.queue_system)

if not using_scheduler:
# Set up the timeout-mechanism
timeout_queue = asyncio.Queue()
# Based on the experiment id the generator will
Expand All @@ -195,7 +191,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
raise ValueError("no config") # mypy

try:
if FeatureToggling.is_enabled("scheduler"):
if using_scheduler:
driver = create_driver(self._queue_config)
queue = Scheduler(
driver,
Expand Down
165 changes: 61 additions & 104 deletions src/ert/shared/feature_toggling.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,76 @@
from __future__ import annotations

import logging
import os
from argparse import ArgumentParser
from copy import deepcopy
from typing import TYPE_CHECKING, Dict, Optional, Union

if TYPE_CHECKING:
from ert.namespace import Namespace
from typing import TYPE_CHECKING, Optional

logger = logging.getLogger()

class _Feature:
def __init__(
self, default: Optional[bool], msg: Optional[str] = None, optional: bool = False
) -> None:
self._value = default
self.msg = msg
self.optional = optional

def validate_value(self, value: Union[bool, str, None]) -> Optional[bool]:
if type(value) is bool or value is None:
return value
elif value.lower() in ["true", "1"]:
return True
elif value.lower() in ["false", "0"]:
return False
elif self.optional and value.lower() in ["default", ""]:
return None
else:
raise ValueError(
f"This option can only be set to {'True/1, False/0 or Default/<empty>' if self.optional else 'True/1 or False/0'}"
)

@property
def value(self) -> Optional[bool]:
return self._value
if TYPE_CHECKING:
from argparse import ArgumentParser

@value.setter
def value(self, value: Optional[bool]) -> None:
self._value = self.validate_value(value)
from ert.config.parsing.queue_system import QueueSystem
from ert.namespace import Namespace


class FeatureToggling:
_conf_original: Dict[str, _Feature] = {
"scheduler": _Feature(
default=None,
msg="Default value for use of Scheduler has been overridden\n"
"This is experimental and may cause problems",
optional=True,
),
class FeatureScheduler:
_DEFAULTS = {
"LOCAL": True,
"LSF": False,
"SLURM": False,
"TORQUE": False,
}

_conf = deepcopy(_conf_original)

@staticmethod
def is_enabled(feature_name: str) -> bool:
return FeatureToggling._conf[feature_name].value is True

@staticmethod
def value(feature_name: str) -> Optional[bool]:
return FeatureToggling._conf[feature_name].value
_value: Optional[bool] = None

@classmethod
def is_enabled(cls, queue_system: QueueSystem) -> bool:
if cls._value is not None:
return cls._value
return cls._DEFAULTS[queue_system.name]

@classmethod
def set_value(cls, args: Namespace) -> None:
if ((value := cls._get_from_args(args)) is not None) or (
(value := cls._get_from_env()) is not None
):
cls._value = value
else:
cls._value = None

@staticmethod
def add_feature_toggling_args(parser: ArgumentParser) -> None:
for name, feature in FeatureToggling._conf.items():
env_var_name = f"ERT_FEATURE_{name.replace('-', '_').upper()}"
env_value: Union[bool, str, None] = None
if env_var_name in os.environ:
try:
feature.value = feature.validate_value(os.environ[env_var_name])
except ValueError as e:
# TODO: this is a bit spammy. It will get called 6 times for each incorrect env var.
logging.getLogger().warning(
f"Failed to set {env_var_name} to '{os.environ[env_var_name]}'. {e}"
)

if not feature.optional:
parser.add_argument(
f"--{'disable' if feature.value else 'enable'}-{name}",
action="store_false" if feature.value else "store_true",
help=f"Toggle {name} (Warning: This is experimental)",
dest=f"feature-{name}",
default=env_value if env_value is not None else feature.value,
)
else:
group = parser.add_mutually_exclusive_group()
group.add_argument(
f"--enable-{name}",
action="store_true",
help=f"Enable {name}",
dest=f"feature-{name}",
default=feature.value,
)
group.add_argument(
f"--disable-{name}",
action="store_false",
help=f"Disable {name}",
dest=f"feature-{name}",
default=feature.value,
)
def add_to_argparse(parser: ArgumentParser) -> None:
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--enable-scheduler",
action="store_true",
help="Enable new scheduler",
dest="feature_scheduler",
default=None,
)
group.add_argument(
"--disable-scheduler",
action="store_false",
help="Disable new scheduler",
dest="feature_scheduler",
default=None,
)

@staticmethod
def update_from_args(args: "Namespace") -> None:
pattern = "feature-"
feature_args = [arg for arg in vars(args).items() if arg[0].startswith(pattern)]
for name, value in feature_args:
name = name[len(pattern) :]
if name in FeatureToggling._conf:
FeatureToggling._conf[name].value = value

# Print warnings for enabled features.
for name, feature in FeatureToggling._conf.items():
if FeatureToggling.is_enabled(name) and feature.msg is not None:
logging.getLogger().warning(
f"{feature.msg}\nValue is set to {feature.value}"
)
def _get_from_env() -> Optional[bool]:
if (value := os.environ.get("ERT_FEATURE_SCHEDULER")) is None:
return None
value = value.lower()
if value in ("true", "1"):
return True
elif value in ("false", "0"):
return False
elif value in ("auto", "default", ""):
return None
raise ValueError(
"This option can only be set to 'true'/'1', 'false'/'0' or 'auto'/'default'/''"
)

@staticmethod
def reset() -> None:
FeatureToggling._conf = deepcopy(FeatureToggling._conf_original)
def _get_from_args(args: Namespace) -> Optional[bool]:
return args.feature_scheduler
4 changes: 2 additions & 2 deletions src/ert/simulator/batch_simulator_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from ert.scheduler.job import State as JobState
from ert.shared.feature_toggling import FeatureToggling
from ert.scheduler.scheduler import Scheduler

from .simulation_context import SimulationContext

Expand Down Expand Up @@ -57,7 +57,7 @@ def status(self) -> Status:
NB: Killed realizations are not reported.
"""
if FeatureToggling.is_enabled("scheduler"):
if isinstance(self._job_queue, Scheduler):
states = self._job_queue.count_states()
return Status(
running=states[JobState.RUNNING],
Expand Down
11 changes: 2 additions & 9 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import numpy as np

from ert.config import HookRuntime
from ert.config.parsing.queue_system import QueueSystem
from ert.enkf_main import create_run_path
from ert.ensemble_evaluator import Realization
from ert.job_queue import JobQueue, JobStatus
from ert.run_context import RunContext
from ert.runpaths import Runpaths
from ert.scheduler import Scheduler, create_driver
from ert.scheduler.job import State as JobState
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler

from .forward_model_status import ForwardModelStatus

Expand Down Expand Up @@ -94,13 +93,7 @@ def __init__(
self._ert = ert
self._mask = mask

if (
ert.ert_config.queue_config.queue_system in [QueueSystem.LOCAL]
and FeatureToggling.value("scheduler") is not False
):
FeatureToggling._conf["scheduler"].value = True
if ert.ert_config.queue_config.queue_system != QueueSystem.LOCAL:
raise NotImplementedError()
if FeatureScheduler.is_enabled(ert.ert_config.queue_config.queue_system):
driver = create_driver(ert.ert_config.queue_config)
self._job_queue = Scheduler(
driver, max_running=ert.ert_config.queue_config.max_running
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ert.config import ErtConfig
from ert.ensemble_evaluator.config import EvaluatorServerConfig
from ert.services import StorageService
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler
from ert.storage import open_storage

from .utils import SOURCE_DIR
Expand Down Expand Up @@ -296,6 +296,7 @@ def using_scheduler(request, monkeypatch):
_ = get_event_loop()

monkeypatch.setenv("ERT_FEATURE_SCHEDULER", "1" if should_enable_scheduler else "0")
monkeypatch.setattr(FeatureScheduler, "_value", should_enable_scheduler)
yield should_enable_scheduler


Expand Down
17 changes: 10 additions & 7 deletions tests/integration_tests/run_cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from argparse import ArgumentParser

import pytest

from ert.__main__ import ert_parser
from ert.cli.main import run_cli as cli_runner
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler


def run_cli(*args):
parser = ArgumentParser(prog="test_main")
parsed = ert_parser(parser, args)
FeatureToggling.update_from_args(parsed)
res = cli_runner(parsed)
FeatureToggling.reset()
return res
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(FeatureScheduler, "_value", None)
parser = ArgumentParser(prog="test_main")
parsed = ert_parser(parser, args)
FeatureScheduler.set_value(parsed)
res = cli_runner(parsed)
return res
6 changes: 3 additions & 3 deletions tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ert.ensemble_evaluator.monitor import Monitor
from ert.job_queue.queue import JobQueue
from ert.scheduler import Scheduler
from ert.shared.feature_toggling import FeatureToggling
from ert.shared.feature_toggling import FeatureScheduler


@pytest.mark.timeout(60)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_run_legacy_ensemble_with_bare_exception(
):
"""This test function is not ported to Scheduler, as it will not
catch general exceptions."""
monkeypatch.setattr(FeatureToggling._conf["scheduler"], "_value", False)
monkeypatch.setattr(FeatureScheduler, "_value", False)
num_reals = 2
custom_port_range = range(1024, 65535)
with tmpdir.as_cwd():
Expand Down Expand Up @@ -125,7 +125,7 @@ async def test_queue_config_properties_propagated_to_scheduler(
tmpdir, make_ensemble_builder, monkeypatch
):
num_reals = 1
monkeypatch.setattr(FeatureToggling._conf["scheduler"], "_value", True)
monkeypatch.setattr(FeatureScheduler, "_value", True)
mocked_scheduler = MagicMock()
mocked_scheduler.__class__ = Scheduler
monkeypatch.setattr(Scheduler, "__init__", mocked_scheduler)
Expand Down
Loading

0 comments on commit fe12aea

Please sign in to comment.