Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691282826
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Oct 30, 2024
1 parent 49ded0f commit 413e460
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,18 @@ def evaluation_pipeline(
n_lat = 121
n_field = len(variables)

if "weighted_norm" in config and config.weighted_norm:
logging.info("Using weighted norm")
lat = jnp.linspace(-90., 90., n_lat)
# Reshapes to the correct broadcast shape.
weighted_norm = (
jnp.cos(jnp.deg2rad(lat)).reshape((1, 1, -1))
)
# In the metric, we handle one field at a time.
weighted_norm = jnp.tile(weighted_norm, (1, n_lon, 1))
else:
weighted_norm = 1.0

logging.info("Batch size per device: %d", batch_size_eval)

for ii, batch in enumerate(eval_dataloader):
Expand Down Expand Up @@ -596,7 +608,11 @@ def evaluation_pipeline(
)

mean_err_dict = metrics.smoothed_average_l1_error(
input_array, output_array, target_array, variables=variables
input_array,
output_array,
target_array,
variables=variables,
weighted_norm=weighted_norm,
)

err_dict = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def smoothed_average_l1_error(
target: jax.Array,
variables: Sequence[str],
window_size: int = 365,
weighted_norm: jax.Array | float = 1.0,
) -> dict[str, dict[str, jax.Array]]:
"""Computes the l1 error of between global averages smoothed in time.
Expand All @@ -293,6 +294,7 @@ def smoothed_average_l1_error(
target: Reference ERA5 data with the same dimensions as the input.
variables: List of physical variables in the snapshots.
window_size: The size of the window for the smoothing in time.
weighted_norm: The weights to use in the global averages, by default is 1.0.
Returns:
A dictionary with the errors per field.
Expand All @@ -316,9 +318,15 @@ def smoothed_average_l1_error(

for field_idx in range(num_fields):
# Compute global averages per snapshot.
era5_mean = np.mean(target[:, :, :, field_idx], axis=(-1, -2))
lens2_mean = np.mean(input_array[:, :, :, field_idx], axis=(-1, -2))
reflow_mean = np.mean(output[:, :, :, field_idx], axis=(-1, -2))
era5_mean = np.mean(
weighted_norm * target[:, :, :, field_idx], axis=(-1, -2)
)
lens2_mean = np.mean(
weighted_norm * input_array[:, :, :, field_idx], axis=(-1, -2)
)
reflow_mean = np.mean(
weighted_norm * output[:, :, :, field_idx], axis=(-1, -2)
)

# Compute the moving average.
conv_window = np.ones(window_size)/window_size
Expand Down
19 changes: 19 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl import flags
from absl import logging
import jax
import jax.numpy as jnp
from ml_collections import config_flags
import optax
from orbax import checkpoint
Expand Down Expand Up @@ -273,6 +274,23 @@ def main(argv):
jax.random.uniform, dtype=jax.numpy.float32
)

# Adds the weighted norm for the loss function.
if "weighted_norm" in config and config.weighted_norm:
lat = jnp.linspace(-90., 90., config.input_shapes[0][2])
# Reshapes to the correct broadcast shape.
if "reg_factor" in config:
reg_factor = config.reg_factor
else:
reg_factor = 0.05
weighted_norm = (
jnp.cos(jnp.deg2rad(lat)).reshape((1, 1, -1, 1)) + reg_factor
)
weighted_norm = jnp.broadcast_to(
weighted_norm, (1,) + config.input_shapes[0][1:]
)
else:
weighted_norm = None

model = models.ConditionalReFlowModel(
# TODO: clean this part.
input_shape=(
Expand All @@ -296,6 +314,7 @@ def main(argv):
time_sampling=time_sampler,
min_train_time=config.min_time, # It should be close to 0.
max_train_time=config.max_time, # It should be close to 1.
weighted_norm=weighted_norm,
)

# Defining the trainer.
Expand Down
32 changes: 29 additions & 3 deletions swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,15 @@ def eval_fn(
v_t = jax.vmap(flow_fn, in_axes=(1, None), out_axes=1)(x_t, time_eval)

# Eq. (1) in [1]. (by default in_axes=0 and out_axes=0 in vmap)
int_losses = jax.vmap(jnp.mean)(jnp.square((x_1 - x_0 - v_t)))
if self.weighted_norm is None:
weighted_norm = 1.0
else:
# Adds extra dimension to the weighted norm to broadcast it.
weighted_norm = self.weighted_norm[None, ...]

int_losses = jax.vmap(jnp.mean)(
weighted_norm * jnp.square((x_1 - x_0 - v_t))
)
eval_losses = {f"time_lvl{i}": loss for i, loss in enumerate(int_losses)}

return eval_losses
Expand Down Expand Up @@ -281,6 +289,17 @@ def initialize(self, rng: Array):
x = jnp.ones((1,) + self.input_shape)
cond = cond_sample_from_shape(self.cond_shape, batch_dims=(1,))

# If weighted norm is provided, it must have the right shape.
if self.weighted_norm is not None:
if (
self.weighted_norm.shape[0] != 1
or self.weighted_norm.shape[1:] != x.shape[1:]
):
raise ValueError(
"Weighted norm shape must be (1, *x.shape[1:]) instead we have"
f" {self.weighted_norm.shape}, with x.shape = {x.shape}"
)

return self.flow_model.init( # add conditional input here.
rng, x=x, sigma=jnp.ones((1,)), cond=cond, is_training=False
)
Expand Down Expand Up @@ -334,8 +353,15 @@ def loss_fn(
is_training=True,
rngs={"dropout": dropout_rng},
)
# Eq. (1) in [1].
loss = jnp.mean(jnp.square((batch["x_1"] - batch["x_0"]) - v_t))

if self.weighted_norm is None:
weighted_norm = 1.0
else:
weighted_norm = self.weighted_norm

# Eq. (1) in [1], but with a possibly weighted norm.
error = ((batch["x_1"] - batch["x_0"]) - v_t)
loss = jnp.mean(weighted_norm * jnp.square(error))
metric = dict(loss=loss)
return loss, (metric, mutables)

Expand Down

0 comments on commit 413e460

Please sign in to comment.