-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 584684425
- Loading branch information
Showing
3 changed files
with
322 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
swirl_dynamics/projects/probabilistic_diffusion/evaluate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright 2023 The swirl_dynamics Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Modules for evaluating conditional samples.""" | ||
|
||
from collections.abc import Mapping, Sequence | ||
import dataclasses | ||
from typing import Any | ||
|
||
from clu import metrics as clu_metrics | ||
import flax | ||
import jax | ||
import jax.numpy as jnp | ||
from swirl_dynamics.lib.metrics import probabilistic_forecast as prob_metrics | ||
from swirl_dynamics.projects.probabilistic_diffusion import inference | ||
from swirl_dynamics.templates import evaluate | ||
|
||
|
||
Array = jax.Array | ||
PyTree = Any | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class CondSamplingBenchmark(evaluate.Benchmark): | ||
"""Draws conditional samples and evaluates probabilistic scores. | ||
Required `batch` schema:: | ||
batch["cond"]: dict[str, jax.Array] | None # a-priori condition | ||
batch["guidance_inputs"]: dict[str, jax.Array] | None # guidance inputs | ||
batch["obs"]: jax.Array # observation wrt which the samples are evaluated | ||
NOTE: Batch size *should always be 1*. | ||
Attributes: | ||
num_samples_per_cond: The number of conditional samples to generate per | ||
condition. The samples are generated in batches. | ||
sample_batch_size: The batch size to generate conditional samples at. | ||
brier_score_thresholds: The threshold values to evaluate for the Brier | ||
scores. | ||
""" | ||
|
||
num_samples_per_cond: int | ||
sample_batch_size: int | ||
brier_score_thresholds: Sequence[int | float] | ||
|
||
def __post_init__(self): | ||
if self.num_samples_per_cond % self.sample_batch_size != 0: | ||
raise ValueError( | ||
f"`sample_batch_size` ({self.sample_batch_size}) must be divisible by" | ||
f" `num_samples_per_cond` ({self.num_samples_per_cond})." | ||
) | ||
|
||
def run_batch_inference( | ||
self, | ||
inference_fn: inference.CondSampler, | ||
batch: Mapping[str, Any], | ||
rng: Array, | ||
) -> Array: | ||
"""Runs batch inference on a conditional sampler.""" | ||
num_batches = self.num_samples_per_cond // self.sample_batch_size | ||
rngs = jax.random.split(rng, num=num_batches) | ||
|
||
squeeze_fn = lambda x: jnp.squeeze(x, axis=0) | ||
cond = jax.tree_map(squeeze_fn, batch["cond"]) | ||
guidance_inputs = ( | ||
batch["guidance_inputs"] if "guidance_inputs" in batch else None | ||
) | ||
|
||
def _batch_inference_fn(rng: jax.Array) -> jax.Array: | ||
return inference_fn( | ||
num_samples=self.sample_batch_size, | ||
rng=rng, | ||
cond=cond, | ||
guidance_inputs=guidance_inputs, | ||
) | ||
|
||
# using `jax.lax.map` instead of `jax.vmap` because the former is less | ||
# memory intensive and batch inference is expected to be very demanding | ||
samples = jax.lax.map( | ||
_batch_inference_fn, rngs | ||
) # ~ (num_batches, batch_size, *spatial_dims, channels) | ||
samples = jnp.reshape(samples, (1, -1) + samples.shape[2:]) | ||
return samples | ||
|
||
def compute_batch_metrics( | ||
self, pred: jax.Array, batch: Mapping[str, Any] | ||
) -> tuple[dict[str, jax.Array], dict[str, jax.Array]]: | ||
"""Computes metrics on the batch. | ||
Results consist of collected types and aggregated types (see | ||
`swirl_dynamics.templates.Benchmark` protocol for their definitions and | ||
distinctions). | ||
The collected metrics consist of: | ||
* The observation | ||
* A subset of conditional examples generated | ||
* Channel-wise CRPS | ||
* Threshold Brier scores | ||
* Conditional standard deviations | ||
The aggregated metrics consist of: | ||
* Global mean CRPS (scalar) | ||
* Local mean CRPS (averaged for each location) | ||
* Global mean threshold Brier scores | ||
* Local mean threshold Brier score (averaged for each location) | ||
Args: | ||
pred: The conditional samples generated by a benchmarked model. | ||
batch: The evaluation batch data containing a reference observation. | ||
Returns: | ||
Metrics to be collected and aggregated respectively. | ||
""" | ||
# pred ~ (1, num_samples, *spatial, channels) | ||
obs = batch["obs"] # ~ (1, *spatial, channels) | ||
cond_stddev = jnp.std(pred, axis=1) # ~ (1, *spatial, channels) | ||
crps = prob_metrics.crps( | ||
pred, obs, direct_broadcast=False | ||
) # ~ (1, *spatial, channels) | ||
thres_brier_scores = prob_metrics.threshold_brier_score( | ||
pred, obs, jnp.asarray(self.brier_score_thresholds) | ||
) # ~ (1, *spatial, channels, n_thresholds); | ||
batch_collect = dict( | ||
observation=obs, | ||
example1=pred[:, 0], | ||
cond_stddev=cond_stddev, | ||
crps=crps, | ||
thres_brier_scores=thres_brier_scores, | ||
) | ||
batch_result = dict( | ||
crps=crps, | ||
thres_brier_scores=thres_brier_scores, | ||
) | ||
return batch_collect, batch_result | ||
|
||
|
||
class CondSamplingEvaluator(evaluate.Evaluator): | ||
"""Evaluator for the conditional sampling benchmark.""" | ||
|
||
@flax.struct.dataclass | ||
class AggregatingMetrics(clu_metrics.Collection): | ||
global_mean_crps: evaluate.TensorAverage(axis=None).from_output("crps") | ||
local_mean_crps: evaluate.TensorAverage(axis=0).from_output("crps") | ||
global_mean_threshold_brier_score: evaluate.TensorAverage( | ||
axis=(0, 1, 2, 3) # NOTE: specific to 2d case | ||
).from_output("thres_brier_scores") | ||
local_mean_threshold_brier_score: evaluate.TensorAverage( | ||
axis=0 | ||
).from_output("thres_brier_scores") | ||
|
||
@property | ||
def scalar_metrics_to_log(self) -> dict[str, Array]: | ||
"""Logs global crps and threshold brier scores.""" | ||
scalar_metrics = {} | ||
agg_metrics = self.state.compute_aggregated_metrics() | ||
for model_key, metric_dict in agg_metrics.items(): | ||
scalar_metrics[f"{model_key}/global_mean_crps"] = metric_dict[ | ||
"global_mean_crps" | ||
] | ||
for i, sc in enumerate(metric_dict["global_mean_threshold_brier_score"]): | ||
scalar_metrics[f"{model_key}/global_mean_thres_brier_score_{i}"] = sc | ||
return scalar_metrics |
147 changes: 147 additions & 0 deletions
147
swirl_dynamics/projects/probabilistic_diffusion/inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright 2023 The swirl_dynamics Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Inference modules.""" | ||
|
||
from collections.abc import Sequence | ||
import dataclasses | ||
from typing import Any, Protocol | ||
|
||
import flax.linen as nn | ||
import jax | ||
import numpy as np | ||
from swirl_dynamics.data import hdf5_utils | ||
from swirl_dynamics.lib.diffusion import samplers | ||
from swirl_dynamics.projects.probabilistic_diffusion import trainers | ||
|
||
Array = jax.Array | ||
PyTree = Any | ||
|
||
|
||
class CondSampler(Protocol): | ||
"""The conditional sampler interface.""" | ||
|
||
def __call__( | ||
self, num_samples: int, rng: Array, cond: PyTree, guidance_inputs: PyTree | ||
) -> Array: | ||
... | ||
|
||
|
||
class PreprocTransform(Protocol): | ||
"""The pre-processing transform interface.""" | ||
|
||
def __call__( | ||
self, cond: PyTree, guidance_inputs: PyTree | ||
) -> tuple[PyTree, PyTree]: | ||
... | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class StandardizeCondField: | ||
"""Standardizes a field in the condictioning dict.""" | ||
|
||
cond_field: str | ||
mean: np.ndarray | Array | ||
std: np.ndarray | Array | ||
|
||
def __call__( | ||
self, cond: PyTree, guidance_inputs: PyTree | ||
) -> tuple[PyTree, PyTree]: | ||
cond[self.cond_field] = (cond[self.cond_field] - self.mean) / self.std | ||
return cond, guidance_inputs | ||
|
||
|
||
class PostprocTransform(Protocol): | ||
"""The post-processing transform interface.""" | ||
|
||
def __call__(self, cond_samples: Array) -> Array: | ||
... | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class RescaleSamples: | ||
"""Rescales the samples linearly (inverse standardization).""" | ||
|
||
mean: np.ndarray | Array | ||
std: np.ndarray | Array | ||
|
||
def __call__(self, cond_samples: Array) -> Array: | ||
return cond_samples * self.std + self.mean | ||
|
||
|
||
def chain( | ||
cond_sampler: CondSampler, | ||
preprocessors: Sequence[PreprocTransform] = (), | ||
postprocessors: Sequence[PostprocTransform] = (), | ||
) -> CondSampler: | ||
"""Chains a conditional sampler together with pre- and post-processors.""" | ||
|
||
def _cond_sample( | ||
num_samples: int, rng: Array, cond: PyTree, guidance_inputs: PyTree | ||
) -> Array: | ||
for preproc_fn in preprocessors: | ||
cond, guidance_inputs = preproc_fn(cond, guidance_inputs) | ||
cond_samples = cond_sampler(num_samples, rng, cond, guidance_inputs) | ||
for postproc_fn in postprocessors: | ||
cond_samples = postproc_fn(cond_samples) | ||
return cond_samples | ||
|
||
return _cond_sample | ||
|
||
|
||
def get_trained_denoise_fn( | ||
denoiser: nn.Module, | ||
ckpt_dir: str, | ||
step: int | None = None, | ||
use_ema: bool = True, | ||
) -> samplers.DenoiseFn: | ||
"""Loads checkpoint and creates an inference denoising function.""" | ||
return trainers.DenoisingTrainer.inference_fn_from_state_dict( | ||
state=trainers.DenoisingModelTrainState.restore_from_orbax_ckpt( | ||
ckpt_dir, step=step | ||
), | ||
denoiser=denoiser, | ||
use_ema=use_ema, | ||
) | ||
|
||
|
||
def get_inference_fn_from_sampler( | ||
sampler: samplers.Sampler, **kwargs | ||
) -> CondSampler: | ||
"""Creates a conditional sampling inference function for evaluation.""" | ||
|
||
def _gen_cond_sample( | ||
num_samples: int, rng: Array, cond: PyTree, guidance_inputs: PyTree | ||
) -> Array: | ||
samples, _ = sampler.generate( | ||
num_samples, rng, cond=cond, guidance_inputs=guidance_inputs, **kwargs | ||
) | ||
return samples | ||
|
||
return _gen_cond_sample | ||
|
||
|
||
# **************** | ||
# Helpers | ||
# **************** | ||
|
||
|
||
def read_stats_from_hdf5( | ||
file_path: str, variables: Sequence[str], group: str = "mean" | ||
): | ||
"""Reads statistics from a hdf5 file and applies a 90 degree rotation.""" | ||
variables = [f"{group}/{v}" for v in variables] | ||
arrays = hdf5_utils.read_arrays_as_tuple(file_path, keys=variables) | ||
out = np.stack(arrays, axis=-1) | ||
return np.rot90(out, k=1, axes=(0, 1)) |