Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving timesteps parameter #305

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/dynode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,16 @@ class is accepted to modify/create the downstream parameters.
# "validate": do_nothing,
"type": lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date(),
},
{
# list[date] on which the user wishes to save the state of each
# compartment, final_timesteps automatically
"name": "COMPARTMENT_SAVE_DATES",
# "validate": do_nothing,
# type list[date]
"type": lambda lst: [
datetime.datetime.strptime(s, "%Y-%m-%d").date() for s in lst
],
},
{
"name": "VACCINATION_SEASON_CHANGE",
# "validate": do_nothing,
Expand Down
50 changes: 36 additions & 14 deletions src/dynode/dynode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def save_inference_posteriors(
self,
inferer: MechanisticInferer,
save_filename="checkpoint.json",
exclude_prefixes=["final_timestep"],
exclude_prefixes=["timestep"],
save_chains_plot=True,
save_pairs_correlation_plot=True,
) -> None:
Expand All @@ -443,7 +443,8 @@ def save_inference_posteriors(
exclude_prefixes: list[str], optional
a list of strs that, if found in a sample name,
are exlcuded from the saved json. This is common for large logging
info that will bloat filesize like, by default ["final_timestep"]
info that will bloat filesize like, by default ["timestep"]
to exclude all timestep deterministic variables.
save_chains_plot: bool, optional
whether to save accompanying mcmc chains plot, by default True
save_pairs_correlation_plot: bool, optional
Expand Down Expand Up @@ -480,40 +481,61 @@ def save_inference_final_timesteps(
self,
inferer: MechanisticInferer,
save_filename="final_timesteps.json",
final_timestep_identifier="final_timestep",
):
"""saves the `final_timestep` posterior, if it is found in mcmc.get_samples(), otherwise raises a warning
and saves nothing
"""saves the `final_timestep` posterior, if it is found in
mcmc.get_samples(), otherwise raises a warning and saves nothing

Parameters
----------
inferer : MechanisticInferer
inferer that was run with `inferer.infer()`
save_filename : str, optional
output filename, by default "final_timesteps.json"
output filename, by default "timesteps.json"
final_timestep_identifier : str, optional
prefix attached to the final_timestep parameter, by default "final_timestep"
prefix attached to the final_timestep parameter, by default "timestep"
"""
self.save_inference_timesteps(
inferer, save_filename, timestep_identifier="final_timestep"
)

def save_inference_timesteps(
self,
inferer: MechanisticInferer,
save_filename="timesteps.json",
timestep_identifier="timestep",
):
"""saves all `timestep` posteriors, if they are found in
mcmc.get_samples(), otherwise raises a warning and saves nothing

Parameters
----------
inferer : MechanisticInferer
inferer that was run with `inferer.infer()`
save_filename : str, optional
output filename, by default "timesteps.json"
step_identifier : str, optional
identifying token attached to any timestep parameter, by default "timestep"
"""
# if inference complete, convert jnp/np arrays to list, then json dump
if inferer.infer_complete:
samples = inferer.inference_algo.get_samples(group_by_chain=True)
final_timesteps = {
timesteps = {
name: timesteps
for name, timesteps in samples.items()
if final_timestep_identifier in name
if timestep_identifier in name
}
# if it is empty, warn the user, save nothing
if final_timesteps:
if timesteps:
save_path = os.path.join(self.azure_output_dir, save_filename)
self._save_samples(final_timesteps, save_path)
self._save_samples(timesteps, save_path)
else:
warnings.warn(
"attempting to call `save_inference_final_timesteps` but failed to find any final_timesteps with prefix %s"
% final_timestep_identifier
"attempting to call `save_inference_timesteps` but failed to find any timesteps with prefix %s"
% timestep_identifier
)
else:
warnings.warn(
"attempting to call `save_inference_final_timesteps` before inference is complete. Something is likely wrong..."
"attempting to call `save_inference_timesteps` before inference is complete. Something is likely wrong..."
)

def save_inference_timelines(
Expand Down
50 changes: 38 additions & 12 deletions src/dynode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
observed metrics.
"""

import datetime
import json
from typing import Union

Expand All @@ -21,6 +22,7 @@
from .abstract_parameters import AbstractParameters
from .config import Config
from .mechanistic_runner import MechanisticRunner
from .utils import date_to_sim_day


class MechanisticInferer(AbstractParameters):
Expand Down Expand Up @@ -182,18 +184,7 @@ def likelihood(
dct = self.run_simulation(tf)
solution = dct["solution"]
predicted_metrics = dct["hospitalizations"]
numpyro.deterministic(
"final_timestep_s", solution.ys[self.config.COMPARTMENT_IDX.S][-1]
)
numpyro.deterministic(
"final_timestep_e", solution.ys[self.config.COMPARTMENT_IDX.E][-1]
)
numpyro.deterministic(
"final_timestep_i", solution.ys[self.config.COMPARTMENT_IDX.I][-1]
)
numpyro.deterministic(
"final_timestep_c", solution.ys[self.config.COMPARTMENT_IDX.C][-1]
)
self._checkpoint_compartment_sizes(solution)
predicted_metrics = jnp.maximum(predicted_metrics, 1e-6)
numpyro.sample(
"incidence",
Expand Down Expand Up @@ -247,6 +238,41 @@ def _debug_likelihood(self, **kwargs) -> bx.Model:
)
return bx_model

def _checkpoint_compartment_sizes(self, solution: Solution):
"""marks the final_timesteps parameters as well as any
requested dates from self.config.COMPARTMENT_SAVE_DATES if the
parameter exists. Skipping over any invalid dates.

This method does not actually save the compartment sizes to a file,
instead it stores the values within `self.inference_algo.get_samples()`
so that they may be later saved by self.checkpoint() or by the user.


Parameters
----------
solution : diffrax.Solution
a diffrax Solution object returned by solving ODEs, most often
retrieved by `self.run_simulation()`
"""
for compartment in self.config.COMPARTMENT_IDX:
numpyro.deterministic(
"final_timestep_%s" % compartment.name,
solution.ys[compartment][-1],
)
for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []):
date: datetime.date
date_str = date.strftime("%Y_%m_%d")
sim_day = date_to_sim_day(date, self.config.INIT_DATE)
# ensure user requests a day we actually have in `solution`
if sim_day >= 0 and sim_day < len(
solution.ys[self.config.COMPARTMENT_IDX.S]
):
for compartment in self.config.COMPARTMENT_IDX:
numpyro.deterministic(
"%s_timestep_%s" % (date_str, compartment.name),
solution.ys[compartment][sim_day],
)

def checkpoint(
self, checkpoint_path: str, group_by_chain: bool = True
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def plot_checkpoint_inference_correlation_pairs(
for key, val in posteriors.items()
}
posteriors: dict[str, np.ndarray] = flatten_list_parameters(posteriors)
# drop any final_timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "final_timestep")
# drop any timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "timestep")
number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1]
# if we are dealing with many samples per chain,
# narrow down to max_samples_calculated samples per chain
Expand Down Expand Up @@ -435,8 +435,8 @@ def plot_mcmc_chains(
for key, val in samples.items()
}
samples: dict[str, np.ndarray] = flatten_list_parameters(samples)
# drop any final_timestep parameters in case they snuck in
samples = drop_keys_with_substring(samples, "final_timestep")
# drop any timestep parameters in case they snuck in
samples = drop_keys_with_substring(samples, "timestep")
param_names = list(samples.keys())
num_params = len(param_names)
num_chains = samples[param_names[0]].shape[0]
Expand Down
Loading