diff --git a/src/continuity/operators/deeponet.py b/src/continuity/operators/deeponet.py index 9b9225ed..91dafd38 100644 --- a/src/continuity/operators/deeponet.py +++ b/src/continuity/operators/deeponet.py @@ -77,7 +77,7 @@ def forward( assert u.shape[1:] == torch.Size([self.shapes.u.num * self.shapes.u.dim]) y = y.flatten(0, 1) - assert u.shape[1:] == torch.Size([self.shapes.u.num * self.shapes.u.dim]) + assert y.shape[-1:] == torch.Size([self.shapes.y.dim]) # Pass through branch and trunk networks b = self.branch(u)