From 5f1489dd01b22835e3112a721fe9c0bb145c6352 Mon Sep 17 00:00:00 2001 From: Paul Maria Scheikl Date: Wed, 9 Oct 2024 11:08:05 -0400 Subject: [PATCH] Remove duplicate reshape of x1. --- torchcfm/optimal_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index 53c53f3..b11b3c3 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -81,7 +81,6 @@ def get_map(self, x0, x1): x0 = x0.reshape(x0.shape[0], -1) if x1.dim() > 2: x1 = x1.reshape(x1.shape[0], -1) - x1 = x1.reshape(x1.shape[0], -1) M = torch.cdist(x0, x1) ** 2 if self.normalize_cost: M = M / M.max() # should not be normalized when using minibatches