From 2fb64bef7c5d7bbdc69df93803b655d4cf272056 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 10 Oct 2023 13:10:18 +0100 Subject: [PATCH] init --- tensordict/tensordict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 0cde92342..e143c07d4 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -7924,6 +7924,8 @@ def sort_keys(element): rename_key = _renamed_inplace_method(rename_key_) def where(self, condition, other, *, out=None, pad=None): + if condition.ndim < self.ndim: + condition = expand_right(condition, self.batch_size) condition = condition.unbind(self.stack_dim) if _is_tensor_collection(other.__class__) or ( isinstance(other, Tensor)