From 3edb80890ac33bf2a19f4b40b0f9a84c5b2f5eb1 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Wed, 4 Dec 2024 00:05:17 +0000 Subject: [PATCH 1/6] checkpoint, now save compartment sizes at each date in COMPARTMENT_SAVE_DATES --- src/dynode/config.py | 10 +++++++ src/dynode/dynode_runner.py | 39 ++++++++++++++++++------ src/dynode/mechanistic_inferer.py | 50 +++++++++++++++++++++++-------- 3 files changed, 78 insertions(+), 21 deletions(-) diff --git a/src/dynode/config.py b/src/dynode/config.py index 402dbe6..9b778ec 100644 --- a/src/dynode/config.py +++ b/src/dynode/config.py @@ -839,6 +839,16 @@ class is accepted to modify/create the downstream parameters. # "validate": do_nothing, "type": lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date(), }, + { + # list[date] on which the user wishes to save the state of each + # compartment, final_timesteps automatically + "name": "COMPARTMENT_SAVE_DATES", + # "validate": do_nothing, + # type list[date] + "type": lambda lst: [ + datetime.datetime.strptime(s, "%Y-%m-%d").date() for s in lst + ], + }, { "name": "VACCINATION_SEASON_CHANGE", # "validate": do_nothing, diff --git a/src/dynode/dynode_runner.py b/src/dynode/dynode_runner.py index c47b7be..2c4b138 100644 --- a/src/dynode/dynode_runner.py +++ b/src/dynode/dynode_runner.py @@ -480,19 +480,40 @@ def save_inference_final_timesteps( self, inferer: MechanisticInferer, save_filename="final_timesteps.json", - final_timestep_identifier="final_timestep", ): - """saves the `final_timestep` posterior, if it is found in mcmc.get_samples(), otherwise raises a warning - and saves nothing + """saves the `final_timestep` posterior, if it is found in + mcmc.get_samples(), otherwise raises a warning and saves nothing Parameters ---------- inferer : MechanisticInferer inferer that was run with `inferer.infer()` save_filename : str, optional - output filename, by default "final_timesteps.json" + output filename, by default "timesteps.json" final_timestep_identifier : str, optional - prefix attached to the final_timestep parameter, by default "final_timestep" + prefix attached to the final_timestep parameter, by default "timestep" + """ + self.save_inference_timesteps( + inferer, save_filename, timestep_identifier="final_timestep" + ) + + def save_inference_timesteps( + self, + inferer: MechanisticInferer, + save_filename="timesteps.json", + timestep_identifier="timestep", + ): + """saves all `timestep` posteriors, if they are found in + mcmc.get_samples(), otherwise raises a warning and saves nothing + + Parameters + ---------- + inferer : MechanisticInferer + inferer that was run with `inferer.infer()` + save_filename : str, optional + output filename, by default "timesteps.json" + final_timestep_identifier : str, optional + prefix attached to the final_timestep parameter, by default "timestep" """ # if inference complete, convert jnp/np arrays to list, then json dump if inferer.infer_complete: @@ -500,7 +521,7 @@ def save_inference_final_timesteps( final_timesteps = { name: timesteps for name, timesteps in samples.items() - if final_timestep_identifier in name + if timestep_identifier in name } # if it is empty, warn the user, save nothing if final_timesteps: @@ -508,12 +529,12 @@ def save_inference_final_timesteps( self._save_samples(final_timesteps, save_path) else: warnings.warn( - "attempting to call `save_inference_final_timesteps` but failed to find any final_timesteps with prefix %s" - % final_timestep_identifier + "attempting to call `save_inference_timesteps` but failed to find any timesteps with prefix %s" + % timestep_identifier ) else: warnings.warn( - "attempting to call `save_inference_final_timesteps` before inference is complete. Something is likely wrong..." + "attempting to call `save_inference_timesteps` before inference is complete. Something is likely wrong..." ) def save_inference_timelines( diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index fcd959a..fe1219d 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -4,6 +4,7 @@ observed metrics. """ +import datetime import json from typing import Union @@ -21,6 +22,7 @@ from .abstract_parameters import AbstractParameters from .config import Config from .mechanistic_runner import MechanisticRunner +from .utils import date_to_sim_day class MechanisticInferer(AbstractParameters): @@ -182,18 +184,7 @@ def likelihood( dct = self.run_simulation(tf) solution = dct["solution"] predicted_metrics = dct["hospitalizations"] - numpyro.deterministic( - "final_timestep_s", solution.ys[self.config.COMPARTMENT_IDX.S][-1] - ) - numpyro.deterministic( - "final_timestep_e", solution.ys[self.config.COMPARTMENT_IDX.E][-1] - ) - numpyro.deterministic( - "final_timestep_i", solution.ys[self.config.COMPARTMENT_IDX.I][-1] - ) - numpyro.deterministic( - "final_timestep_c", solution.ys[self.config.COMPARTMENT_IDX.C][-1] - ) + self._checkpoint_compartment_sizes(solution) predicted_metrics = jnp.maximum(predicted_metrics, 1e-6) numpyro.sample( "incidence", @@ -247,6 +238,41 @@ def _debug_likelihood(self, **kwargs) -> bx.Model: ) return bx_model + def _checkpoint_compartment_sizes(self, solution: Solution): + """marks the final_timesteps parameters as well as any + requested dates from self.config.COMPARTMENT_SAVE_DATES if the + parameter exists. Skipping over any invalid dates. + + This method does not actually save the compartment sizes to a file, + instead it stores the values within `self.inference_algo.get_samples()` + so that they may be later saved by self.checkpoint() or by the user. + + + Parameters + ---------- + solution : diffrax.Solution + a diffrax Solution object returned by solving ODEs, most often + retrieved by `self.run_simulation()` + """ + for compartment in self.config.COMPARTMENT_IDX: + numpyro.deterministic( + "final_timestep_%s" % compartment.name, + solution.ys[compartment][-1], + ) + for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []): + date: datetime.date + date_str = date.strftime("%Y-%m-%d") + sim_day = date_to_sim_day(date, self.config.INIT_DATE) + # ensure user requests a day we actually have in `solution` + if sim_day >= 0 or sim_day < len( + solution.ys[self.config.COMPARTMENT_IDX.S] + ): + for compartment in self.config.COMPARTMENT_IDX: + numpyro.deterministic( + "%s_timestep_%s" % (date_str, compartment.name), + solution.ys[compartment][-1], + ) + def checkpoint( self, checkpoint_path: str, group_by_chain: bool = True ) -> None: From 7ad407fb9ee6fb1b8e8c6fd5878e18428bf629df Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Wed, 4 Dec 2024 00:44:43 +0000 Subject: [PATCH 2/6] bugfix sim_day to _checkpoint_compartment_sizes() --- src/dynode/mechanistic_inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index fe1219d..d9549db 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -270,7 +270,7 @@ def _checkpoint_compartment_sizes(self, solution: Solution): for compartment in self.config.COMPARTMENT_IDX: numpyro.deterministic( "%s_timestep_%s" % (date_str, compartment.name), - solution.ys[compartment][-1], + solution.ys[compartment][sim_day], ) def checkpoint( From 039db30dd0ac1d7855a0ddc3c2fc585321af8d41 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Tue, 10 Dec 2024 23:48:13 +0000 Subject: [PATCH 3/6] checkpoint trying to debug OOM issues when saving timesteps --- src/dynode/mechanistic_inferer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index b306606..b2e2d25 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -22,7 +22,7 @@ from .abstract_parameters import AbstractParameters from .config import Config from .mechanistic_runner import MechanisticRunner -from .utils import date_to_sim_day +from .utils import date_to_sim_day, sim_day_to_date class MechanisticInferer(AbstractParameters): @@ -259,6 +259,13 @@ def _checkpoint_compartment_sizes(self, solution: Solution): "final_timestep_%s" % compartment.name, solution.ys[compartment][-1], ) + date_str = sim_day_to_date(50, self.config.INIT_DATE).strftime( + "%Y-%m-%d" + ) + numpyro.deterministic( + "%s_timestep_%s" % (date_str, compartment.name), + solution.ys[compartment][50], + ) for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []): date: datetime.date date_str = date.strftime("%Y-%m-%d") From 0ce61b36d489e8c54113ce9603c579a19b4ecb70 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Wed, 11 Dec 2024 00:16:45 +0000 Subject: [PATCH 4/6] excluding all timesteps variable --- src/dynode/dynode_runner.py | 2 +- src/dynode/mechanistic_inferer.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/dynode/dynode_runner.py b/src/dynode/dynode_runner.py index 1b96ddf..33510c6 100644 --- a/src/dynode/dynode_runner.py +++ b/src/dynode/dynode_runner.py @@ -426,7 +426,7 @@ def save_inference_posteriors( self, inferer: MechanisticInferer, save_filename="checkpoint.json", - exclude_prefixes=["final_timestep"], + exclude_prefixes=["timestep"], save_chains_plot=True, save_pairs_correlation_plot=True, ) -> None: diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index b2e2d25..b306606 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -22,7 +22,7 @@ from .abstract_parameters import AbstractParameters from .config import Config from .mechanistic_runner import MechanisticRunner -from .utils import date_to_sim_day, sim_day_to_date +from .utils import date_to_sim_day class MechanisticInferer(AbstractParameters): @@ -259,13 +259,6 @@ def _checkpoint_compartment_sizes(self, solution: Solution): "final_timestep_%s" % compartment.name, solution.ys[compartment][-1], ) - date_str = sim_day_to_date(50, self.config.INIT_DATE).strftime( - "%Y-%m-%d" - ) - numpyro.deterministic( - "%s_timestep_%s" % (date_str, compartment.name), - solution.ys[compartment][50], - ) for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []): date: datetime.date date_str = date.strftime("%Y-%m-%d") From 5e41cc2196d4678601f1e734e2de4d78394ccc58 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Wed, 11 Dec 2024 01:06:27 +0000 Subject: [PATCH 5/6] timesteps feature working --- src/dynode/dynode_runner.py | 13 +++++++------ src/dynode/mechanistic_inferer.py | 2 +- src/dynode/vis_utils.py | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/dynode/dynode_runner.py b/src/dynode/dynode_runner.py index 33510c6..9e67f54 100644 --- a/src/dynode/dynode_runner.py +++ b/src/dynode/dynode_runner.py @@ -443,7 +443,8 @@ def save_inference_posteriors( exclude_prefixes: list[str], optional a list of strs that, if found in a sample name, are exlcuded from the saved json. This is common for large logging - info that will bloat filesize like, by default ["final_timestep"] + info that will bloat filesize like, by default ["timestep"] + to exclude all timestep deterministic variables. save_chains_plot: bool, optional whether to save accompanying mcmc chains plot, by default True save_pairs_correlation_plot: bool, optional @@ -512,21 +513,21 @@ def save_inference_timesteps( inferer that was run with `inferer.infer()` save_filename : str, optional output filename, by default "timesteps.json" - final_timestep_identifier : str, optional - prefix attached to the final_timestep parameter, by default "timestep" + step_identifier : str, optional + identifying token attached to any timestep parameter, by default "timestep" """ # if inference complete, convert jnp/np arrays to list, then json dump if inferer.infer_complete: samples = inferer.inference_algo.get_samples(group_by_chain=True) - final_timesteps = { + timesteps = { name: timesteps for name, timesteps in samples.items() if timestep_identifier in name } # if it is empty, warn the user, save nothing - if final_timesteps: + if timesteps: save_path = os.path.join(self.azure_output_dir, save_filename) - self._save_samples(final_timesteps, save_path) + self._save_samples(timesteps, save_path) else: warnings.warn( "attempting to call `save_inference_timesteps` but failed to find any timesteps with prefix %s" diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index b306606..6207b91 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -261,7 +261,7 @@ def _checkpoint_compartment_sizes(self, solution: Solution): ) for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []): date: datetime.date - date_str = date.strftime("%Y-%m-%d") + date_str = date.strftime("%Y_%m_%d") sim_day = date_to_sim_day(date, self.config.INIT_DATE) # ensure user requests a day we actually have in `solution` if sim_day >= 0 or sim_day < len( diff --git a/src/dynode/vis_utils.py b/src/dynode/vis_utils.py index 6428986..9d75455 100644 --- a/src/dynode/vis_utils.py +++ b/src/dynode/vis_utils.py @@ -304,8 +304,8 @@ def plot_checkpoint_inference_correlation_pairs( for key, val in posteriors.items() } posteriors: dict[str, np.ndarray] = flatten_list_parameters(posteriors) - # drop any final_timestep parameters in case they snuck in - posteriors = drop_keys_with_substring(posteriors, "final_timestep") + # drop any timestep parameters in case they snuck in + posteriors = drop_keys_with_substring(posteriors, "timestep") number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1] # if we are dealing with many samples per chain, # narrow down to max_samples_calculated samples per chain @@ -435,8 +435,8 @@ def plot_mcmc_chains( for key, val in samples.items() } samples: dict[str, np.ndarray] = flatten_list_parameters(samples) - # drop any final_timestep parameters in case they snuck in - samples = drop_keys_with_substring(samples, "final_timestep") + # drop any timestep parameters in case they snuck in + samples = drop_keys_with_substring(samples, "timestep") param_names = list(samples.keys()) num_params = len(param_names) num_chains = samples[param_names[0]].shape[0] From 785819ae62f9e327f62673c363ed9a76fdd58561 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Wed, 11 Dec 2024 01:40:01 +0000 Subject: [PATCH 6/6] bugfix boolean logic or -> and --- src/dynode/mechanistic_inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dynode/mechanistic_inferer.py b/src/dynode/mechanistic_inferer.py index 6207b91..0106f46 100644 --- a/src/dynode/mechanistic_inferer.py +++ b/src/dynode/mechanistic_inferer.py @@ -264,7 +264,7 @@ def _checkpoint_compartment_sizes(self, solution: Solution): date_str = date.strftime("%Y_%m_%d") sim_day = date_to_sim_day(date, self.config.INIT_DATE) # ensure user requests a day we actually have in `solution` - if sim_day >= 0 or sim_day < len( + if sim_day >= 0 and sim_day < len( solution.ys[self.config.COMPARTMENT_IDX.S] ): for compartment in self.config.COMPARTMENT_IDX: