diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 8124c1d..3e518ef 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -326,6 +326,7 @@ def compute_mu_t(self, x0, x1, t): [2] Flow Matching for Generative Modelling, ICLR, Lipman et al. """ del x0 + t = pad_t_like_x(t, x1) return t * x1 def compute_sigma_t(self, t): @@ -369,6 +370,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): [1] Flow Matching for Generative Modelling, ICLR, Lipman et al. """ del x0 + t = pad_t_like_x(t, x1) return (x1 - (1 - self.sigma) * xt) / (1 - (1 - self.sigma) * t)