Skip to content

Commit

Permalink
PT forward, better error on missing batch dim
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Aug 29, 2023
1 parent 5e8e145 commit 6839833
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,8 @@ def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIf
batch_dim = _get_batch_dim_from_extern_data(self.extern_data)

def _get_tensor_wo_batch_numpy(x: Tensor) -> Tensor:
if batch_dim not in x.dims:
raise Exception(f"Expected {batch_dim} in {x}.")
if x.dims.index(batch_dim) != 0:
x = x.copy_move_axis(x.dims.index(batch_dim), 0)

Expand Down

0 comments on commit 6839833

Please sign in to comment.