diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/models.py b/swirl_dynamics/projects/debiasing/rectified_flow/models.py index 9dccf18..a2a43be 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/models.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/models.py @@ -192,18 +192,18 @@ def eval_fn( vmap_mult = jax.vmap(jnp.multiply, in_axes=(0, 0)) x_t = vmap_mult(x_1, time_eval) + vmap_mult(x_0, 1 - time_eval) - flow_fn = self.flow_fn(variables, self.flow_model) + flow_fn = self.inference_fn(variables, self.flow_model) v_t = jax.vmap(flow_fn, in_axes=(1, None), out_axes=1)( x_t, time_eval ) # Eq. (1) in [1] int_losses = jax.vmap(jnp.mean)(jnp.square((x_1 - x_0 - v_t))) - eval_losses = {f"sigma_lvl{i}": loss for i, loss in enumerate(int_losses)} + eval_losses = {f"time_lvl{i}": loss for i, loss in enumerate(int_losses)} return eval_losses @staticmethod - def flow_fn(variables: models.PyTree, flow_model: nn.Module): + def inference_fn(variables: models.PyTree, flow_model: nn.Module): """Returns the inference flow function.""" def _flow( diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py b/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py new file mode 100644 index 0000000..0cc9eff --- /dev/null +++ b/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py @@ -0,0 +1,104 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainers for ReFlow models.""" + +from collections.abc import Callable + +from clu import metrics as clu_metrics +import flax +import jax +import optax +from swirl_dynamics.projects.debiasing.rectified_flow import models +from swirl_dynamics.templates import train_states +from swirl_dynamics.templates import trainers + +Array = jax.Array +VariableDict = trainers.VariableDict +TrainState = train_states.BasicTrainState + + +class ReFlowTrainer( + trainers.BasicTrainer[models.ReFlowModel, TrainState] +): + """Single-device trainer for rectified flow models.""" + + @flax.struct.dataclass + class TrainMetrics(clu_metrics.Collection): + train_loss: clu_metrics.Average.from_output("loss") + train_loss_std: clu_metrics.Std.from_output("loss") + + EvalMetrics = clu_metrics.Collection.create( # pylint: disable=invalid-name + **{ + f"eval_time_lvl{i}": clu_metrics.Average.from_output( + f"time_lvl{i}" + ) + for i in range(models.ReFlowModel.num_eval_time_levels) + } + ) + + def initialize_train_state(self, rng: Array) -> TrainState: + init_vars = self.model.initialize(rng) + mutables, params = flax.core.pop(init_vars, "params") + return TrainState.create( + replicate=self.is_distributed, + params=params, + opt_state=self.optimizer.init(params), + flax_mutables=mutables, + ) + + @property + def update_train_state( + self, + ) -> Callable[[TrainState, VariableDict, VariableDict], TrainState]: + """Returns function that updates the train state.""" + + def _update_train_state( + train_state: TrainState, + grads: VariableDict, + mutables: VariableDict, + ) -> TrainState: + updates, new_opt_state = self.optimizer.update( + grads, train_state.opt_state, train_state.params + ) + new_params = optax.apply_updates(train_state.params, updates) + + return train_state.replace( + step=train_state.step + 1, + opt_state=new_opt_state, + params=new_params, + flax_mutables=mutables, + ) + + return _update_train_state + + @staticmethod + def inference_fn_from_state_dict( + state: TrainState, *args, **kwargs + ): + return models.ReFlowModel.inference_fn( + state.model_variables, *args, **kwargs + ) + + +class DistributedReFlowTrainer( + ReFlowTrainer, + trainers.BasicDistributedTrainer[models.ReFlowModel, TrainState], +): + """Multi-device trainer for rectified flow models.""" + + # TODO(lzepedanunez): Write a test for this trainer. + + # MRO: ReFlowTrainer > BasicDistributedTrainer > BasicTrainer + ...