From 48d52d20d149a290e50e7af81b5c09a7a36db13e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 15:45:27 -0700 Subject: [PATCH] [Feature] Propagate `existsok` in memmap* methods ghstack-source-id: 6dcab0ff5e2ae2bb9b8d3bbf18cfb524c51d144d Pull Request resolved: https://github.com/pytorch/tensordict/pull/990 --- tensordict/_lazy.py | 4 ++++ tensordict/_td.py | 20 ++++++++++++++++---- tensordict/base.py | 16 ++++++++++++++++ tensordict/memmap.py | 26 ++++++++++++++------------ tensordict/persistent.py | 4 +++- tensordict/tensorclass.py | 6 ++++++ 6 files changed, 59 insertions(+), 17 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c5f79f337..bc12dc015 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2458,6 +2458,7 @@ def _memmap_( inplace=True, like=False, share_non_tensor, + existsok, ) -> T: if prefix is not None: prefix = Path(prefix) @@ -2489,6 +2490,7 @@ def save_metadata(prefix=prefix, self=self): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) ) if not inplace: @@ -3526,6 +3528,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: @@ -3558,6 +3561,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not inplace: dest = type(self)( diff --git a/tensordict/_td.py b/tensordict/_td.py index 0d9cff420..72e3efbe3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2535,6 +2535,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if prefix is not None: @@ -2569,6 +2570,7 @@ def _memmap_( inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if prefix is not None: _update_metadata( @@ -2585,6 +2587,7 @@ def _memmap_( copy_existing=copy_existing, prefix=prefix, like=like, + existsok=existsok, ) else: futures.append( @@ -2596,6 +2599,7 @@ def _memmap_( copy_existing=copy_existing, prefix=prefix, like=like, + existsok=existsok, ) ) if prefix is not None: @@ -2847,7 +2851,12 @@ def make_memmap_from_storage( return memmap_tensor def make_memmap_from_tensor( - self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True + self, + key: NestedKey, + tensor: torch.Tensor, + *, + copy_data: bool = True, + existsok: bool = True, ) -> MemoryMappedTensor: if not self.is_memmap(): raise RuntimeError( @@ -2876,6 +2885,7 @@ def make_memmap_from_tensor( copy_existing=True, prefix=last_node._memmap_prefix, like=not copy_data, + existsok=existsok, ) _update_metadata( metadata=metadata, @@ -3906,6 +3916,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if prefix is not None: @@ -3936,6 +3947,7 @@ def save_metadata(prefix=prefix, self=self): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not inplace: result = _SubTensorDict(_source, idx=self.idx) @@ -4404,7 +4416,7 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None): # user did specify location and memmap is in wrong place, so we copy -def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): +def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok): filename = None if prefix is None else str(prefix / f"{key}.memmap") if value.is_nested: shape = value._nested_tensor_size() @@ -4416,7 +4428,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): shape, filename=shape_filename, copy_existing=copy_existing, - existsok=True, + existsok=existsok, copy_data=True, ) else: @@ -4425,9 +4437,9 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): value.data if value.requires_grad else value, filename=filename, copy_existing=copy_existing, - existsok=True, copy_data=not like, shape=shape, + existsok=existsok, ) dest._tensordict[key] = memmap_tensor return memmap_tensor diff --git a/tensordict/base.py b/tensordict/base.py index 51d5180a7..3b326caaa 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3211,6 +3211,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: ... def densify(self, layout: torch.layout = torch.strided): @@ -3743,6 +3744,7 @@ def memmap_( num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: """Writes all tensors onto a corresponding memory-mapped Tensor, in-place. @@ -3767,6 +3769,8 @@ def memmap_( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -3799,6 +3803,7 @@ def memmap_( inplace=True, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -3813,6 +3818,7 @@ def memmap_( executor=None, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() @abc.abstractmethod @@ -3935,6 +3941,7 @@ def memmap( num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: """Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict. @@ -3958,6 +3965,8 @@ def memmap( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -3992,6 +4001,7 @@ def memmap( inplace=False, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -4007,6 +4017,7 @@ def memmap( like=False, futures=None, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() def memmap_like( @@ -4014,6 +4025,7 @@ def memmap_like( prefix: str | None = None, copy_existing: bool = False, *, + existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, @@ -4040,6 +4052,8 @@ def memmap_like( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -4089,6 +4103,7 @@ def memmap_like( inplace=False, like=True, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -4106,6 +4121,7 @@ def memmap_like( executor=None, futures=None, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() @classmethod diff --git a/tensordict/memmap.py b/tensordict/memmap.py index f256aba81..9ffe14d46 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -134,12 +134,12 @@ def from_tensor( cls, input, *, - filename=None, - existsok=False, - copy_existing=False, - copy_data=True, - shape=None, - ): + filename: Path | str = None, + existsok: bool = False, + copy_existing: bool = False, + copy_data: bool = True, + shape: torch.Size | None = None, + ): # noqa: D417 """Creates a MemoryMappedTensor with the same content as another tensor. If the tensor is already a MemoryMappedTensor the original tensor is @@ -149,6 +149,8 @@ def from_tensor( Args: input (torch.Tensor): the tensor which content must be copied onto the MemoryMappedTensor. + + Keyword Args: filename (path to a file): the path to the file where the tensor should be stored. If none is provided, a file handler is used instead. @@ -280,12 +282,12 @@ def from_storage( cls, storage, *, - shape=None, - dtype=None, - device=None, - index=None, - filename=None, - handler=None, + shape: torch.Size | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + index: IndexType | None = None, + filename: Path | str = None, + handler: _handler = None, ): if getattr(storage, "filename", None) is not None: if filename is None: diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 2e2c37710..4da8c118d 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -701,6 +701,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if inplace: raise RuntimeError("Cannot call memmap inplace in a persistent tensordict.") @@ -749,6 +750,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): futures=futures, inplace=inplace, share_non_tensor=share_non_tensor, + existsok=existsok, ), inplace=False, validated=True, @@ -776,7 +778,7 @@ def _populate( ), copy_data=not like, copy_existing=copy_existing, - existsok=True, + existsok=existsok, ) tensordict._set_str( key, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index aaccb8869..0f8e72cb5 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -953,6 +953,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ): _non_tensordict = dict(self._non_tensordict) cls = type(self) @@ -997,6 +998,7 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): like=like, copy_existing=copy_existing, share_non_tensor=share_non_tensor, + existsok=existsok, ) if new_futures: futures += new_futures @@ -2816,6 +2818,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ): # For efficiency, we can avoid doing this saving # if the data is already there. @@ -2842,6 +2845,7 @@ def _memmap_( like=like, memmaped=memmaped, share_non_tensor=share_non_tensor, + existsok=existsok, ) _metadata["_share_non_tensor"] = share_non_tensor out._non_tensordict["_metadata"] = _metadata @@ -2967,6 +2971,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: memmaped_leaves = memmaped @@ -3013,6 +3018,7 @@ def save_metadata(prefix=prefix, self=self): # no memmapping should be executed memmaped=memmaped_leaves, share_non_tensor=share_non_tensor, + existsok=existsok, ) ) if not inplace: