Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a SeedMixin class for seed options #1642

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/rtctools/data/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Module for processing timeseries data."""
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np

import rtctools.data.csv as csv
import rtctools.data.pi as pi
import rtctools.data.rtc as rtc


def get_timeseries_from_csv(
file: Path, delimiter=","
) -> Tuple[List[datetime], Dict[str, np.ndarray]]:
"""Read timeseries from csv

:param file: Path to CSV.
:param delimiter: CSV column delimiter.

:returns: A list of time stamps (datetimes) and dict of values.
"""
data = csv.load(file, delimiter=delimiter, with_time=True)
times = data[data.dtype.names[0]]
values = {key: np.asarray(data[key], dtype=np.float64) for key in data.dtype.names[1:]}
return times, values


def get_timeseries_from_pi(
file: Path, data_config=rtc.DataConfig, is_binary=False, validate=False
) -> Tuple[List[datetime], Dict[str, np.ndarray]]:
"""Read timeseries from Delft-FEWS Published Interface file.

:param file: Path to the xml file.
:param data_config: A :class:rtc.DataConfig object.
:param is_binary: Indicate if the file is a binary file.
:param validate: Validate the timeseries.

:returns: A list of time stamps (datetimes) and dict of values.
"""
file = Path(file)
timeseries = pi.Timeseries(
data_config=data_config,
folder=file.parent,
basename=file.stem,
binary=is_binary,
pi_validate_times=validate,
)
times = timeseries.times
values = dict(timeseries.items())
return times, values


def check_times_are_increasing(times: List[datetime]):
"""Check that time stamps are increasing."""
for i in range(len(times) - 1):
if times[i] >= times[i + 1]:
raise ValueError("Time stamps must be strictly increasing.")


def check_times_are_equidistant(times: List[datetime]):
"""Check that times are eqeuidistant."""
dt = times[1] - times[0]
for i in range(len(times) - 1):
if times[i + 1] - times[i] != dt:
raise ValueError(
"Expecting equidistant timeseries, the time step towards "
"{} is not the same as the time step(s) before. ".format(times[i + 1])
)


def fill_nan_in_timeseries(times: List[datetime], values: np.ndarray, interp_args: dict = None):
"""Fill in missing values in a timeseries using lienar interpolation.

:param times: List of datetimes.
:param values: 1D array of values with the same length as times.
:interp_args: Dict of arguments passed to numpy.interp.

:returns: List of values where nans have been replaced with interpolated values.
"""
if interp_args is None:
interp_args = {}
times_sec = np.array([(t - times[0]).total_seconds() for t in times])
nans = np.isnan(values)
if all(nans):
return values
result = np.interp(times_sec, times_sec[~nans], values[~nans], **interp_args)
Ailbhemit marked this conversation as resolved.
Show resolved Hide resolved
return result
51 changes: 17 additions & 34 deletions src/rtctools/optimization/csv_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import rtctools.data.csv as csv
from rtctools._internal.alias_tools import AliasDict
from rtctools._internal.caching import cached
from rtctools.data.timeseries import (
check_times_are_equidistant,
check_times_are_increasing,
get_timeseries_from_csv,
)
from rtctools.optimization.io_mixin import IOMixin
from rtctools.optimization.timeseries import Timeseries

Expand Down Expand Up @@ -102,26 +107,18 @@ def check_initial_state_array(initial_state):
logger.debug("CSVMixin: Read ensemble description")

for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]):
_timeseries = csv.load(
times, values = get_timeseries_from_csv(
os.path.join(
self._input_folder,
ensemble_member_name,
self.timeseries_import_basename + ".csv",
),
delimiter=self.csv_delimiter,
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]

self.io.reference_datetime = self.__timeseries_times[0]

for key in _timeseries.dtype.names[1:]:
self.io.set_timeseries(
key,
self.__timeseries_times,
np.asarray(_timeseries[key], dtype=np.float64),
ensemble_member_index,
)
self.__timeseries_times = times
self.io.reference_datetime = times[0]
for key, value in values.items():
self.io.set_timeseries(key, times, value, ensemble_member_index)
logger.debug("CSVMixin: Read timeseries")

for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble["name"]):
Expand Down Expand Up @@ -159,19 +156,14 @@ def check_initial_state_array(initial_state):
self.__initial_state.append(AliasDict(self.alias_relation, _initial_state))
logger.debug("CSVMixin: Read initial state.")
else:
_timeseries = csv.load(
times, values = get_timeseries_from_csv(
os.path.join(self._input_folder, self.timeseries_import_basename + ".csv"),
delimiter=self.csv_delimiter,
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]

self.io.reference_datetime = self.__timeseries_times[0]

for key in _timeseries.dtype.names[1:]:
self.io.set_timeseries(
key, self.__timeseries_times, np.asarray(_timeseries[key], dtype=np.float64)
)
self.__timeseries_times = times
self.io.reference_datetime = times[0]
for key, value in values.items():
self.io.set_timeseries(key, times, value)
logger.debug("CSVMixin: Read timeseries.")

try:
Expand Down Expand Up @@ -202,22 +194,13 @@ def check_initial_state_array(initial_state):
# Timestamp check
if self.csv_validate_timeseries:
times = self.__timeseries_times
for i in range(len(times) - 1):
if times[i] >= times[i + 1]:
raise Exception("CSVMixin: Time stamps must be strictly increasing.")
check_times_are_increasing(times)

if self.csv_equidistant:
# Check if the timeseries are truly equidistant
if self.csv_validate_timeseries:
times = self.__timeseries_times
dt = times[1] - times[0]
for i in range(len(times) - 1):
if times[i + 1] - times[i] != dt:
raise Exception(
"CSVMixin: Expecting equidistant timeseries, the time step towards "
"{} is not the same as the time step(s) before. "
"Set csv_equidistant = False if this is intended.".format(times[i + 1])
)
check_times_are_equidistant(times)

def ensemble_member_probability(self, ensemble_member):
if self.csv_ensemble_mode:
Expand Down
26 changes: 13 additions & 13 deletions src/rtctools/optimization/pi_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

import rtctools.data.pi as pi
import rtctools.data.rtc as rtc
from rtctools.data.timeseries import (
check_times_are_equidistant,
check_times_are_increasing,
)
from rtctools.optimization.io_mixin import IOMixin

logger = logging.getLogger("rtctools")
Expand Down Expand Up @@ -109,23 +113,12 @@ def read(self):

# Timestamp check
if self.pi_validate_timeseries:
for i in range(len(timeseries_import_times) - 1):
if timeseries_import_times[i] >= timeseries_import_times[i + 1]:
raise Exception("PIMixin: Time stamps must be strictly increasing.")
check_times_are_increasing(timeseries_import_times)

if self.__timeseries_import.dt:
# Check if the timeseries are truly equidistant
if self.pi_validate_timeseries:
dt = timeseries_import_times[1] - timeseries_import_times[0]
for i in range(len(timeseries_import_times) - 1):
if timeseries_import_times[i + 1] - timeseries_import_times[i] != dt:
raise Exception(
"PIMixin: Expecting equidistant timeseries, the time step "
"towards {} is not the same as the time step(s) before. Set "
"unit to nonequidistant if this is intended.".format(
timeseries_import_times[i + 1]
)
)
check_times_are_equidistant(timeseries_import_times)

# Offer input timeseries to IOMixin
self.io.reference_datetime = self.__timeseries_import.forecast_datetime
Expand Down Expand Up @@ -278,6 +271,13 @@ def timeseries_export(self):
"""
return self.__timeseries_export

@property
def data_config(self):
"""
:class:`rtc.ConfigData` object for holding the configuration data.
"""
return self.__data_config

def set_unit(self, variable: str, unit: str):
"""
Set the unit of a time series.
Expand Down
126 changes: 126 additions & 0 deletions src/rtctools/optimization/seed_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Module for seeding functionalities.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Union

from rtctools.data.timeseries import (
check_times_are_equidistant,
check_times_are_increasing,
fill_nan_in_timeseries,
get_timeseries_from_csv,
get_timeseries_from_pi,
)
from rtctools.optimization.csv_mixin import CSVMixin
from rtctools.optimization.goal_programming_mixin import GoalProgrammingMixin
from rtctools.optimization.homotopy_mixin import HomotopyMixin
from rtctools.optimization.optimization_problem import OptimizationProblem
from rtctools.optimization.pi_mixin import PIMixin
from rtctools.optimization.timeseries import Timeseries

DATA_DIR = Path(__file__).parent / "reservoir"


@dataclass
class SeedOptions:
"""
Class describing options for using a seed.

:cvar seed_timeseries:
Path to the file containing the seed.
If None, the default seed is used.
:cvar is_binary:
Indicates if the file is in binary format.
:cvar validate_timeseries:
Indicates if the timeseries should be validated.
:cvar times_are_equidistant:
Indicates if the time stamps are equidistant.
:cvar fallback:
Indicates if the solver should fallback to the default seed
if it fails for the given seed.
"""

seed_timeseries: Union[Path, None] = None
is_binary: bool = False
validate_timeseries: bool = False
times_are_equidistant: bool = False
fallback: bool = False


class SeedMixin(OptimizationProblem):
"""
Adds options for reading a seed from a given file to your optimization problem.
"""

def __init__(self, **kwargs):
self._seed_timeseries: Path = None
super().__init__(**kwargs)

def _use_seed_timeseries(self):
"""Return True if a seed from a timeseries is used for the current run."""
if self._seed_timeseries is None:
return False
if isinstance(self, GoalProgrammingMixin):
if not self._gp_first_run:
return False
if isinstance(self, HomotopyMixin):
theta_name = self.homotopy_options()["homotopy_parameter"]
theta = self.parameters(ensemble_member=0)[theta_name]
theta_start = self.homotopy_options()["theta_start"]
if theta > theta_start:
return False
return True

def seed_options(self) -> SeedOptions:
"""Get the seed options."""
return SeedOptions()

def seed_timeseries(self) -> Path:
"""Get the path of the seed timeseries."""
return self._seed_timeseries

def seed(self, ensemble_member):
if not self._use_seed_timeseries():
return super().seed(ensemble_member)
seed: dict = super().seed(ensemble_member).copy() # Copy to prevent updating cached seeds.
if isinstance(self, CSVMixin):
times, values_dict = get_timeseries_from_csv(self._seed_timeseries)
elif isinstance(self, PIMixin):
times, values_dict = get_timeseries_from_pi(
self._seed_timeseries,
data_config=self.data_config,
is_binary=self.seed_options().is_binary,
validate=self.seed_options().validate_timeseries,
)
if self.seed_options().validate_timeseries:
check_times_are_increasing(times)
if self.seed_options().times_are_equidistant:
check_times_are_equidistant(times)
times_sec = self.io.datetime_to_sec(times, self.io.reference_datetime)
for var in values_dict:
values = fill_nan_in_timeseries(times, values_dict[var])
values_dict[var] = Timeseries(times_sec, values)
seed.update(values_dict)
return seed

def optimize(
self,
preprocessing: bool = True,
postprocessing: bool = True,
log_solver_failure_as_error: bool = True,
) -> bool:
if preprocessing:
self.pre()
self._seed_timeseries = self.seed_options().seed_timeseries
fallback = self.seed_options().fallback
success = super().optimize(
preprocessing=False,
postprocessing=False,
log_solver_failure_as_error=log_solver_failure_as_error,
)
if not success and self._use_seed_timeseries() and fallback:
self._seed_timeseries = None
success = super().optimize(preprocessing, postprocessing, log_solver_failure_as_error)
if postprocessing:
self.post()
return success
15 changes: 15 additions & 0 deletions tests/data/data/timeseries/rtcDataConfig.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<rtcDataConfig xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:rtc="http://www.wldelft.nl/fews" xmlns="http://www.wldelft.nl/fews" xsi:schemaLocation="http://www.wldelft.nl/fews ../../../xsd/rtcDataConfig.xsd">
<timeSeries id="Q_out">
<PITimeSeries>
<locationId>Seeds</locationId>
<parameterId>Q</parameterId>
</PITimeSeries>
</timeSeries>
<timeSeries id="Q_in">
<PITimeSeries>
<locationId>Inputs</locationId>
<parameterId>Q</parameterId>
</PITimeSeries>
</timeSeries>
</rtcDataConfig>
6 changes: 6 additions & 0 deletions tests/data/data/timeseries/timeseries.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Time,Q_out
2020-01-01 00:00:00,
2020-01-01 00:00:01,1.0
2020-01-01 00:00:02,
2020-01-01 00:00:03,3.0
2020-01-01 00:00:04,
Loading