diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/models.py b/swirl_dynamics/projects/debiasing/rectified_flow/models.py index bb18fdc..62b0486 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/models.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/models.py @@ -106,6 +106,8 @@ class ReFlowModel(models.BaseModel): This should be close to 1. num_eval_time_levels: Number of times at which the flow will be sampled for each trajectory between x_0 and x_1. + weighted_norm: The norm to use for the loss, if None we use euclidean norm, + otherwise we use weighted norm. """ input_shape: tuple[int, ...] @@ -122,10 +124,23 @@ class ReFlowModel(models.BaseModel): max_eval_time_lvl: float = 1.0 - 1e-4 # It should be close to 1. num_eval_time_levels: ClassVar[int] = 10 + weighted_norm: Array | None = None + def initialize(self, rng: Array): # TODO: Add a dtype object to ensure consistency of types. x = jnp.ones((1,) + self.input_shape) + # If weighted norm is provided, it must have the right shape. + if self.weighted_norm is not None: + if ( + self.weighted_norm.shape[0] != 1 + or self.weighted_norm.shape[1:] != x.shape[1:] + ): + raise ValueError( + "Weighted norm shape must be (1, *x.shape[1:]) instead we have" + f" {self.weighted_norm.shape}, with x.shape = {x.shape}" + ) + return self.flow_model.init( rng, x=x, sigma=jnp.ones((1,)), is_training=False ) @@ -172,8 +187,16 @@ def loss_fn( is_training=True, rngs={"dropout": dropout_rng}, ) - # Eq. (1) in [1]. - loss = jnp.mean(jnp.square((batch["x_1"] - batch["x_0"]) - v_t)) + + # TODO: Define a metric class that incorporates the weights. + if self.weighted_norm is None: + weighted_norm = 1.0 + else: + weighted_norm = self.weighted_norm + + # Eq. (1) in [1], but with a possibly weighted norm. + error = ((batch["x_1"] - batch["x_0"]) - v_t) + loss = jnp.mean(weighted_norm * jnp.square(error)) metric = dict(loss=loss) return loss, (metric, mutables)