Skip to content

Commit

Permalink
pimixin: seeding: option to ignore seeds at t0
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailbhemit committed Oct 22, 2024
1 parent 0db8fb8 commit aa98526
Showing 1 changed file with 138 additions and 131 deletions.
269 changes: 138 additions & 131 deletions src/rtctools/optimization/pi_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def imported_seed_options(self) -> Dict[str, Union[str, float]]:
+------------------------------------------+------------+---------------+
| ``extend_seed_backwards`` | ``Bool`` | ``False`` |
+------------------------------------------+------------+---------------+
| ``seed_variables_in_timeseries_import`` | ``Bool`` | ``False`` |
| ``seed_variables_in_timeseries_import`` | ``Bool`` | ``True`` |
+------------------------------------------+------------+---------------+
| ``ignore_seed_values_at_t0`` | ``Bool`` | ``True`` |
+------------------------------------------+------------+---------------+
The seeding process is controlled by the seeding_options. If ``import_seed``
is true then, The imported seed will be merged with the timeseries_import and used for
Expand All @@ -88,16 +90,19 @@ def imported_seed_options(self) -> Dict[str, Union[str, float]]:
at the beginning of the timehorizon.
Note that extending a seed backwards or seeding variables included in the timeseries_import
may lead to undesirable effects if a controlled input is seeded.
may lead to undesirable effects if a controlled input is seeded. To avoid accidental forcing
of variables at t0, the option ``ignore_seed_values_at_t0```can be used. When true, the seed
value at t0 is ignored for all variables in the imported seed.
:returns: A dictionary of options for importing an isolated seed.
"""

return {
"import_seed": False,
"extend_seed_forwards": False,
"extend_seed_forwards": True,
"extend_seed_backwards": False,
"seed_variables_in_timeseries_import": False,
"ignore_seed_values_at_t0": True,
}

def read(self):
Expand Down Expand Up @@ -175,138 +180,12 @@ def read(self):
# Offer input timeseries to IOMixin
self.io.reference_datetime = self.__timeseries_import.forecast_datetime

# If an imported seed has been provided merge it with the timeseries_import
# TODO if seed is missing or wrong then use default seed instead of raising exceptions
# TODO be careful for adding seeds for variables with fixed=false!
imported_seed_options = self.imported_seed_options()
if imported_seed_options["import_seed"]:
try:
self.__imported_seed_timeseries = pi.Timeseries(
self.__data_config,
self._input_folder,
self.imported_seed_basename,
binary=self.pi_binary_timeseries,
pi_validate_times=self.pi_validate_timeseries,
)
except IOError:
raise Exception(
"PIMixin: {}.xml not found in {}.".format(
self.imported_seed_basename, self._input_folder
)
)

imported_seed_times = self.__imported_seed_timeseries.times

# Timestamp check
if self.pi_validate_timeseries:
for i in range(len(imported_seed_times) - 1):
if imported_seed_times[i] >= imported_seed_times[i + 1]:
raise Exception("PIMixin: Time stamps must be strictly increasing.")

# Check if the timeseries are truly equidistant
if self.pi_validate_timeseries:
dt = imported_seed_times[1] - imported_seed_times[0]
for i in range(len(imported_seed_times) - 1):
if imported_seed_times[i + 1] - imported_seed_times[i] != dt:
raise Exception(
"PIMixin: Expecting equidistant timeseries, the time step "
"towards {} is not the same as the time step(s) before. Seeding using "
"an imported result is only supported for equidistant timesteps".format(
imported_seed_times[i + 1]
)
)
# Check if timestep is same as timeseries_import
if dt != self.__timeseries_import.dt:
raise Exception(
"PIMixin: The timesteps in timeseries_import {} differ from the "
"timesteps in the imported previous result {}. This is not "
"supported".format(self.__timeseries_import.dt, dt)
)

imported_seed_times_t0 = imported_seed_times[0]
t0_difference = self.io.reference_datetime - imported_seed_times_t0
index_difference = int(t0_difference / dt)

# Check that timeseries_import values are in the seed
if len(imported_seed_times) < abs(index_difference):
# TODO should not be an exception
raise Exception(
"Imported result does not overlap with {} range. "
"Default seed is used".format(self.timeseries_import_basename)
)
times = timeseries_import_times

# timeseries_import_variables_dict = {}
for ensemble_member in range(self.__timeseries_import.ensemble_size):
for variable, values in self.__timeseries_import.items(ensemble_member):
self.io.set_timeseries(
variable, timeseries_import_times, values, ensemble_member
)

for ensemble_member in range(self.__imported_seed_timeseries.ensemble_size):
for variable, values in self.__imported_seed_timeseries.items(ensemble_member):
write_ts = True
if index_difference >= 0:
values = np.asarray(values[index_difference:], dtype=np.float64)
if len(times) < len(values):
values = values[: len(times)]
elif len(times) > len(values):
if imported_seed_options["extend_seed_forwards"]:
# extend the last entry
values = np.append(
values, [values[-1]] * (len(times) - len(values))
)
else:
values = np.append(values, np.nan * (len(times) - len(values)))
else:
values = np.asarray(values, dtype=np.float64)
if imported_seed_options["extend_seed_backwards"]:
# extend first entry back to t0
values = np.append([values[0]] * abs(index_difference), values)
else:
values = np.append(np.nan * abs(index_difference), values)
if len(times) < len(values):
values = values[: len(times)]
elif len(times) > len(values):
if imported_seed_options["extend_seed_forwards"]:
# extend the last entry
values = np.append(
values, [values[-1]] * (len(times) - len(values))
)
else:
values = np.append(values, np.nan * (len(times) - len(values)))
for (
timeseries_import_variable,
timeseries_import_values,
) in self.__timeseries_import.items(ensemble_member):
if timeseries_import_variable == variable:
if not imported_seed_options["seed_variables_in_timeseries_import"]:
write_ts = False
elif np.any(timeseries_import_values):
values = [
a if not math.isnan(a) else b
for a, b in zip(timeseries_import_values, values)
]
else:
write_ts = False
break
if write_ts:
self.io.set_timeseries(
variable, timeseries_import_times, values, ensemble_member
)

logger.info("PIMixin: updated imported timeseries with data from imported seed.")

for ensemble_member in range(self.__timeseries_import.ensemble_size):
if not imported_seed_options["import_seed"]:
for variable, values in self.__timeseries_import.items(ensemble_member):
self.io.set_timeseries(
variable, timeseries_import_times, values, ensemble_member
)

# store the parameters in the internal data store. Note that we
# are effectively broadcasting parameters, as ParameterConfig does
# not support parameters varying per ensemble member
# Note we do this before storing timeseries such that parameters can be used to indicate
# if an imported sed should be used.
for parameter_config in self.__parameter_config:
for location_id, model_id, parameter_id, value in parameter_config:
try:
Expand All @@ -322,6 +201,134 @@ def read(self):
ensemble_member,
check_duplicates=self.pi_check_for_duplicate_parameters,
)
for variable, values in self.__timeseries_import.items(ensemble_member):
self.io.set_timeseries(variable, timeseries_import_times, values, ensemble_member)
imported_seed_options = self.imported_seed_options()
if imported_seed_options["import_seed"]:
# If an imported seed has been provided merge it with the timeseries_import
try:
self.__imported_seed_timeseries = pi.Timeseries(
self.__data_config,
self._input_folder,
self.imported_seed_basename,
binary=self.pi_binary_timeseries,
pi_validate_times=self.pi_validate_timeseries,
)
except IOError:
logger.warning(
"PIMixin: {}.xml not found in {}.".format(
self.imported_seed_basename, self._input_folder
)
)
return

imported_seed_times = self.__imported_seed_timeseries.times

# Timestamp check
if self.pi_validate_timeseries:
for i in range(len(imported_seed_times) - 1):
if imported_seed_times[i] >= imported_seed_times[i + 1]:
logger.warning("PIMixin: Time stamps must be strictly increasing.")
return

# Check if the timeseries are truly equidistant
if self.pi_validate_timeseries:
dt = imported_seed_times[1] - imported_seed_times[0]
for i in range(len(imported_seed_times) - 1):
if imported_seed_times[i + 1] - imported_seed_times[i] != dt:
logger.warning(
"PIMixin: Expecting equidistant timeseries, the time step "
"towards {} is not the same as the time step(s) before. Seeding "
"using an imported result is only supported for equidistant "
"timesteps".format(imported_seed_times[i + 1])
)
return
# Check if timestep is same as timeseries_import
if dt != self.__timeseries_import.dt:
logger.warning(
"PIMixin: The timesteps in timeseries_import {} differ from the "
"timesteps in the imported previous result {}. This is not "
"supported".format(self.__timeseries_import.dt, dt)
)
return

imported_seed_times_t0 = imported_seed_times[0]
t0_difference = self.io.reference_datetime - imported_seed_times_t0
index_difference = int(t0_difference / dt)

# Check that timeseries_import values are in the seed
if len(imported_seed_times) < abs(index_difference):
logger.warning(
"Imported result does not overlap with {} range. "
"Default seed is used".format(self.timeseries_import_basename)
)
return
times = timeseries_import_times

for ensemble_member in range(self.__imported_seed_timeseries.ensemble_size):
for variable, values in self.__imported_seed_timeseries.items(ensemble_member):
write_ts = True
if index_difference >= 0:
values = np.asarray(values[index_difference:], dtype=np.float64)
if len(times) < len(values):
values = values[: len(times)]
elif len(times) > len(values):
if imported_seed_options["extend_seed_forwards"]:
# extend the last entry
values = np.append(
values, [values[-1]] * (len(times) - len(values))
)
else:
values = np.append(values, np.nan * (len(times) - len(values)))
else:
values = np.asarray(values, dtype=np.float64)
if imported_seed_options["extend_seed_backwards"]:
# extend first entry back to t0
values = np.append([values[0]] * abs(index_difference), values)
else:
values = np.append(np.nan * abs(index_difference), values)
if len(times) < len(values):
values = values[: len(times)]
elif len(times) > len(values):
if imported_seed_options["extend_seed_forwards"]:
# extend the last entry
values = np.append(
values, [values[-1]] * (len(times) - len(values))
)
else:
values = np.append(values, np.nan * (len(times) - len(values)))
if imported_seed_options["ignore_seed_values_at_t0"]:
values[0] = np.nan
for (
timeseries_import_variable,
timeseries_import_values,
) in self.__timeseries_import.items(ensemble_member):
if timeseries_import_variable == variable:
if "GEM_Dolk_Quitlaat" in variable:
print("stop")
if not imported_seed_options["seed_variables_in_timeseries_import"]:
write_ts = False
elif np.isnan(timeseries_import_values).any():
values = [
a if not math.isnan(a) else b
for a, b in zip(timeseries_import_values, values)
]
else:
write_ts = False
break
if write_ts:
self.io.set_timeseries(
variable, timeseries_import_times, values, ensemble_member
)
else:
# without this the timeseries is overwritten
self.io.set_timeseries(
variable,
timeseries_import_times,
timeseries_import_values,
ensemble_member,
)
logger.info("PIMixin: updated imported timeseries with data from imported seed.")

def solver_options(self):
# Call parent
Expand Down

0 comments on commit aa98526

Please sign in to comment.