diff --git a/tensordict/base.py b/tensordict/base.py index 48b306520..358cae1b1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3142,6 +3142,92 @@ def _set_device(self, device: torch.device) -> T: value._set_device(device=device) return self + @cache # noqa: B019 + def param_count(self, *, count_duplicates: bool = True) -> int: + """Counts the number of parameters (total number of indexable items), accounting for tensors only. + + Keyword Args: + count_duplicates (bool): Whether to count duplicated tensor as independent or not. + If ``False``, only strictly identical tensors will be discarded (same views but different + ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed + to be a single copy). + + """ + vals = self._values_list(True, True) + total = 0 + if not count_duplicates: + vals = set(vals) + for v in vals: + total += v.numel() + return total + + @cache # noqa: B019 + def bytes(self, *, count_duplicates: bool = True) -> int: + """Counts the number of bytes of the contained tensors. + + Keyword Args: + count_duplicates (bool): Whether to count duplicated tensor as independent or not. + If ``False``, only strictly identical tensors will be discarded (same views but different + ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed + to be a single copy). + + """ + set_of_tensors = set() if not count_duplicates else [] + + def add(tensor): + if count_duplicates: + set_of_tensors.append(tensor) + else: + set_of_tensors.add(tensor) + + def count_bytes(tensor): + if tensor.is_nested: + if not tensor.layout == torch.jagged: + raise RuntimeError( + "NTs that are not jagged are not supported by the bytes method. Please use the jagged layout instead " + "or raise and issue on https://github.com/pytorch/tensordict/issues instead." + ) + attrs, ctx = tensor.__tensor_flatten__() + for attr in attrs: + t = getattr(tensor, attr) + count_bytes(t) + return + if isinstance(tensor, torch.Tensor): + if isinstance(tensor, MemoryMappedTensor): + add(tensor) + return + if type(tensor) is not torch.Tensor: + try: + attrs, ctx = tensor.__tensor_flatten__() + for attr in attrs: + t = getattr(tensor, attr) + count_bytes(t) + return + except AttributeError: + warnings.warn( + "The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it " + "impossible to count the bytes it contains. Falling back on regular count.", + category=UserWarning, + ) + count_bytes(torch.as_tensor(tensor)) + return + + grad = getattr(tensor, "grad", None) + if grad is not None: + count_bytes(grad) + count_bytes(tensor.data) + else: + add(tensor) + return + + vals = self._values_list(True, True) + for v in vals: + count_bytes(v) + total = 0 + for tensor in set_of_tensors: + total += tensor.numel() * tensor.dtype.itemsize + return total + def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T: """Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e67b48416..0f1f65b5d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -132,6 +132,17 @@ mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn" +@pytest.fixture +def device_fixture(): + device = torch.get_default_device() + if torch.cuda.is_available(): + torch.set_default_device(torch.device("cuda:0")) + elif torch.backends.mps.is_available(): + torch.set_default_device(torch.device("mps:0")) + yield + torch.set_default_device(device) + + def _compare_tensors_identity(td0, td1): if isinstance(td0, LazyStackedTensorDict): if not isinstance(td1, LazyStackedTensorDict): @@ -242,7 +253,32 @@ def test_batchsize_reset(self): td_u.batch_size = [1] td_u.to_tensordict().batch_size = [1] - def test_depth(ggself): + @pytest.mark.parametrize("count_duplicates", [False, True]) + def test_bytes(self, count_duplicates, device_fixture): + tensor = torch.zeros(3) + tensor_with_grad = torch.ones(3, requires_grad=True) + (tensor_with_grad + 1).sum().backward() + v = torch.ones(3) * 2 # 12 bytes + offsets = torch.tensor([0, 1, 3]) # 24 bytes + lengths = torch.tensor([1, 2]) # 16 bytes + njt = torch.nested.nested_tensor_from_jagged( + v, offsets, lengths=lengths + ) # 52 bytes + tricky = torch.nested.nested_tensor_from_jagged( + tensor, offsets, lengths=lengths + ) # 52 bytes or 0 + td = TensorDict( + tensor=tensor, # 3 * 4 = 12 bytes + tensor_with_grad=tensor_with_grad, # 3 * 4 * 2 = 24 bytes + njt=njt, # 32 + tricky=tricky, # 32 or 0 + ) + if count_duplicates: + assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 52 + else: + assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0 + + def test_depth(self): td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_() assert td.depth == 3 with td.unlock_(): @@ -1903,6 +1939,16 @@ def test_pad_sequence_pad_dim1(self, make_mask): else: assert "masks" not in padded_td.keys() + @pytest.mark.parametrize("count_duplicates", [False, True]) + def test_param_count(self, count_duplicates): + td = TensorDict(a=torch.randn(3), b=torch.randn(6)) + td["c"] = td["a"] + assert len(td._values_list(True, True)) == 3 + if count_duplicates: + assert td.param_count(count_duplicates=count_duplicates) == 12 + else: + assert td.param_count(count_duplicates=count_duplicates) == 9 + @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device): torch.manual_seed(1)