diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index 53c53f3..b11b3c3 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -81,7 +81,6 @@ def get_map(self, x0, x1): x0 = x0.reshape(x0.shape[0], -1) if x1.dim() > 2: x1 = x1.reshape(x1.shape[0], -1) - x1 = x1.reshape(x1.shape[0], -1) M = torch.cdist(x0, x1) ** 2 if self.normalize_cost: M = M / M.max() # should not be normalized when using minibatches