From 6839833858c4a0c2d2874249c2cc58f0701af90c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 29 Aug 2023 13:44:58 +0200 Subject: [PATCH] PT forward, better error on missing batch dim #1385 --- returnn/torch/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index cc7bc586b2..35decc450b 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -677,6 +677,8 @@ def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIf batch_dim = _get_batch_dim_from_extern_data(self.extern_data) def _get_tensor_wo_batch_numpy(x: Tensor) -> Tensor: + if batch_dim not in x.dims: + raise Exception(f"Expected {batch_dim} in {x}.") if x.dims.index(batch_dim) != 0: x = x.copy_move_axis(x.dims.index(batch_dim), 0)