diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6a1666d6daf..2089e517e00 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -445,6 +445,7 @@ def __eq__(self, other): type(self) == type(other) and self.low.dtype == other.low.dtype and self.high.dtype == other.high.dtype + and self.device == other.device and torch.isclose(self.low, other.low).all() and torch.isclose(self.high, other.high).all() ) @@ -1004,6 +1005,8 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): def __eq__(self, other): if not isinstance(other, LazyStackedTensorSpec): return False + if self.device != other.device: + return False if len(self._specs) != len(other._specs): return False for _spec1, _spec2 in zip(self._specs, other._specs): @@ -1593,6 +1596,7 @@ def __init__( def __eq__(self, other): return ( type(other) == type(self) + and self.device == other.device and self.shape == other.shape and self.space == other.space and self.dtype == other.dtype @@ -2825,9 +2829,9 @@ def __eq__(self, other): if isinstance(other, DiscreteTensorSpec): return ( other.n == 2 + and other.device == self.device and other.shape == self.shape and other.dtype == self.dtype - and other.device == self.device ) return False return super().__eq__(other) @@ -3715,7 +3719,7 @@ def __eq__(self, other): return ( type(self) is type(other) and self.shape == other.shape - and self._device == other._device + and self.device == other.device and set(self._specs.keys()) == set(other._specs.keys()) and all((self._specs[key] == spec) for (key, spec) in other._specs.items()) ) @@ -3937,6 +3941,8 @@ def __eq__(self, other): return False if self.stack_dim != other.stack_dim: return False + if self.device != other.device: + return False for _spec1, _spec2 in zip(self._specs, other._specs): if _spec1 != _spec2: return False