From d3fba082f3892eb5e9c44027d3eedb7638551d24 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 18:19:17 +0000 Subject: [PATCH] [Doc] Better doc for non-tensor data handling ghstack-source-id: c25571282d7bb63a14cc7b4ba9fb217785060cc7 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1173 --- docs/source/overview.rst | 265 +++++++++++++++++++++++++++++++-------- 1 file changed, 212 insertions(+), 53 deletions(-) diff --git a/docs/source/overview.rst b/docs/source/overview.rst index aebbf8c66..5b6d100a2 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -1,16 +1,21 @@ Overview ======== -TensorDict makes it easy to organise data and write reusable, generic PyTorch code. Originally developed for TorchRL, we've spun it out into a separate library. +TensorDict makes it easy to organise data and write reusable, generic PyTorch code. Originally developed for TorchRL, +we've spun it out into a separate library. -TensorDict is primarily a dictionary but also a tensor-like class: it supports multiple tensor operations that are mostly shape and storage-related. It is designed to be efficiently serialised or transmitted from node to node or process to process. Finally, it is shipped with its own ``tensordict.nn`` module which is compatible with ``functorch`` and aims at making model ensembling and parameter manipulation easier. +TensorDict is primarily a dictionary but also a tensor-like class: it supports multiple tensor operations that are +mostly shape and storage-related. It is designed to be efficiently serialised or transmitted from node to node or +process to process. Finally, it is shipped with its own :mod:`~tensordict.nn` module which is compatible with ``torch.func`` +and aims at making model ensembling and parameter manipulation easier. -On this page we will motivate ``TensorDict`` and give some examples of what it can do. +On this page we will motivate :class:`~tensordict.TensorDict` and give some examples of what it can do. Motivation ---------- -TensorDict allows you to write generic code modules that are re-usable across paradigms. For instance, the following loop can be re-used across most SL, SSL, UL and RL tasks. +TensorDict allows you to write generic code modules that are re-usable across paradigms. For instance, the following +loop can be re-used across most SL, SSL, UL and RL tasks. >>> for i, tensordict in enumerate(dataset): ... # the model reads and writes tensordicts @@ -20,9 +25,11 @@ TensorDict allows you to write generic code modules that are re-usable across pa ... optimizer.step() ... optimizer.zero_grad() -With its ``tensordict.nn`` module, the package provides many tools to use ``TensorDict`` in a code base with little or no effort. +With its :mod:`~tensordict.nn` module, the package provides many tools to use :class:`~tensordict.TensorDict` in a code +base with little or no effort. -In multiprocessing or distributed settings, ``tensordict`` allows you to seamlessly dispatch data to each worker: +In multiprocessing or distributed settings, :class:`~tensordict.TensorDict` allows you to seamlessly dispatch data to +each worker: >>> # creates batches of 10 datapoints >>> splits = torch.arange(tensordict.shape[0]).split(10) @@ -56,12 +63,15 @@ The nested case is even more compelling: ... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]} ... for i in range(3) -Decomposing the output dictionary in three similarly structured dictionaries after applying the unbind operation quickly becomes significantly cumbersome when working naively with pytree. With tensordict, we provide a simple API for users that want to unbind or split nested structures, rather than computing a nested split / unbound nested structure. +Decomposing the output dictionary in three similarly structured dictionaries after applying the unbind operation quickly +becomes significantly cumbersome when working naively with pytree. With tensordict, we provide a simple API for users +that want to unbind or split nested structures, rather than computing a nested split / unbound nested structure. Features -------- -A ``TensorDict`` is a dict-like container for tensors. To instantiate a ``TensorDict``, you must specify key-value pairs as well as the batch size. The leading dimensions of any values in the ``TensorDict`` must be compatible with the batch size. +A :class:`~tensordict.TensorDict` is a dict-like container for tensors. To instantiate a :class:`~tensordict.TensorDict`, +you must specify key-value pairs as well as the batch size. The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size. >>> import torch >>> from tensordict import TensorDict @@ -81,7 +91,7 @@ a few characters (notice that indexing the nth leading dimensions with tree_map >>> sub_tensordict = tensordict[..., :2] -One can also use the set method with ``inplace=True`` or the ``set_`` method to do inplace updates of the contents. +One can also use the set method with ``inplace=True`` or the :meth:`~tensordict.TensorDict.set_` method to do inplace updates of the contents. The former is a fault-tolerant version of the latter: if no matching key is found, it will write a new one. The contents of the TensorDict can now be manipulated collectively. @@ -93,7 +103,167 @@ To reshape the batch dimensions one can do >>> tensordict = tensordict.reshape(6) -The class supports many other operations, including squeeze, unsqueeze, view, permute, unbind, stack, cat and many more. If an operation is not present, the TensorDict.apply method will usually provide the solution that was needed. +The class supports many other operations, including squeeze, unsqueeze, view, permute, unbind, stack, cat and many more. +If an operation is not present, the TensorDict.apply method will usually provide the solution that was needed. + +Non-tensor data +--------------- + +Tensordict is a powerful library for working with tensor data, but it also supports non-tensor data. This guide will +show you how to use tensordict with non-tensor data. + +Creating a TensorDict with Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can create a TensorDict with non-tensor data using the :class:`~tensordict.NonTensorData` class. + + >>> from tensordict import TensorDict, NonTensorData + >>> import torch + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b=torch.zeros(()), + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +As you can see, the :class:`~tensordict.NonTensorData` object is stored in the TensorDict just like a regular tensor. + +Accessing Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can access the non-tensor data using the key or the get method. Regular `getattr` calls will return the content of +the :class:`~tensordict.NonTensorData` object whereas :meth:`~tensordict.TensorDict.get` will return the +:class:`~tensordict.NonTensorData` object itself. + + >>> print(td["a"]) # prints: a string! + >>> print(td.get("a")) # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None) + + +Batched Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~ + +If you have a batch of non-tensor data, you can store it in a TensorDict with a specified batch size. + + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), + b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + +In this case, we assume that all elements of the tensordict have the same non-tensor data. + + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +To assign a different non-tensor data object to each element in a shaped tensordict, you can use stacks of non-tensor +data. + +Stacked Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~ + +If you have a list of non-tensor data that you want to store in a :class:`~tensordict.TensorDict`, you can use the +:class:`~tensordict.NonTensorStack` class. + + >>> td = TensorDict( + ... a=NonTensorStack("a string!", "another string!", "a third string!"), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorStack( + ['a string!', 'another string!', 'a third string!'..., + batch_size=torch.Size([3]), + device=None), + b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + +You can access the first element and you will get the first of the strings: + + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +In contrast, using :class:`~tensordict.NonTensorData` with a list will not lead to the same result, as there is no +way to tell what to do in general with a non-tensor data that happens to be a list: + + >>> td = TensorDict( + ... a=NonTensorData(["a string!", "another string!", "a third string!"]), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +Stacking TensorDicts with Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To stack non-tensor data, :func:`~torch.stack` will check the identity of the non-tensor objects and produce a single +:class:`~tensordict.NonTensorData` if they match, or a :class:`~tensordict.NonTensorStack` otherwise: + + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b = torch.zeros(()), + ... ) + >>> print(torch.stack([td, td])) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([2]), device=None), + b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + +If you want to make sure the result is a stack, use :meth:`~tensordict.TensorDict.lazy_stack` instead. + + >>> print(TensorDict.lazy_stack([td, td])) + LazyStackedTensorDict( + fields={ + a: NonTensorStack( + ['a string!', 'a string!'], + batch_size=torch.Size([2]), + device=None), + b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2]), + device=None, + is_shared=False, + stack_dim=0) Named dimensions ---------------- @@ -111,7 +281,8 @@ similar to the torch.Tensor dimension name feature: Nested TensorDicts ------------------ -The values in a ``TensorDict`` can themselves be TensorDicts (the nested dictionaries in the example below will be converted to nested TensorDicts). +The values in a :class:`~tensordict.TensorDict` can themselves be TensorDicts (the nested dictionaries in the example +below will be converted to nested TensorDicts). >>> tensordict = TensorDict( ... { @@ -133,7 +304,10 @@ Accessing or setting nested keys can be done with tuples of strings Lazy evaluation --------------- -Some operations on ``TensorDict`` defer execution until items are accessed. For example stacking, squeezing, unsqueezing, permuting batch dimensions and creating a view are not executed immediately on all the contents of the ``TensorDict``. Instead they are performed lazily when values in the ``TensorDict`` are accessed. This can save a lot of unnecessary calculation should the ``TensorDict`` contain many values. +Some operations on :class:`~tensordict.TensorDict` defer execution until items are accessed. For example stacking, +squeezing, unsqueezing, permuting batch dimensions and creating a view are not executed immediately on all the contents +of the :class:`~tensordict.TensorDict`. Instead they are performed lazily when values in the :class:`~tensordict.TensorDict` +are accessed. This can save a lot of unnecessary calculation should the :class:`~tensordict.TensorDict` contain many values. >>> tensordicts = [TensorDict({ ... "a": torch.rand(10), @@ -147,7 +321,10 @@ It also has the advantage that we can manipulate the original tensordicts in a s >>> stacked["a"] = torch.zeros_like(stacked["a"]) >>> assert (tensordicts[0]["a"] == 0).all() -The caveat is that the get method has now become an expensive operation and, if repeated many times, may cause some overhead. One can avoid this by simply calling tensordict.contiguous() after the execution of stack. To further mitigate this, TensorDict comes with its own meta-data class (MetaTensor) that keeps track of the type, shape, dtype and device of each entry of the dict, without performing the expensive operation. +The caveat is that the get method has now become an expensive operation and, if repeated many times, may cause some +overhead. One can avoid this by simply calling tensordict.contiguous() after the execution of stack. To further mitigate +this, TensorDict comes with its own meta-data class (MetaTensor) that keeps track of the type, shape, dtype and device +of each entry of the dict, without performing the expensive operation. Lazy pre-allocation ------------------- @@ -158,14 +335,16 @@ Suppose we have some function foo() -> TensorDict and that we do something like >>> for i in range(N): ... tensordict[i] = foo() -When ``i == 0`` the empty ``TensorDict`` will automatically be populated with empty tensors with batch size N. In subsequent iterations of the loop the updates will all be written in-place. +When ``i == 0`` the empty :class:`~tensordict.TensorDict` will automatically be populated with empty tensors with batch +size N. In subsequent iterations of the loop the updates will all be written in-place. TensorDictModule ---------------- -To make it easy to integrate ``TensorDict`` in one's code base, we provide a tensordict.nn package that allows users to pass ``TensorDict`` instances to ``nn.Module`` objects. +To make it easy to integrate :class:`~tensordict.TensorDict` in one's code base, we provide a tensordict.nn package that allows users to +pass :class:`~tensordict.TensorDict` instances to :class:`~torch.nn.Module` objects (or any callable). -``TensorDictModule`` wraps ``nn.Module`` and accepts a single ``TensorDict`` as an input. You can specify where the underlying module should take its input from, and where it should write its output. This is a key reason we can write reusable, generic high-level code such as the training loop in the motivation section. +:class:`~tensordict.nn.TensorDictModule` wraps :class:`~torch.nn.Module` and accepts a single :class:`~tensordict.TensorDict` as an input. You can specify where the underlying module should take its input from, and where it should write its output. This is a key reason we can write reusable, generic high-level code such as the training loop in the motivation section. >>> from tensordict.nn import TensorDictModule >>> class Net(nn.Module): @@ -191,11 +370,17 @@ To facilitate the adoption of this class, one can also pass the tensors as kwarg >>> tensordict = module(input=torch.randn(32, 100)) -which will return a ``TensorDict`` identical to the one in the previous code box. +which will return a :class:`~tensordict.TensorDict` identical to the one in the previous code box. See :ref:`the export tutorial` for +more context on this feature. -A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be read and where to write it. +A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. +Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be +read and where to write it. -For this purpose, we provide the ``TensorDictSequential`` class which passes data through a sequence of ``TensorDictModules``. Each module in the sequence takes its input from, and writes its output to the original ``TensorDict``, meaning it's possible for modules in the sequence to ignore output from their predecessors, or take additional input from the tensordict as necessary. Here's an example. +For this purpose, we provide the :class:`~tensordict.nn.TensorDictSequential` class which passes data through a +sequence of ``TensorDictModules``. Each module in the sequence takes its input from, and writes its output to the +original :class:`~tensordict.TensorDict`, meaning it's possible for modules in the sequence to ignore output from their +predecessors, or take additional input from the tensordict as necessary. Here's an example: >>> class Net(nn.Module): ... def __init__(self, input_size=100, hidden_size=50, output_size=10): @@ -232,38 +417,12 @@ For this purpose, we provide the ``TensorDictSequential`` class which passes dat >>> intermediate_x = tensordict["intermediate", "x"] >>> probabilities = tensordict["output", "probabilities"] -In this example, the second module combines the output of the first with the mask stored under ("inputs", "mask") in the ``TensorDict``. - -``TensorDictSequential`` offers a bunch of other features: one can access the list of input and output keys by querying the in_keys and out_keys attributes. It is also possible to ask for a sub-graph by querying ``select_subsequence()`` with the desired sets of input and output keys that are desired. This will return another ``TensorDictSequential`` with only the modules that are indispensable to satisfy those requirements. The ``TensorDictModule`` is also compatible with ``vmap`` and other ``functorch`` capabilities. +In this example, the second module combines the output of the first with the mask stored under ("inputs", "mask") in the +:class:`~tensordict.TensorDict`. -Functional Programming ----------------------- - -We provide and API to use ``TensorDict`` in conjunction with ``functorch``. For instance, ``TensorDict`` makes it easy to concatenate model weights to do model ensembling: - ->>> from torch import nn ->>> from tensordict import TensorDict ->>> from tensordict.nn import make_functional ->>> import torch ->>> from torch import vmap ->>> layer1 = nn.Linear(3, 4) ->>> layer2 = nn.Linear(4, 4) ->>> model = nn.Sequential(layer1, layer2) ->>> # we represent the weights hierarchically ->>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(separator=".") ->>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(separator=".") ->>> params = make_functional(model) ->>> # params provided by make_functional match state_dict: ->>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all() ->>> # Let's use our functional module ->>> x = torch.randn(10, 3) ->>> out = model(x, params=params) # params is the last arg (or kwarg) ->>> # an ensemble of models: we stack params along the first dimension... ->>> params_stack = torch.stack([params, params], 0) ->>> # ... and use it as an input we'd like to pass through the model ->>> y = vmap(model, (None, 0))(x, params_stack) ->>> print(y.shape) -torch.Size([2, 10, 4]) - - -The functional API is comparable if not faster than the current ``FunctionalModule`` implemented in ``functorch``. +:class:`~tensordict.nn.TensorDictSequential` offers a bunch of other features: one can access the list of input and +output keys by querying the in_keys and out_keys attributes. It is also possible to ask for a sub-graph by querying +:meth:`~tensordict.nn.TensorDictSequential.select_subsequence` with the desired sets of input and output keys that are desired. This will return another +:class:`~tensordict.nn.TensorDictSequential` with only the modules that are indispensable to satisfy those requirements. +The :class:`~tensordict.nn.TensorDictModule` is also compatible with :func:`~torch.vmap` and other ``torch.func`` +capabilities.