Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690705718
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Oct 28, 2024
1 parent 7f84285 commit b07ffdc
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class ReFlowModel(models.BaseModel):
This should be close to 1.
num_eval_time_levels: Number of times at which the flow will be sampled for
each trajectory between x_0 and x_1.
weighted_norm: The norm to use for the loss, if None we use euclidean norm,
otherwise we use weighted norm.
"""

input_shape: tuple[int, ...]
Expand All @@ -122,10 +124,23 @@ class ReFlowModel(models.BaseModel):
max_eval_time_lvl: float = 1.0 - 1e-4 # It should be close to 1.
num_eval_time_levels: ClassVar[int] = 10

weighted_norm: Array | None = None

def initialize(self, rng: Array):
# TODO: Add a dtype object to ensure consistency of types.
x = jnp.ones((1,) + self.input_shape)

# If weighted norm is provided, it must have the right shape.
if self.weighted_norm is not None:
if (
self.weighted_norm.shape[0] != 1
or self.weighted_norm.shape[1:] != x.shape[1:]
):
raise ValueError(
"Weighted norm shape must be (1, *x.shape[1:]) instead we have"
f" {self.weighted_norm.shape}, with x.shape = {x.shape}"
)

return self.flow_model.init(
rng, x=x, sigma=jnp.ones((1,)), is_training=False
)
Expand Down Expand Up @@ -172,8 +187,16 @@ def loss_fn(
is_training=True,
rngs={"dropout": dropout_rng},
)
# Eq. (1) in [1].
loss = jnp.mean(jnp.square((batch["x_1"] - batch["x_0"]) - v_t))

# TODO: Define a metric class that incorporates the weights.
if self.weighted_norm is None:
weighted_norm = 1.0
else:
weighted_norm = self.weighted_norm

# Eq. (1) in [1], but with a possibly weighted norm.
error = ((batch["x_1"] - batch["x_0"]) - v_t)
loss = jnp.mean(weighted_norm * jnp.square(error))
metric = dict(loss=loss)
return loss, (metric, mutables)

Expand Down

0 comments on commit b07ffdc

Please sign in to comment.