diff --git a/models/yolo.py b/models/yolo.py index a28c7d305..e68aec221 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -485,10 +485,13 @@ def forward(self, x): mc = [torch.cat([self.cv6[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2), torch.cat([self.cv7[i](x[self.nl+i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)] # mask coefficients - d = self.detect(self, x[:-2]) + x = self.detect(self, x[:-2]) if self.training: return d, mc, p - return (torch.cat([d[0][1], mc[1]], 1), (d[1][1], mc[1], p[1])) + if self.export: + return (torch.cat([x[1], mc[1]], 1), p[1]) + else: + return (torch.cat([x[0][1], mc[1]], 1), (x[1][1], mc[1], p[1])) class Panoptic(Detect):