diff --git a/src/rtctools/data/timeseries.py b/src/rtctools/data/timeseries.py new file mode 100644 index 00000000..4d1bd0a7 --- /dev/null +++ b/src/rtctools/data/timeseries.py @@ -0,0 +1,88 @@ +"""Module for processing timeseries data.""" +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np + +import rtctools.data.csv as csv +import rtctools.data.pi as pi +import rtctools.data.rtc as rtc + + +def get_timeseries_from_csv( + file: Path, delimiter="," +) -> Tuple[List[datetime], Dict[str, np.ndarray]]: + """Read timeseries from csv + + :param file: Path to CSV. + :param delimiter: CSV column delimiter. + + :returns: A list of time stamps (datetimes) and dict of values. + """ + data = csv.load(file, delimiter=delimiter, with_time=True) + times = data[data.dtype.names[0]] + values = {key: np.asarray(data[key], dtype=np.float64) for key in data.dtype.names[1:]} + return times, values + + +def get_timeseries_from_pi( + file: Path, data_config=rtc.DataConfig, is_binary=False, validate=False +) -> Tuple[List[datetime], Dict[str, np.ndarray]]: + """Read timeseries from Delft-FEWS Published Interface file. + + :param file: Path to the xml file. + :param data_config: A :class:rtc.DataConfig object. + :param is_binary: Indicate if the file is a binary file. + :param validate: Validate the timeseries. + + :returns: A list of time stamps (datetimes) and dict of values. + """ + file = Path(file) + timeseries = pi.Timeseries( + data_config=data_config, + folder=file.parent, + basename=file.stem, + binary=is_binary, + pi_validate_times=validate, + ) + times = timeseries.times + values = dict(timeseries.items()) + return times, values + + +def check_times_are_increasing(times: List[datetime]): + """Check that time stamps are increasing.""" + for i in range(len(times) - 1): + if times[i] >= times[i + 1]: + raise ValueError("Time stamps must be strictly increasing.") + + +def check_times_are_equidistant(times: List[datetime]): + """Check that times are eqeuidistant.""" + dt = times[1] - times[0] + for i in range(len(times) - 1): + if times[i + 1] - times[i] != dt: + raise ValueError( + "Expecting equidistant timeseries, the time step towards " + "{} is not the same as the time step(s) before. ".format(times[i + 1]) + ) + + +def fill_nan_in_timeseries(times: List[datetime], values: np.ndarray, interp_args: dict = None): + """Fill in missing values in a timeseries using lienar interpolation. + + :param times: List of datetimes. + :param values: 1D array of values with the same length as times. + :interp_args: Dict of arguments passed to numpy.interp. + + :returns: List of values where nans have been replaced with interpolated values. + """ + if interp_args is None: + interp_args = {} + times_sec = np.array([(t - times[0]).total_seconds() for t in times]) + nans = np.isnan(values) + if all(nans): + return values + result = np.interp(times_sec, times_sec[~nans], values[~nans], **interp_args) + return result diff --git a/src/rtctools/optimization/csv_mixin.py b/src/rtctools/optimization/csv_mixin.py index 3f35be5f..a0fd0422 100644 --- a/src/rtctools/optimization/csv_mixin.py +++ b/src/rtctools/optimization/csv_mixin.py @@ -7,6 +7,11 @@ import rtctools.data.csv as csv from rtctools._internal.alias_tools import AliasDict from rtctools._internal.caching import cached +from rtctools.data.timeseries import ( + check_times_are_equidistant, + check_times_are_increasing, + get_timeseries_from_csv, +) from rtctools.optimization.io_mixin import IOMixin from rtctools.optimization.timeseries import Timeseries @@ -102,26 +107,18 @@ def check_initial_state_array(initial_state): logger.debug("CSVMixin: Read ensemble description") for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]): - _timeseries = csv.load( + times, values = get_timeseries_from_csv( os.path.join( self._input_folder, ensemble_member_name, self.timeseries_import_basename + ".csv", ), delimiter=self.csv_delimiter, - with_time=True, ) - self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]] - - self.io.reference_datetime = self.__timeseries_times[0] - - for key in _timeseries.dtype.names[1:]: - self.io.set_timeseries( - key, - self.__timeseries_times, - np.asarray(_timeseries[key], dtype=np.float64), - ensemble_member_index, - ) + self.__timeseries_times = times + self.io.reference_datetime = times[0] + for key, value in values.items(): + self.io.set_timeseries(key, times, value, ensemble_member_index) logger.debug("CSVMixin: Read timeseries") for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]): @@ -159,19 +156,14 @@ def check_initial_state_array(initial_state): self.__initial_state.append(AliasDict(self.alias_relation, _initial_state)) logger.debug("CSVMixin: Read initial state.") else: - _timeseries = csv.load( + times, values = get_timeseries_from_csv( os.path.join(self._input_folder, self.timeseries_import_basename + ".csv"), delimiter=self.csv_delimiter, - with_time=True, ) - self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]] - - self.io.reference_datetime = self.__timeseries_times[0] - - for key in _timeseries.dtype.names[1:]: - self.io.set_timeseries( - key, self.__timeseries_times, np.asarray(_timeseries[key], dtype=np.float64) - ) + self.__timeseries_times = times + self.io.reference_datetime = times[0] + for key, value in values.items(): + self.io.set_timeseries(key, times, value) logger.debug("CSVMixin: Read timeseries.") try: @@ -202,22 +194,13 @@ def check_initial_state_array(initial_state): # Timestamp check if self.csv_validate_timeseries: times = self.__timeseries_times - for i in range(len(times) - 1): - if times[i] >= times[i + 1]: - raise Exception("CSVMixin: Time stamps must be strictly increasing.") + check_times_are_increasing(times) if self.csv_equidistant: # Check if the timeseries are truly equidistant if self.csv_validate_timeseries: times = self.__timeseries_times - dt = times[1] - times[0] - for i in range(len(times) - 1): - if times[i + 1] - times[i] != dt: - raise Exception( - "CSVMixin: Expecting equidistant timeseries, the time step towards " - "{} is not the same as the time step(s) before. " - "Set csv_equidistant = False if this is intended.".format(times[i + 1]) - ) + check_times_are_equidistant(times) def ensemble_member_probability(self, ensemble_member): if self.csv_ensemble_mode: diff --git a/src/rtctools/optimization/pi_mixin.py b/src/rtctools/optimization/pi_mixin.py index f0af57f2..a2f43220 100644 --- a/src/rtctools/optimization/pi_mixin.py +++ b/src/rtctools/optimization/pi_mixin.py @@ -3,6 +3,10 @@ import rtctools.data.pi as pi import rtctools.data.rtc as rtc +from rtctools.data.timeseries import ( + check_times_are_equidistant, + check_times_are_increasing, +) from rtctools.optimization.io_mixin import IOMixin logger = logging.getLogger("rtctools") @@ -109,23 +113,12 @@ def read(self): # Timestamp check if self.pi_validate_timeseries: - for i in range(len(timeseries_import_times) - 1): - if timeseries_import_times[i] >= timeseries_import_times[i + 1]: - raise Exception("PIMixin: Time stamps must be strictly increasing.") + check_times_are_increasing(timeseries_import_times) if self.__timeseries_import.dt: # Check if the timeseries are truly equidistant if self.pi_validate_timeseries: - dt = timeseries_import_times[1] - timeseries_import_times[0] - for i in range(len(timeseries_import_times) - 1): - if timeseries_import_times[i + 1] - timeseries_import_times[i] != dt: - raise Exception( - "PIMixin: Expecting equidistant timeseries, the time step " - "towards {} is not the same as the time step(s) before. Set " - "unit to nonequidistant if this is intended.".format( - timeseries_import_times[i + 1] - ) - ) + check_times_are_equidistant(timeseries_import_times) # Offer input timeseries to IOMixin self.io.reference_datetime = self.__timeseries_import.forecast_datetime @@ -278,6 +271,13 @@ def timeseries_export(self): """ return self.__timeseries_export + @property + def data_config(self): + """ + :class:`rtc.ConfigData` object for holding the configuration data. + """ + return self.__data_config + def set_unit(self, variable: str, unit: str): """ Set the unit of a time series. diff --git a/src/rtctools/optimization/seed_mixin.py b/src/rtctools/optimization/seed_mixin.py new file mode 100644 index 00000000..a9023fa0 --- /dev/null +++ b/src/rtctools/optimization/seed_mixin.py @@ -0,0 +1,126 @@ +"""Module for seeding functionalities. +""" +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +from rtctools.data.timeseries import ( + check_times_are_equidistant, + check_times_are_increasing, + fill_nan_in_timeseries, + get_timeseries_from_csv, + get_timeseries_from_pi, +) +from rtctools.optimization.csv_mixin import CSVMixin +from rtctools.optimization.goal_programming_mixin import GoalProgrammingMixin +from rtctools.optimization.homotopy_mixin import HomotopyMixin +from rtctools.optimization.optimization_problem import OptimizationProblem +from rtctools.optimization.pi_mixin import PIMixin +from rtctools.optimization.timeseries import Timeseries + +DATA_DIR = Path(__file__).parent / "reservoir" + + +@dataclass +class SeedOptions: + """ + Class describing options for using a seed. + + :cvar seed_timeseries: + Path to the file containing the seed. + If None, the default seed is used. + :cvar is_binary: + Indicates if the file is in binary format. + :cvar validate_timeseries: + Indicates if the timeseries should be validated. + :cvar times_are_equidistant: + Indicates if the time stamps are equidistant. + :cvar fallback: + Indicates if the solver should fallback to the default seed + if it fails for the given seed. + """ + + seed_timeseries: Union[Path, None] = None + is_binary: bool = False + validate_timeseries: bool = False + times_are_equidistant: bool = False + fallback: bool = False + + +class SeedMixin(OptimizationProblem): + """ + Adds options for reading a seed from a given file to your optimization problem. + """ + + def __init__(self, **kwargs): + self._seed_timeseries: Path = None + super().__init__(**kwargs) + + def _use_seed_timeseries(self): + """Return True if a seed from a timeseries is used for the current run.""" + if self._seed_timeseries is None: + return False + if isinstance(self, GoalProgrammingMixin): + if not self._gp_first_run: + return False + if isinstance(self, HomotopyMixin): + theta_name = self.homotopy_options()["homotopy_parameter"] + theta = self.parameters(ensemble_member=0)[theta_name] + theta_start = self.homotopy_options()["theta_start"] + if theta > theta_start: + return False + return True + + def seed_options(self) -> SeedOptions: + """Get the seed options.""" + return SeedOptions() + + def seed_timeseries(self) -> Path: + """Get the path of the seed timeseries.""" + return self._seed_timeseries + + def seed(self, ensemble_member): + if not self._use_seed_timeseries(): + return super().seed(ensemble_member) + seed: dict = super().seed(ensemble_member).copy() # Copy to prevent updating cached seeds. + if isinstance(self, CSVMixin): + times, values_dict = get_timeseries_from_csv(self._seed_timeseries) + elif isinstance(self, PIMixin): + times, values_dict = get_timeseries_from_pi( + self._seed_timeseries, + data_config=self.data_config, + is_binary=self.seed_options().is_binary, + validate=self.seed_options().validate_timeseries, + ) + if self.seed_options().validate_timeseries: + check_times_are_increasing(times) + if self.seed_options().times_are_equidistant: + check_times_are_equidistant(times) + times_sec = self.io.datetime_to_sec(times, self.io.reference_datetime) + for var in values_dict: + values = fill_nan_in_timeseries(times, values_dict[var]) + values_dict[var] = Timeseries(times_sec, values) + seed.update(values_dict) + return seed + + def optimize( + self, + preprocessing: bool = True, + postprocessing: bool = True, + log_solver_failure_as_error: bool = True, + ) -> bool: + if preprocessing: + self.pre() + self._seed_timeseries = self.seed_options().seed_timeseries + fallback = self.seed_options().fallback + success = super().optimize( + preprocessing=False, + postprocessing=False, + log_solver_failure_as_error=log_solver_failure_as_error, + ) + if not success and self._use_seed_timeseries() and fallback: + self._seed_timeseries = None + success = super().optimize(preprocessing, postprocessing, log_solver_failure_as_error) + if postprocessing: + self.post() + return success diff --git a/tests/data/data/timeseries/rtcDataConfig.xml b/tests/data/data/timeseries/rtcDataConfig.xml new file mode 100644 index 00000000..0b20f91e --- /dev/null +++ b/tests/data/data/timeseries/rtcDataConfig.xml @@ -0,0 +1,15 @@ + + + + + Seeds + Q + + + + + Inputs + Q + + + diff --git a/tests/data/data/timeseries/timeseries.csv b/tests/data/data/timeseries/timeseries.csv new file mode 100644 index 00000000..ad9162b6 --- /dev/null +++ b/tests/data/data/timeseries/timeseries.csv @@ -0,0 +1,6 @@ +Time,Q_out +2020-01-01 00:00:00, +2020-01-01 00:00:01,1.0 +2020-01-01 00:00:02, +2020-01-01 00:00:03,3.0 +2020-01-01 00:00:04, \ No newline at end of file diff --git a/tests/data/data/timeseries/timeseries.xml b/tests/data/data/timeseries/timeseries.xml new file mode 100644 index 00000000..b88c829f --- /dev/null +++ b/tests/data/data/timeseries/timeseries.xml @@ -0,0 +1,21 @@ + + 0.0 + + + instantaneous + Seeds + Q + + + + + -999.0 + m3/s + + + + + + + + \ No newline at end of file diff --git a/tests/data/test_timeseries.py b/tests/data/test_timeseries.py new file mode 100644 index 00000000..82308e3f --- /dev/null +++ b/tests/data/test_timeseries.py @@ -0,0 +1,74 @@ +"""Module for testing timeseries data functionanilities. +""" +from datetime import datetime +from pathlib import Path + +import numpy as np +from rtctools.data.rtc import DataConfig +from rtctools.data.timeseries import ( + check_times_are_equidistant, + check_times_are_increasing, + fill_nan_in_timeseries, + get_timeseries_from_csv, + get_timeseries_from_pi, +) +from test_case import TestCase + +DATA_DIR = Path(__file__).parent / "data" / "timeseries" + + +class TestTimeseries(TestCase): + def test_check_times_are_equidistant(self): + """Test checking if a timeseries is equidistant.""" + times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 3]] + check_times_are_equidistant(times) + times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 4]] + with self.assertRaises(ValueError): + check_times_are_equidistant(times) + + def test_check_timeseries_is_increasing(self): + """Test checking times are strictly increasing.""" + times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 3]] + check_times_are_increasing(times) + times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 2]] + with self.assertRaises(ValueError): + check_times_are_increasing(times) + + def test_get_timeseries_from_csv(self): + """Test getting a timeseries from a cvs file.""" + times, values = get_timeseries_from_csv(DATA_DIR / "timeseries.csv") + ref_times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)] + ref_values = {"Q_out": [np.nan, 1.0, np.nan, 3.0, np.nan]} + for time, ref_time in zip(times, ref_times): + self.assertEqual(time, ref_time) + for var in values: + np.testing.assert_almost_equal(values[var], ref_values[var]) + + def test_get_timeseries_from_pi(self): + """Test getting a timeseries from a Delft-FEWS Published Interface file.""" + data_config = DataConfig(DATA_DIR) + times, values = get_timeseries_from_pi( + file=DATA_DIR / "timeseries.xml", data_config=data_config, validate=True + ) + ref_times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)] + ref_values = {"Q_out": [np.nan, 1.0, np.nan, 3.0, np.nan]} + for time, ref_time in zip(times, ref_times): + self.assertEqual(time, ref_time) + for var in values: + np.testing.assert_almost_equal(values[var], ref_values[var]) + + def test_fill_nan_in_timeseries(self): + """Test filling nan values of a timeseries.""" + del self + times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)] + values = np.array([np.nan, 1.0, np.nan, 3.0, np.nan]) + result = fill_nan_in_timeseries(times, values) + ref_result = np.array([1.0, 1.0, 2.0, 3.0, 3.0]) + np.testing.assert_almost_equal(result, ref_result) + values = np.array([np.nan, 1.0, np.nan, np.nan, np.nan]) + result = fill_nan_in_timeseries(times, values) + ref_result = np.array([1.0, 1.0, 1.0, 1.0, 1.0]) + np.testing.assert_almost_equal(result, ref_result) + values = np.array([np.nan, np.nan, np.nan, np.nan, np.nan]) + result = fill_nan_in_timeseries(times, values) + assert all(np.isnan(result)) diff --git a/tests/optimization/data/reservoir/.gitignore b/tests/optimization/data/reservoir/.gitignore new file mode 100644 index 00000000..6cf9204d --- /dev/null +++ b/tests/optimization/data/reservoir/.gitignore @@ -0,0 +1 @@ +timeseries_export.* \ No newline at end of file diff --git a/tests/optimization/data/reservoir/initial_state.csv b/tests/optimization/data/reservoir/initial_state.csv new file mode 100644 index 00000000..92deaaf6 --- /dev/null +++ b/tests/optimization/data/reservoir/initial_state.csv @@ -0,0 +1,2 @@ +Volume +15.0 \ No newline at end of file diff --git a/tests/optimization/data/reservoir/reservoir.mo b/tests/optimization/data/reservoir/reservoir.mo new file mode 100644 index 00000000..c71748ac --- /dev/null +++ b/tests/optimization/data/reservoir/reservoir.mo @@ -0,0 +1,9 @@ +model Reservoir + // Basic model for in/outlow of a reservoir + parameter Real theta; + input Real Q_in(fixed=true); + input Real Q_out(fixed=false, min=0.0, max=5.0); + output Real Volume; +equation + der(Volume) = Q_in - (2 - theta) * Q_out; +end Reservoir; diff --git a/tests/optimization/data/reservoir/rtcDataConfig.xml b/tests/optimization/data/reservoir/rtcDataConfig.xml new file mode 100644 index 00000000..0b20f91e --- /dev/null +++ b/tests/optimization/data/reservoir/rtcDataConfig.xml @@ -0,0 +1,15 @@ + + + + + Seeds + Q + + + + + Inputs + Q + + + diff --git a/tests/optimization/data/reservoir/seed.csv b/tests/optimization/data/reservoir/seed.csv new file mode 100644 index 00000000..ad9162b6 --- /dev/null +++ b/tests/optimization/data/reservoir/seed.csv @@ -0,0 +1,6 @@ +Time,Q_out +2020-01-01 00:00:00, +2020-01-01 00:00:01,1.0 +2020-01-01 00:00:02, +2020-01-01 00:00:03,3.0 +2020-01-01 00:00:04, \ No newline at end of file diff --git a/tests/optimization/data/reservoir/seed.xml b/tests/optimization/data/reservoir/seed.xml new file mode 100644 index 00000000..b88c829f --- /dev/null +++ b/tests/optimization/data/reservoir/seed.xml @@ -0,0 +1,21 @@ + + 0.0 + + + instantaneous + Seeds + Q + + + + + -999.0 + m3/s + + + + + + + + \ No newline at end of file diff --git a/tests/optimization/data/reservoir/timeseries_import.csv b/tests/optimization/data/reservoir/timeseries_import.csv new file mode 100644 index 00000000..1a71a10c --- /dev/null +++ b/tests/optimization/data/reservoir/timeseries_import.csv @@ -0,0 +1,6 @@ +Time,Q_in +2020-01-01 00:00:00,0 +2020-01-01 00:00:01,1.0 +2020-01-01 00:00:02,2.0 +2020-01-01 00:00:03,3.0 +2020-01-01 00:00:04,4.0 \ No newline at end of file diff --git a/tests/optimization/data/reservoir/timeseries_import.xml b/tests/optimization/data/reservoir/timeseries_import.xml new file mode 100644 index 00000000..5bdf9bc4 --- /dev/null +++ b/tests/optimization/data/reservoir/timeseries_import.xml @@ -0,0 +1,21 @@ + + 0.0 + + + instantaneous + Inputs + Q + + + + + -999.0 + m3/s + + + + + + + + \ No newline at end of file diff --git a/tests/optimization/test_seed_mixin.py b/tests/optimization/test_seed_mixin.py new file mode 100644 index 00000000..5305e1ce --- /dev/null +++ b/tests/optimization/test_seed_mixin.py @@ -0,0 +1,193 @@ +"""Module for testing the SeedMixin class.""" +from pathlib import Path + +from rtctools.optimization.collocated_integrated_optimization_problem import ( + CollocatedIntegratedOptimizationProblem, +) +from rtctools.optimization.csv_mixin import CSVMixin +from rtctools.optimization.goal_programming_mixin import Goal, GoalProgrammingMixin, StateGoal +from rtctools.optimization.homotopy_mixin import HomotopyMixin +from rtctools.optimization.io_mixin import IOMixin +from rtctools.optimization.modelica_mixin import ModelicaMixin +from rtctools.optimization.optimization_problem import OptimizationProblem +from rtctools.optimization.pi_mixin import PIMixin +from rtctools.optimization.seed_mixin import SeedMixin, SeedOptions +from rtctools.optimization.timeseries import Timeseries +from test_case import TestCase + +DATA_DIR = Path(__file__).parent / "data" / "reservoir" + + +class WaterVolumeGoal(StateGoal): + """Keep the volume within a given range.""" + + priority = 1 + state = "Volume" + target_min = 10 + target_max = 15 + + +class MinimizeQOutGoal(Goal): + """Minimize the outflow.""" + + priority = 2 + + def function(self, optimization_problem, ensemble_member): + del self + del ensemble_member + return optimization_problem.integral("Q_out") + + +class Reservoir( + SeedMixin, + HomotopyMixin, + GoalProgrammingMixin, + IOMixin, + ModelicaMixin, + CollocatedIntegratedOptimizationProblem, + OptimizationProblem, +): + """Optimization problem for controlling a reservoir.""" + + def __init__(self, **kwargs): + kwargs["model_name"] = "Reservoir" + kwargs["input_folder"] = DATA_DIR + kwargs["output_folder"] = DATA_DIR + kwargs["model_folder"] = DATA_DIR + super().__init__(**kwargs) + + def bounds(self): + bounds = super().bounds() + bounds["Volume"] = (0, 20.0) + return bounds + + def goals(self): + del self + return [MinimizeQOutGoal()] + + def path_goals(self): + return [WaterVolumeGoal(self)] + + def seed_options(self) -> SeedOptions: + return SeedOptions(seed_timeseries=DATA_DIR / "seed.csv", fallback=True) + + def homotopy_options(self): + options = super().homotopy_options() + if self.seed_timeseries() is not None: + options["theta_start"] = 1.0 + return options + + +class ReservoirCSV( + Reservoir, + CSVMixin, + ModelicaMixin, + CollocatedIntegratedOptimizationProblem, +): + """Reservoir class using CSV files.""" + + pass + + +class ReservoirPI( + Reservoir, + PIMixin, + ModelicaMixin, + CollocatedIntegratedOptimizationProblem, +): + """Reservoir class using Delft-FEWS Published Interface files.""" + + pi_parameter_config_basenames = [] + + +class DummySolver(OptimizationProblem): + """Class for enforcing a solver result for testing purposes.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.success = [] # Keep track of solver results. + + def enforced_solver_result(self): + """Return the enforced solver result.""" + del self + return None + + def optimize( + self, + preprocessing: bool = True, + postprocessing: bool = True, + log_solver_failure_as_error: bool = True, + ) -> bool: + success = super().optimize(preprocessing, postprocessing, log_solver_failure_as_error) + if self.enforced_solver_result() is not None: + success = self.enforced_solver_result() + self.success.append(success) + return success + + +class ReservoirTest(Reservoir, DummySolver): + """Class for testing a reservoir model.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Keep track of seeds/priorities/thetas of each run. + self.seeds_q_out = [] + self.priorities = [] + self.thetas = [] + self.used_seeds = [] + + def enforced_solver_result(self): + # Enforce failure when using a seed. + return False if self.seed_timeseries() is not None else None + + def priority_started(self, priority: int): + super().priority_started(priority) + # Keep track of seeds/priorities/thetas for testing purposes. + seed = self.seed(ensemble_member=0) + self.seeds_q_out.append(seed.get("Q_out")) + self.used_seeds.append(self.seed_timeseries() is not None) + self.priorities.append(priority) + self.thetas.append(self.parameters(ensemble_member=0)["theta"]) + + +class ReservoirCSVTest(ReservoirTest, ReservoirCSV, DummySolver): + """ReservoirTest class using CSV files.""" + + pass + + +class ReservoirPITest(ReservoirTest, ReservoirPI, DummySolver): + """ReservoirTest class using Delft-FEWS Published Interface files.""" + + pass + + +class TestSeedMixin(TestCase): + """Test class for seeding with fallback.""" + + def _test_seeding_with_fallback(self, model: ReservoirTest): + """Test using a seed from a file with a fallback option.""" + model.optimize() + ref_used_seeds = [True, False, False, False, False] + ref_thetas = [1, 0, 0, 1, 1] + ref_priorities = [1, 1, 2, 1, 2] + ref_success = [False, True, True, True, True] + ref_seeds = [ + Timeseries([0, 1, 2, 3, 4], [1, 1, 2, 3, 3]), + None, + ] + self.assertEqual(model.used_seeds, ref_used_seeds) + self.assertEqual(model.thetas, ref_thetas) + self.assertEqual(model.priorities, ref_priorities) + self.assertEqual(model.success, ref_success) + self.assertEqual(model.seeds_q_out[:2], ref_seeds) + + def test_seeding_with_fallback_csv(self): + """Test using a seed from a CSV file with a fallback option.""" + model = ReservoirCSVTest() + self._test_seeding_with_fallback(model) + + def test_seeding_with_fallback_pi(self): + """Test using a seed from a PI file with a fallback option.""" + model = ReservoirPITest() + self._test_seeding_with_fallback(model)