From 2d2b68f3ceb7ba1c8cc3ff48022db2939d31a598 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 10:23:40 +0100 Subject: [PATCH] amend --- tensordict/_lazy.py | 2 +- tensordict/_td.py | 2 +- tensordict/nn/functional_modules.py | 8 ++++---- tensordict/utils.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index bbab03702..c9fe7cbd8 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -92,7 +92,7 @@ _has_functorch = False try: try: - from torch._C._functorch import ( # @manual=fbcode//caffe2:_C + from torch._C._functorch import ( # @manual=fbcode//caffe2:torch _add_batch_dim, _remove_batch_dim, is_batchedtensor, diff --git a/tensordict/_td.py b/tensordict/_td.py index ff42b5930..be8f2dc67 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -101,7 +101,7 @@ _has_functorch = False try: try: - from torch._C._functorch import ( # @manual=fbcode//caffe2:_C + from torch._C._functorch import ( # @manual=fbcode//caffe2:torch _add_batch_dim, _remove_batch_dim, is_batchedtensor, diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 49e11778b..bf891dbea 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -109,8 +109,8 @@ def set_tensor_dict( # noqa: F811 _RESET_OLD_TENSORDICT = True try: - import torch._functorch.vmap as vmap_src - from torch._functorch.vmap import ( + import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch + from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch _add_batch_dim, _broadcast_to_and_flatten, _get_name, @@ -124,7 +124,7 @@ def set_tensor_dict( # noqa: F811 _has_functorch = True except ImportError: try: - from functorch._src.vmap import ( + from functorch._src.vmap import ( # @manual=fbcode//caffe2/functorch:functorch_src _add_batch_dim, _broadcast_to_and_flatten, _get_name, @@ -136,7 +136,7 @@ def set_tensor_dict( # noqa: F811 ) _has_functorch = True - import functorch._src.vmap as vmap_src + import functorch._src.vmap as vmap_src # @manual=fbcode//caffe2/functorch:functorch_src except ImportError: _has_functorch = False diff --git a/tensordict/utils.py b/tensordict/utils.py index 86737798e..080f686bd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -69,7 +69,7 @@ try: try: - from torch._C._functorch import ( # @manual=fbcode//caffe2:_C + from torch._C._functorch import ( # @manual=fbcode//caffe2:torch get_unwrapped, is_batchedtensor, ) @@ -2321,7 +2321,7 @@ class _add_batch_dim_pre_hook: def __call__(self, mod: torch.nn.Module, args, kwargs): for name, param in list(mod.named_parameters(recurse=False)): if hasattr(param, "in_dim") and hasattr(param, "vmap_level"): - from torch._C._functorch import _add_batch_dim + from torch._C._functorch import _add_batch_dim # @manual=//caffe2:_C param = _add_batch_dim(param, param.in_dim, param.vmap_level) delattr(mod, name)