diff --git a/src/rtctools/data/timeseries.py b/src/rtctools/data/timeseries.py
new file mode 100644
index 00000000..4d1bd0a7
--- /dev/null
+++ b/src/rtctools/data/timeseries.py
@@ -0,0 +1,88 @@
+"""Module for processing timeseries data."""
+from datetime import datetime
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+
+import rtctools.data.csv as csv
+import rtctools.data.pi as pi
+import rtctools.data.rtc as rtc
+
+
+def get_timeseries_from_csv(
+ file: Path, delimiter=","
+) -> Tuple[List[datetime], Dict[str, np.ndarray]]:
+ """Read timeseries from csv
+
+ :param file: Path to CSV.
+ :param delimiter: CSV column delimiter.
+
+ :returns: A list of time stamps (datetimes) and dict of values.
+ """
+ data = csv.load(file, delimiter=delimiter, with_time=True)
+ times = data[data.dtype.names[0]]
+ values = {key: np.asarray(data[key], dtype=np.float64) for key in data.dtype.names[1:]}
+ return times, values
+
+
+def get_timeseries_from_pi(
+ file: Path, data_config=rtc.DataConfig, is_binary=False, validate=False
+) -> Tuple[List[datetime], Dict[str, np.ndarray]]:
+ """Read timeseries from Delft-FEWS Published Interface file.
+
+ :param file: Path to the xml file.
+ :param data_config: A :class:rtc.DataConfig object.
+ :param is_binary: Indicate if the file is a binary file.
+ :param validate: Validate the timeseries.
+
+ :returns: A list of time stamps (datetimes) and dict of values.
+ """
+ file = Path(file)
+ timeseries = pi.Timeseries(
+ data_config=data_config,
+ folder=file.parent,
+ basename=file.stem,
+ binary=is_binary,
+ pi_validate_times=validate,
+ )
+ times = timeseries.times
+ values = dict(timeseries.items())
+ return times, values
+
+
+def check_times_are_increasing(times: List[datetime]):
+ """Check that time stamps are increasing."""
+ for i in range(len(times) - 1):
+ if times[i] >= times[i + 1]:
+ raise ValueError("Time stamps must be strictly increasing.")
+
+
+def check_times_are_equidistant(times: List[datetime]):
+ """Check that times are eqeuidistant."""
+ dt = times[1] - times[0]
+ for i in range(len(times) - 1):
+ if times[i + 1] - times[i] != dt:
+ raise ValueError(
+ "Expecting equidistant timeseries, the time step towards "
+ "{} is not the same as the time step(s) before. ".format(times[i + 1])
+ )
+
+
+def fill_nan_in_timeseries(times: List[datetime], values: np.ndarray, interp_args: dict = None):
+ """Fill in missing values in a timeseries using lienar interpolation.
+
+ :param times: List of datetimes.
+ :param values: 1D array of values with the same length as times.
+ :interp_args: Dict of arguments passed to numpy.interp.
+
+ :returns: List of values where nans have been replaced with interpolated values.
+ """
+ if interp_args is None:
+ interp_args = {}
+ times_sec = np.array([(t - times[0]).total_seconds() for t in times])
+ nans = np.isnan(values)
+ if all(nans):
+ return values
+ result = np.interp(times_sec, times_sec[~nans], values[~nans], **interp_args)
+ return result
diff --git a/src/rtctools/optimization/csv_mixin.py b/src/rtctools/optimization/csv_mixin.py
index 3f35be5f..a0fd0422 100644
--- a/src/rtctools/optimization/csv_mixin.py
+++ b/src/rtctools/optimization/csv_mixin.py
@@ -7,6 +7,11 @@
import rtctools.data.csv as csv
from rtctools._internal.alias_tools import AliasDict
from rtctools._internal.caching import cached
+from rtctools.data.timeseries import (
+ check_times_are_equidistant,
+ check_times_are_increasing,
+ get_timeseries_from_csv,
+)
from rtctools.optimization.io_mixin import IOMixin
from rtctools.optimization.timeseries import Timeseries
@@ -102,26 +107,18 @@ def check_initial_state_array(initial_state):
logger.debug("CSVMixin: Read ensemble description")
for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]):
- _timeseries = csv.load(
+ times, values = get_timeseries_from_csv(
os.path.join(
self._input_folder,
ensemble_member_name,
self.timeseries_import_basename + ".csv",
),
delimiter=self.csv_delimiter,
- with_time=True,
)
- self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
-
- self.io.reference_datetime = self.__timeseries_times[0]
-
- for key in _timeseries.dtype.names[1:]:
- self.io.set_timeseries(
- key,
- self.__timeseries_times,
- np.asarray(_timeseries[key], dtype=np.float64),
- ensemble_member_index,
- )
+ self.__timeseries_times = times
+ self.io.reference_datetime = times[0]
+ for key, value in values.items():
+ self.io.set_timeseries(key, times, value, ensemble_member_index)
logger.debug("CSVMixin: Read timeseries")
for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]):
@@ -159,19 +156,14 @@ def check_initial_state_array(initial_state):
self.__initial_state.append(AliasDict(self.alias_relation, _initial_state))
logger.debug("CSVMixin: Read initial state.")
else:
- _timeseries = csv.load(
+ times, values = get_timeseries_from_csv(
os.path.join(self._input_folder, self.timeseries_import_basename + ".csv"),
delimiter=self.csv_delimiter,
- with_time=True,
)
- self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
-
- self.io.reference_datetime = self.__timeseries_times[0]
-
- for key in _timeseries.dtype.names[1:]:
- self.io.set_timeseries(
- key, self.__timeseries_times, np.asarray(_timeseries[key], dtype=np.float64)
- )
+ self.__timeseries_times = times
+ self.io.reference_datetime = times[0]
+ for key, value in values.items():
+ self.io.set_timeseries(key, times, value)
logger.debug("CSVMixin: Read timeseries.")
try:
@@ -202,22 +194,13 @@ def check_initial_state_array(initial_state):
# Timestamp check
if self.csv_validate_timeseries:
times = self.__timeseries_times
- for i in range(len(times) - 1):
- if times[i] >= times[i + 1]:
- raise Exception("CSVMixin: Time stamps must be strictly increasing.")
+ check_times_are_increasing(times)
if self.csv_equidistant:
# Check if the timeseries are truly equidistant
if self.csv_validate_timeseries:
times = self.__timeseries_times
- dt = times[1] - times[0]
- for i in range(len(times) - 1):
- if times[i + 1] - times[i] != dt:
- raise Exception(
- "CSVMixin: Expecting equidistant timeseries, the time step towards "
- "{} is not the same as the time step(s) before. "
- "Set csv_equidistant = False if this is intended.".format(times[i + 1])
- )
+ check_times_are_equidistant(times)
def ensemble_member_probability(self, ensemble_member):
if self.csv_ensemble_mode:
diff --git a/src/rtctools/optimization/pi_mixin.py b/src/rtctools/optimization/pi_mixin.py
index f0af57f2..a2f43220 100644
--- a/src/rtctools/optimization/pi_mixin.py
+++ b/src/rtctools/optimization/pi_mixin.py
@@ -3,6 +3,10 @@
import rtctools.data.pi as pi
import rtctools.data.rtc as rtc
+from rtctools.data.timeseries import (
+ check_times_are_equidistant,
+ check_times_are_increasing,
+)
from rtctools.optimization.io_mixin import IOMixin
logger = logging.getLogger("rtctools")
@@ -109,23 +113,12 @@ def read(self):
# Timestamp check
if self.pi_validate_timeseries:
- for i in range(len(timeseries_import_times) - 1):
- if timeseries_import_times[i] >= timeseries_import_times[i + 1]:
- raise Exception("PIMixin: Time stamps must be strictly increasing.")
+ check_times_are_increasing(timeseries_import_times)
if self.__timeseries_import.dt:
# Check if the timeseries are truly equidistant
if self.pi_validate_timeseries:
- dt = timeseries_import_times[1] - timeseries_import_times[0]
- for i in range(len(timeseries_import_times) - 1):
- if timeseries_import_times[i + 1] - timeseries_import_times[i] != dt:
- raise Exception(
- "PIMixin: Expecting equidistant timeseries, the time step "
- "towards {} is not the same as the time step(s) before. Set "
- "unit to nonequidistant if this is intended.".format(
- timeseries_import_times[i + 1]
- )
- )
+ check_times_are_equidistant(timeseries_import_times)
# Offer input timeseries to IOMixin
self.io.reference_datetime = self.__timeseries_import.forecast_datetime
@@ -278,6 +271,13 @@ def timeseries_export(self):
"""
return self.__timeseries_export
+ @property
+ def data_config(self):
+ """
+ :class:`rtc.ConfigData` object for holding the configuration data.
+ """
+ return self.__data_config
+
def set_unit(self, variable: str, unit: str):
"""
Set the unit of a time series.
diff --git a/src/rtctools/optimization/seed_mixin.py b/src/rtctools/optimization/seed_mixin.py
new file mode 100644
index 00000000..a9023fa0
--- /dev/null
+++ b/src/rtctools/optimization/seed_mixin.py
@@ -0,0 +1,126 @@
+"""Module for seeding functionalities.
+"""
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Union
+
+from rtctools.data.timeseries import (
+ check_times_are_equidistant,
+ check_times_are_increasing,
+ fill_nan_in_timeseries,
+ get_timeseries_from_csv,
+ get_timeseries_from_pi,
+)
+from rtctools.optimization.csv_mixin import CSVMixin
+from rtctools.optimization.goal_programming_mixin import GoalProgrammingMixin
+from rtctools.optimization.homotopy_mixin import HomotopyMixin
+from rtctools.optimization.optimization_problem import OptimizationProblem
+from rtctools.optimization.pi_mixin import PIMixin
+from rtctools.optimization.timeseries import Timeseries
+
+DATA_DIR = Path(__file__).parent / "reservoir"
+
+
+@dataclass
+class SeedOptions:
+ """
+ Class describing options for using a seed.
+
+ :cvar seed_timeseries:
+ Path to the file containing the seed.
+ If None, the default seed is used.
+ :cvar is_binary:
+ Indicates if the file is in binary format.
+ :cvar validate_timeseries:
+ Indicates if the timeseries should be validated.
+ :cvar times_are_equidistant:
+ Indicates if the time stamps are equidistant.
+ :cvar fallback:
+ Indicates if the solver should fallback to the default seed
+ if it fails for the given seed.
+ """
+
+ seed_timeseries: Union[Path, None] = None
+ is_binary: bool = False
+ validate_timeseries: bool = False
+ times_are_equidistant: bool = False
+ fallback: bool = False
+
+
+class SeedMixin(OptimizationProblem):
+ """
+ Adds options for reading a seed from a given file to your optimization problem.
+ """
+
+ def __init__(self, **kwargs):
+ self._seed_timeseries: Path = None
+ super().__init__(**kwargs)
+
+ def _use_seed_timeseries(self):
+ """Return True if a seed from a timeseries is used for the current run."""
+ if self._seed_timeseries is None:
+ return False
+ if isinstance(self, GoalProgrammingMixin):
+ if not self._gp_first_run:
+ return False
+ if isinstance(self, HomotopyMixin):
+ theta_name = self.homotopy_options()["homotopy_parameter"]
+ theta = self.parameters(ensemble_member=0)[theta_name]
+ theta_start = self.homotopy_options()["theta_start"]
+ if theta > theta_start:
+ return False
+ return True
+
+ def seed_options(self) -> SeedOptions:
+ """Get the seed options."""
+ return SeedOptions()
+
+ def seed_timeseries(self) -> Path:
+ """Get the path of the seed timeseries."""
+ return self._seed_timeseries
+
+ def seed(self, ensemble_member):
+ if not self._use_seed_timeseries():
+ return super().seed(ensemble_member)
+ seed: dict = super().seed(ensemble_member).copy() # Copy to prevent updating cached seeds.
+ if isinstance(self, CSVMixin):
+ times, values_dict = get_timeseries_from_csv(self._seed_timeseries)
+ elif isinstance(self, PIMixin):
+ times, values_dict = get_timeseries_from_pi(
+ self._seed_timeseries,
+ data_config=self.data_config,
+ is_binary=self.seed_options().is_binary,
+ validate=self.seed_options().validate_timeseries,
+ )
+ if self.seed_options().validate_timeseries:
+ check_times_are_increasing(times)
+ if self.seed_options().times_are_equidistant:
+ check_times_are_equidistant(times)
+ times_sec = self.io.datetime_to_sec(times, self.io.reference_datetime)
+ for var in values_dict:
+ values = fill_nan_in_timeseries(times, values_dict[var])
+ values_dict[var] = Timeseries(times_sec, values)
+ seed.update(values_dict)
+ return seed
+
+ def optimize(
+ self,
+ preprocessing: bool = True,
+ postprocessing: bool = True,
+ log_solver_failure_as_error: bool = True,
+ ) -> bool:
+ if preprocessing:
+ self.pre()
+ self._seed_timeseries = self.seed_options().seed_timeseries
+ fallback = self.seed_options().fallback
+ success = super().optimize(
+ preprocessing=False,
+ postprocessing=False,
+ log_solver_failure_as_error=log_solver_failure_as_error,
+ )
+ if not success and self._use_seed_timeseries() and fallback:
+ self._seed_timeseries = None
+ success = super().optimize(preprocessing, postprocessing, log_solver_failure_as_error)
+ if postprocessing:
+ self.post()
+ return success
diff --git a/tests/data/data/timeseries/rtcDataConfig.xml b/tests/data/data/timeseries/rtcDataConfig.xml
new file mode 100644
index 00000000..0b20f91e
--- /dev/null
+++ b/tests/data/data/timeseries/rtcDataConfig.xml
@@ -0,0 +1,15 @@
+
+
+
+
+ Seeds
+ Q
+
+
+
+
+ Inputs
+ Q
+
+
+
diff --git a/tests/data/data/timeseries/timeseries.csv b/tests/data/data/timeseries/timeseries.csv
new file mode 100644
index 00000000..ad9162b6
--- /dev/null
+++ b/tests/data/data/timeseries/timeseries.csv
@@ -0,0 +1,6 @@
+Time,Q_out
+2020-01-01 00:00:00,
+2020-01-01 00:00:01,1.0
+2020-01-01 00:00:02,
+2020-01-01 00:00:03,3.0
+2020-01-01 00:00:04,
\ No newline at end of file
diff --git a/tests/data/data/timeseries/timeseries.xml b/tests/data/data/timeseries/timeseries.xml
new file mode 100644
index 00000000..b88c829f
--- /dev/null
+++ b/tests/data/data/timeseries/timeseries.xml
@@ -0,0 +1,21 @@
+
+ 0.0
+
+
+ instantaneous
+ Seeds
+ Q
+
+
+
+
+ -999.0
+ m3/s
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/data/test_timeseries.py b/tests/data/test_timeseries.py
new file mode 100644
index 00000000..82308e3f
--- /dev/null
+++ b/tests/data/test_timeseries.py
@@ -0,0 +1,74 @@
+"""Module for testing timeseries data functionanilities.
+"""
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+from rtctools.data.rtc import DataConfig
+from rtctools.data.timeseries import (
+ check_times_are_equidistant,
+ check_times_are_increasing,
+ fill_nan_in_timeseries,
+ get_timeseries_from_csv,
+ get_timeseries_from_pi,
+)
+from test_case import TestCase
+
+DATA_DIR = Path(__file__).parent / "data" / "timeseries"
+
+
+class TestTimeseries(TestCase):
+ def test_check_times_are_equidistant(self):
+ """Test checking if a timeseries is equidistant."""
+ times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 3]]
+ check_times_are_equidistant(times)
+ times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 4]]
+ with self.assertRaises(ValueError):
+ check_times_are_equidistant(times)
+
+ def test_check_timeseries_is_increasing(self):
+ """Test checking times are strictly increasing."""
+ times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 3]]
+ check_times_are_increasing(times)
+ times = [datetime(2020, 1, 1, 0, 0, sec) for sec in [1, 2, 2]]
+ with self.assertRaises(ValueError):
+ check_times_are_increasing(times)
+
+ def test_get_timeseries_from_csv(self):
+ """Test getting a timeseries from a cvs file."""
+ times, values = get_timeseries_from_csv(DATA_DIR / "timeseries.csv")
+ ref_times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)]
+ ref_values = {"Q_out": [np.nan, 1.0, np.nan, 3.0, np.nan]}
+ for time, ref_time in zip(times, ref_times):
+ self.assertEqual(time, ref_time)
+ for var in values:
+ np.testing.assert_almost_equal(values[var], ref_values[var])
+
+ def test_get_timeseries_from_pi(self):
+ """Test getting a timeseries from a Delft-FEWS Published Interface file."""
+ data_config = DataConfig(DATA_DIR)
+ times, values = get_timeseries_from_pi(
+ file=DATA_DIR / "timeseries.xml", data_config=data_config, validate=True
+ )
+ ref_times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)]
+ ref_values = {"Q_out": [np.nan, 1.0, np.nan, 3.0, np.nan]}
+ for time, ref_time in zip(times, ref_times):
+ self.assertEqual(time, ref_time)
+ for var in values:
+ np.testing.assert_almost_equal(values[var], ref_values[var])
+
+ def test_fill_nan_in_timeseries(self):
+ """Test filling nan values of a timeseries."""
+ del self
+ times = [datetime(2020, 1, 1, 0, 0, sec) for sec in range(5)]
+ values = np.array([np.nan, 1.0, np.nan, 3.0, np.nan])
+ result = fill_nan_in_timeseries(times, values)
+ ref_result = np.array([1.0, 1.0, 2.0, 3.0, 3.0])
+ np.testing.assert_almost_equal(result, ref_result)
+ values = np.array([np.nan, 1.0, np.nan, np.nan, np.nan])
+ result = fill_nan_in_timeseries(times, values)
+ ref_result = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
+ np.testing.assert_almost_equal(result, ref_result)
+ values = np.array([np.nan, np.nan, np.nan, np.nan, np.nan])
+ result = fill_nan_in_timeseries(times, values)
+ assert all(np.isnan(result))
diff --git a/tests/optimization/data/reservoir/.gitignore b/tests/optimization/data/reservoir/.gitignore
new file mode 100644
index 00000000..6cf9204d
--- /dev/null
+++ b/tests/optimization/data/reservoir/.gitignore
@@ -0,0 +1 @@
+timeseries_export.*
\ No newline at end of file
diff --git a/tests/optimization/data/reservoir/initial_state.csv b/tests/optimization/data/reservoir/initial_state.csv
new file mode 100644
index 00000000..92deaaf6
--- /dev/null
+++ b/tests/optimization/data/reservoir/initial_state.csv
@@ -0,0 +1,2 @@
+Volume
+15.0
\ No newline at end of file
diff --git a/tests/optimization/data/reservoir/reservoir.mo b/tests/optimization/data/reservoir/reservoir.mo
new file mode 100644
index 00000000..c71748ac
--- /dev/null
+++ b/tests/optimization/data/reservoir/reservoir.mo
@@ -0,0 +1,9 @@
+model Reservoir
+ // Basic model for in/outlow of a reservoir
+ parameter Real theta;
+ input Real Q_in(fixed=true);
+ input Real Q_out(fixed=false, min=0.0, max=5.0);
+ output Real Volume;
+equation
+ der(Volume) = Q_in - (2 - theta) * Q_out;
+end Reservoir;
diff --git a/tests/optimization/data/reservoir/rtcDataConfig.xml b/tests/optimization/data/reservoir/rtcDataConfig.xml
new file mode 100644
index 00000000..0b20f91e
--- /dev/null
+++ b/tests/optimization/data/reservoir/rtcDataConfig.xml
@@ -0,0 +1,15 @@
+
+
+
+
+ Seeds
+ Q
+
+
+
+
+ Inputs
+ Q
+
+
+
diff --git a/tests/optimization/data/reservoir/seed.csv b/tests/optimization/data/reservoir/seed.csv
new file mode 100644
index 00000000..ad9162b6
--- /dev/null
+++ b/tests/optimization/data/reservoir/seed.csv
@@ -0,0 +1,6 @@
+Time,Q_out
+2020-01-01 00:00:00,
+2020-01-01 00:00:01,1.0
+2020-01-01 00:00:02,
+2020-01-01 00:00:03,3.0
+2020-01-01 00:00:04,
\ No newline at end of file
diff --git a/tests/optimization/data/reservoir/seed.xml b/tests/optimization/data/reservoir/seed.xml
new file mode 100644
index 00000000..b88c829f
--- /dev/null
+++ b/tests/optimization/data/reservoir/seed.xml
@@ -0,0 +1,21 @@
+
+ 0.0
+
+
+ instantaneous
+ Seeds
+ Q
+
+
+
+
+ -999.0
+ m3/s
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/optimization/data/reservoir/timeseries_import.csv b/tests/optimization/data/reservoir/timeseries_import.csv
new file mode 100644
index 00000000..1a71a10c
--- /dev/null
+++ b/tests/optimization/data/reservoir/timeseries_import.csv
@@ -0,0 +1,6 @@
+Time,Q_in
+2020-01-01 00:00:00,0
+2020-01-01 00:00:01,1.0
+2020-01-01 00:00:02,2.0
+2020-01-01 00:00:03,3.0
+2020-01-01 00:00:04,4.0
\ No newline at end of file
diff --git a/tests/optimization/data/reservoir/timeseries_import.xml b/tests/optimization/data/reservoir/timeseries_import.xml
new file mode 100644
index 00000000..5bdf9bc4
--- /dev/null
+++ b/tests/optimization/data/reservoir/timeseries_import.xml
@@ -0,0 +1,21 @@
+
+ 0.0
+
+
+ instantaneous
+ Inputs
+ Q
+
+
+
+
+ -999.0
+ m3/s
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/optimization/test_seed_mixin.py b/tests/optimization/test_seed_mixin.py
new file mode 100644
index 00000000..5305e1ce
--- /dev/null
+++ b/tests/optimization/test_seed_mixin.py
@@ -0,0 +1,193 @@
+"""Module for testing the SeedMixin class."""
+from pathlib import Path
+
+from rtctools.optimization.collocated_integrated_optimization_problem import (
+ CollocatedIntegratedOptimizationProblem,
+)
+from rtctools.optimization.csv_mixin import CSVMixin
+from rtctools.optimization.goal_programming_mixin import Goal, GoalProgrammingMixin, StateGoal
+from rtctools.optimization.homotopy_mixin import HomotopyMixin
+from rtctools.optimization.io_mixin import IOMixin
+from rtctools.optimization.modelica_mixin import ModelicaMixin
+from rtctools.optimization.optimization_problem import OptimizationProblem
+from rtctools.optimization.pi_mixin import PIMixin
+from rtctools.optimization.seed_mixin import SeedMixin, SeedOptions
+from rtctools.optimization.timeseries import Timeseries
+from test_case import TestCase
+
+DATA_DIR = Path(__file__).parent / "data" / "reservoir"
+
+
+class WaterVolumeGoal(StateGoal):
+ """Keep the volume within a given range."""
+
+ priority = 1
+ state = "Volume"
+ target_min = 10
+ target_max = 15
+
+
+class MinimizeQOutGoal(Goal):
+ """Minimize the outflow."""
+
+ priority = 2
+
+ def function(self, optimization_problem, ensemble_member):
+ del self
+ del ensemble_member
+ return optimization_problem.integral("Q_out")
+
+
+class Reservoir(
+ SeedMixin,
+ HomotopyMixin,
+ GoalProgrammingMixin,
+ IOMixin,
+ ModelicaMixin,
+ CollocatedIntegratedOptimizationProblem,
+ OptimizationProblem,
+):
+ """Optimization problem for controlling a reservoir."""
+
+ def __init__(self, **kwargs):
+ kwargs["model_name"] = "Reservoir"
+ kwargs["input_folder"] = DATA_DIR
+ kwargs["output_folder"] = DATA_DIR
+ kwargs["model_folder"] = DATA_DIR
+ super().__init__(**kwargs)
+
+ def bounds(self):
+ bounds = super().bounds()
+ bounds["Volume"] = (0, 20.0)
+ return bounds
+
+ def goals(self):
+ del self
+ return [MinimizeQOutGoal()]
+
+ def path_goals(self):
+ return [WaterVolumeGoal(self)]
+
+ def seed_options(self) -> SeedOptions:
+ return SeedOptions(seed_timeseries=DATA_DIR / "seed.csv", fallback=True)
+
+ def homotopy_options(self):
+ options = super().homotopy_options()
+ if self.seed_timeseries() is not None:
+ options["theta_start"] = 1.0
+ return options
+
+
+class ReservoirCSV(
+ Reservoir,
+ CSVMixin,
+ ModelicaMixin,
+ CollocatedIntegratedOptimizationProblem,
+):
+ """Reservoir class using CSV files."""
+
+ pass
+
+
+class ReservoirPI(
+ Reservoir,
+ PIMixin,
+ ModelicaMixin,
+ CollocatedIntegratedOptimizationProblem,
+):
+ """Reservoir class using Delft-FEWS Published Interface files."""
+
+ pi_parameter_config_basenames = []
+
+
+class DummySolver(OptimizationProblem):
+ """Class for enforcing a solver result for testing purposes."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.success = [] # Keep track of solver results.
+
+ def enforced_solver_result(self):
+ """Return the enforced solver result."""
+ del self
+ return None
+
+ def optimize(
+ self,
+ preprocessing: bool = True,
+ postprocessing: bool = True,
+ log_solver_failure_as_error: bool = True,
+ ) -> bool:
+ success = super().optimize(preprocessing, postprocessing, log_solver_failure_as_error)
+ if self.enforced_solver_result() is not None:
+ success = self.enforced_solver_result()
+ self.success.append(success)
+ return success
+
+
+class ReservoirTest(Reservoir, DummySolver):
+ """Class for testing a reservoir model."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # Keep track of seeds/priorities/thetas of each run.
+ self.seeds_q_out = []
+ self.priorities = []
+ self.thetas = []
+ self.used_seeds = []
+
+ def enforced_solver_result(self):
+ # Enforce failure when using a seed.
+ return False if self.seed_timeseries() is not None else None
+
+ def priority_started(self, priority: int):
+ super().priority_started(priority)
+ # Keep track of seeds/priorities/thetas for testing purposes.
+ seed = self.seed(ensemble_member=0)
+ self.seeds_q_out.append(seed.get("Q_out"))
+ self.used_seeds.append(self.seed_timeseries() is not None)
+ self.priorities.append(priority)
+ self.thetas.append(self.parameters(ensemble_member=0)["theta"])
+
+
+class ReservoirCSVTest(ReservoirTest, ReservoirCSV, DummySolver):
+ """ReservoirTest class using CSV files."""
+
+ pass
+
+
+class ReservoirPITest(ReservoirTest, ReservoirPI, DummySolver):
+ """ReservoirTest class using Delft-FEWS Published Interface files."""
+
+ pass
+
+
+class TestSeedMixin(TestCase):
+ """Test class for seeding with fallback."""
+
+ def _test_seeding_with_fallback(self, model: ReservoirTest):
+ """Test using a seed from a file with a fallback option."""
+ model.optimize()
+ ref_used_seeds = [True, False, False, False, False]
+ ref_thetas = [1, 0, 0, 1, 1]
+ ref_priorities = [1, 1, 2, 1, 2]
+ ref_success = [False, True, True, True, True]
+ ref_seeds = [
+ Timeseries([0, 1, 2, 3, 4], [1, 1, 2, 3, 3]),
+ None,
+ ]
+ self.assertEqual(model.used_seeds, ref_used_seeds)
+ self.assertEqual(model.thetas, ref_thetas)
+ self.assertEqual(model.priorities, ref_priorities)
+ self.assertEqual(model.success, ref_success)
+ self.assertEqual(model.seeds_q_out[:2], ref_seeds)
+
+ def test_seeding_with_fallback_csv(self):
+ """Test using a seed from a CSV file with a fallback option."""
+ model = ReservoirCSVTest()
+ self._test_seeding_with_fallback(model)
+
+ def test_seeding_with_fallback_pi(self):
+ """Test using a seed from a PI file with a fallback option."""
+ model = ReservoirPITest()
+ self._test_seeding_with_fallback(model)