Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643242607
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Jun 14, 2024
1 parent f581fc9 commit fdf118d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import ml_collections
from ml_collections import config_flags
import numpy as np

from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.lib.solvers import ode as ode_solvers
from swirl_dynamics.projects.debiasing.rectified_flow import data_utils
Expand Down Expand Up @@ -405,6 +404,7 @@ def evaluation_pipeline(
lens2_member_indexer: tuple[dict[str, str], ...] | None = None,
lens2_variable_names: dict[str, dict[str, str]] | None = None,
era5_variables: dict[str, dict[str, int] | None] | None = None,
num_sampling_steps: int = 100,
):
"""The evaluation pipeline.
Expand All @@ -418,6 +418,7 @@ def evaluation_pipeline(
lens2_member_indexer: The member indexer for the LENS2 dataset.
lens2_variable_names: The variable names for the LENS2 dataset.
era5_variables: The variable names for the ERA5 dataset.
num_sampling_steps: The number of sampling steps for solving the ODE.
Returns:
A dictionary with the evaluation metrics.
Expand Down Expand Up @@ -504,7 +505,7 @@ def evaluation_pipeline(
integrate_fn = functools.partial(
integrator,
latent_dynamics_fn,
tspan=jnp.arange(0.0, 1.0, 0.01),
tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps),
params=trained_state.model_variables,
)
pmap_integrate_fn = jax.pmap(integrate_fn, in_axes=0, out_axes=0)
Expand Down Expand Up @@ -701,6 +702,15 @@ def main(argv):
else:
lens2_member_indexer_tuple = _LENS2_MEMBER_INDEXER

if "num_sampling_steps" in config:
logging.info("Using num_sampling_steps from config file.")
num_sampling_steps = config.num_sampling_steps
else:
logging.info("Using default num_sampling_steps.")
num_sampling_steps = 100

logging.info("Number of sampling steps %d", num_sampling_steps)

print("Indexers")
print(lens2_member_indexer_tuple, flush=True)

Expand Down Expand Up @@ -738,6 +748,7 @@ def main(argv):
),
lens2_variable_names=lens2_variable_names,
era5_variables=era5_variables,
num_sampling_steps=num_sampling_steps,
)

# Save all the error into a file.
Expand Down
19 changes: 15 additions & 4 deletions swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import ml_collections
from ml_collections import config_flags
import numpy as np

from swirl_dynamics.lib.solvers import ode as ode_solvers
from swirl_dynamics.projects.debiasing.rectified_flow import data_utils
from swirl_dynamics.projects.debiasing.rectified_flow import models
Expand Down Expand Up @@ -184,12 +183,12 @@ def read_stats(
if mean_lens2.shape != mean_era5.shape:
raise ValueError(
"The shape of the mean_lens2 and mean_era5 must be the same; ",
f"instead got {mean_lens2.shape} and {mean_era5.shape}"
f"instead got {mean_lens2.shape} and {mean_era5.shape}",
)
if std_lens2.shape != std_era5.shape:
raise ValueError(
"The shape of the std_lens2 and std_era5 must be the same; ",
f"instead got {std_lens2.shape} and {std_era5.shape}"
f"instead got {std_lens2.shape} and {std_era5.shape}",
)

return {
Expand Down Expand Up @@ -306,6 +305,7 @@ def inference_pipeline(
lens2_member_indexer: tuple[dict[str, str], ...] | None = None,
lens2_variable_names: dict[str, dict[str, str]] | None = None,
era5_variables: dict[str, dict[str, int] | None] | None = None,
num_sampling_steps: int = 100,
) -> dict[str, np.ndarray]:
"""The evaluation pipeline.
Expand All @@ -318,6 +318,7 @@ def inference_pipeline(
lens2_member_indexer: The member indexer for the LENS2 dataset.
lens2_variable_names: The names of the variables in the LENS2 dataset.
era5_variables: The names of the variables in the ERA5 dataset.
num_sampling_steps: The number of sampling steps for solving the ODE.
Returns:
A dictionary with the data both input and output with their corresponding
Expand Down Expand Up @@ -400,7 +401,7 @@ def inference_pipeline(
integrate_fn = functools.partial(
integrator,
latent_dynamics_fn,
tspan=jnp.arange(0.0, 1.0, 0.01),
tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps),
params=trained_state.model_variables,
)
pmap_integrate_fn = jax.pmap(integrate_fn, in_axes=0, out_axes=0)
Expand Down Expand Up @@ -520,6 +521,15 @@ def main(argv):
else:
lens2_member_indexer = _LENS2_MEMBER_INDEXER

if "num_sampling_steps" in config:
logging.info("Using num_sampling_steps from config file.")
num_sampling_steps = config.num_sampling_steps
else:
logging.info("Using default num_sampling_steps.")
num_sampling_steps = 100

logging.info("Number of sampling steps %d", num_sampling_steps)

for lens2_indexer in lens2_member_indexer:
print("Evaluating on CMIP dataset indexer: ", lens2_indexer)
index_member = lens2_indexer
Expand All @@ -544,6 +554,7 @@ def main(argv):
),
lens2_variable_names=lens2_variable_names,
era5_variables=era5_variables,
num_sampling_steps=num_sampling_steps,
)

if jax.process_index() == 0:
Expand Down

0 comments on commit fdf118d

Please sign in to comment.