Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
1 parent b033acc commit 5a87498
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def __eq__(self, other):
type(self) == type(other)
and self.low.dtype == other.low.dtype
and self.high.dtype == other.high.dtype
and self.device == other.device
and torch.isclose(self.low, other.low).all()
and torch.isclose(self.high, other.high).all()
)
Expand Down Expand Up @@ -1004,6 +1005,8 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec):
def __eq__(self, other):
if not isinstance(other, LazyStackedTensorSpec):
return False
if self.device != other.device:
return False
if len(self._specs) != len(other._specs):
return False
for _spec1, _spec2 in zip(self._specs, other._specs):
Expand Down Expand Up @@ -1593,6 +1596,7 @@ def __init__(
def __eq__(self, other):
return (
type(other) == type(self)
and self.device == other.device
and self.shape == other.shape
and self.space == other.space
and self.dtype == other.dtype
Expand Down Expand Up @@ -2825,9 +2829,9 @@ def __eq__(self, other):
if isinstance(other, DiscreteTensorSpec):
return (
other.n == 2
and other.device == self.device
and other.shape == self.shape
and other.dtype == self.dtype
and other.device == self.device
)
return False
return super().__eq__(other)
Expand Down Expand Up @@ -3715,7 +3719,7 @@ def __eq__(self, other):
return (
type(self) is type(other)
and self.shape == other.shape
and self._device == other._device
and self.device == other.device
and set(self._specs.keys()) == set(other._specs.keys())
and all((self._specs[key] == spec) for (key, spec) in other._specs.items())
)
Expand Down Expand Up @@ -3937,6 +3941,8 @@ def __eq__(self, other):
return False
if self.stack_dim != other.stack_dim:
return False
if self.device != other.device:
return False
for _spec1, _spec2 in zip(self._specs, other._specs):
if _spec1 != _spec2:
return False
Expand Down

0 comments on commit 5a87498

Please sign in to comment.