Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597375760
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Jan 10, 2024
1 parent e134884 commit 9e3fc82
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
6 changes: 3 additions & 3 deletions swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,18 @@ def eval_fn(

vmap_mult = jax.vmap(jnp.multiply, in_axes=(0, 0))
x_t = vmap_mult(x_1, time_eval) + vmap_mult(x_0, 1 - time_eval)
flow_fn = self.flow_fn(variables, self.flow_model)
flow_fn = self.inference_fn(variables, self.flow_model)
v_t = jax.vmap(flow_fn, in_axes=(1, None), out_axes=1)(
x_t, time_eval
)

# Eq. (1) in [1]
int_losses = jax.vmap(jnp.mean)(jnp.square((x_1 - x_0 - v_t)))
eval_losses = {f"sigma_lvl{i}": loss for i, loss in enumerate(int_losses)}
eval_losses = {f"time_lvl{i}": loss for i, loss in enumerate(int_losses)}
return eval_losses

@staticmethod
def flow_fn(variables: models.PyTree, flow_model: nn.Module):
def inference_fn(variables: models.PyTree, flow_model: nn.Module):
"""Returns the inference flow function."""

def _flow(
Expand Down
104 changes: 104 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/trainers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2023 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Trainers for ReFlow models."""

from collections.abc import Callable

from clu import metrics as clu_metrics
import flax
import jax
import optax
from swirl_dynamics.projects.debiasing.rectified_flow import models
from swirl_dynamics.templates import train_states
from swirl_dynamics.templates import trainers

Array = jax.Array
VariableDict = trainers.VariableDict
TrainState = train_states.BasicTrainState


class ReFlowTrainer(
trainers.BasicTrainer[models.ReFlowModel, TrainState]
):
"""Single-device trainer for rectified flow models."""

@flax.struct.dataclass
class TrainMetrics(clu_metrics.Collection):
train_loss: clu_metrics.Average.from_output("loss")
train_loss_std: clu_metrics.Std.from_output("loss")

EvalMetrics = clu_metrics.Collection.create( # pylint: disable=invalid-name
**{
f"eval_time_lvl{i}": clu_metrics.Average.from_output(
f"time_lvl{i}"
)
for i in range(models.ReFlowModel.num_eval_time_levels)
}
)

def initialize_train_state(self, rng: Array) -> TrainState:
init_vars = self.model.initialize(rng)
mutables, params = flax.core.pop(init_vars, "params")
return TrainState.create(
replicate=self.is_distributed,
params=params,
opt_state=self.optimizer.init(params),
flax_mutables=mutables,
)

@property
def update_train_state(
self,
) -> Callable[[TrainState, VariableDict, VariableDict], TrainState]:
"""Returns function that updates the train state."""

def _update_train_state(
train_state: TrainState,
grads: VariableDict,
mutables: VariableDict,
) -> TrainState:
updates, new_opt_state = self.optimizer.update(
grads, train_state.opt_state, train_state.params
)
new_params = optax.apply_updates(train_state.params, updates)

return train_state.replace(
step=train_state.step + 1,
opt_state=new_opt_state,
params=new_params,
flax_mutables=mutables,
)

return _update_train_state

@staticmethod
def inference_fn_from_state_dict(
state: TrainState, *args, **kwargs
):
return models.ReFlowModel.inference_fn(
state.model_variables, *args, **kwargs
)


class DistributedReFlowTrainer(
ReFlowTrainer,
trainers.BasicDistributedTrainer[models.ReFlowModel, TrainState],
):
"""Multi-device trainer for rectified flow models."""

# TODO(lzepedanunez): Write a test for this trainer.

# MRO: ReFlowTrainer > BasicDistributedTrainer > BasicTrainer
...

0 comments on commit 9e3fc82

Please sign in to comment.