diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py b/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py index c964d85..a6342cc 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py @@ -26,7 +26,6 @@ import ml_collections from ml_collections import config_flags import numpy as np - from swirl_dynamics.data import hdf5_utils from swirl_dynamics.lib.solvers import ode as ode_solvers from swirl_dynamics.projects.debiasing.rectified_flow import data_utils @@ -405,6 +404,7 @@ def evaluation_pipeline( lens2_member_indexer: tuple[dict[str, str], ...] | None = None, lens2_variable_names: dict[str, dict[str, str]] | None = None, era5_variables: dict[str, dict[str, int] | None] | None = None, + num_sampling_steps: int = 100, ): """The evaluation pipeline. @@ -418,6 +418,7 @@ def evaluation_pipeline( lens2_member_indexer: The member indexer for the LENS2 dataset. lens2_variable_names: The variable names for the LENS2 dataset. era5_variables: The variable names for the ERA5 dataset. + num_sampling_steps: The number of sampling steps for solving the ODE. Returns: A dictionary with the evaluation metrics. @@ -504,7 +505,7 @@ def evaluation_pipeline( integrate_fn = functools.partial( integrator, latent_dynamics_fn, - tspan=jnp.arange(0.0, 1.0, 0.01), + tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps), params=trained_state.model_variables, ) pmap_integrate_fn = jax.pmap(integrate_fn, in_axes=0, out_axes=0) @@ -701,6 +702,15 @@ def main(argv): else: lens2_member_indexer_tuple = _LENS2_MEMBER_INDEXER + if "num_sampling_steps" in config: + logging.info("Using num_sampling_steps from config file.") + num_sampling_steps = config.num_sampling_steps + else: + logging.info("Using default num_sampling_steps.") + num_sampling_steps = 100 + + logging.info("Number of sampling steps %d", num_sampling_steps) + print("Indexers") print(lens2_member_indexer_tuple, flush=True) @@ -738,6 +748,7 @@ def main(argv): ), lens2_variable_names=lens2_variable_names, era5_variables=era5_variables, + num_sampling_steps=num_sampling_steps, ) # Save all the error into a file. diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py b/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py index a3e6173..c4a321d 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py @@ -26,7 +26,6 @@ import ml_collections from ml_collections import config_flags import numpy as np - from swirl_dynamics.lib.solvers import ode as ode_solvers from swirl_dynamics.projects.debiasing.rectified_flow import data_utils from swirl_dynamics.projects.debiasing.rectified_flow import models @@ -184,12 +183,12 @@ def read_stats( if mean_lens2.shape != mean_era5.shape: raise ValueError( "The shape of the mean_lens2 and mean_era5 must be the same; ", - f"instead got {mean_lens2.shape} and {mean_era5.shape}" + f"instead got {mean_lens2.shape} and {mean_era5.shape}", ) if std_lens2.shape != std_era5.shape: raise ValueError( "The shape of the std_lens2 and std_era5 must be the same; ", - f"instead got {std_lens2.shape} and {std_era5.shape}" + f"instead got {std_lens2.shape} and {std_era5.shape}", ) return { @@ -306,6 +305,7 @@ def inference_pipeline( lens2_member_indexer: tuple[dict[str, str], ...] | None = None, lens2_variable_names: dict[str, dict[str, str]] | None = None, era5_variables: dict[str, dict[str, int] | None] | None = None, + num_sampling_steps: int = 100, ) -> dict[str, np.ndarray]: """The evaluation pipeline. @@ -318,6 +318,7 @@ def inference_pipeline( lens2_member_indexer: The member indexer for the LENS2 dataset. lens2_variable_names: The names of the variables in the LENS2 dataset. era5_variables: The names of the variables in the ERA5 dataset. + num_sampling_steps: The number of sampling steps for solving the ODE. Returns: A dictionary with the data both input and output with their corresponding @@ -400,7 +401,7 @@ def inference_pipeline( integrate_fn = functools.partial( integrator, latent_dynamics_fn, - tspan=jnp.arange(0.0, 1.0, 0.01), + tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps), params=trained_state.model_variables, ) pmap_integrate_fn = jax.pmap(integrate_fn, in_axes=0, out_axes=0) @@ -520,6 +521,15 @@ def main(argv): else: lens2_member_indexer = _LENS2_MEMBER_INDEXER + if "num_sampling_steps" in config: + logging.info("Using num_sampling_steps from config file.") + num_sampling_steps = config.num_sampling_steps + else: + logging.info("Using default num_sampling_steps.") + num_sampling_steps = 100 + + logging.info("Number of sampling steps %d", num_sampling_steps) + for lens2_indexer in lens2_member_indexer: print("Evaluating on CMIP dataset indexer: ", lens2_indexer) index_member = lens2_indexer @@ -544,6 +554,7 @@ def main(argv): ), lens2_variable_names=lens2_variable_names, era5_variables=era5_variables, + num_sampling_steps=num_sampling_steps, ) if jax.process_index() == 0: