Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 9, 2024
1 parent b2fa826 commit 2d2b68f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
_has_functorch = False
try:
try:
from torch._C._functorch import ( # @manual=fbcode//caffe2:_C
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_remove_batch_dim,
is_batchedtensor,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
_has_functorch = False
try:
try:
from torch._C._functorch import ( # @manual=fbcode//caffe2:_C
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_remove_batch_dim,
is_batchedtensor,
Expand Down
8 changes: 4 additions & 4 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def set_tensor_dict( # noqa: F811

_RESET_OLD_TENSORDICT = True
try:
import torch._functorch.vmap as vmap_src
from torch._functorch.vmap import (
import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch
from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
Expand All @@ -124,7 +124,7 @@ def set_tensor_dict( # noqa: F811
_has_functorch = True
except ImportError:
try:
from functorch._src.vmap import (
from functorch._src.vmap import ( # @manual=fbcode//caffe2/functorch:functorch_src
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
Expand All @@ -136,7 +136,7 @@ def set_tensor_dict( # noqa: F811
)

_has_functorch = True
import functorch._src.vmap as vmap_src
import functorch._src.vmap as vmap_src # @manual=fbcode//caffe2/functorch:functorch_src
except ImportError:
_has_functorch = False

Expand Down
4 changes: 2 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

try:
try:
from torch._C._functorch import ( # @manual=fbcode//caffe2:_C
from torch._C._functorch import ( # @manual=fbcode//caffe2:torch
get_unwrapped,
is_batchedtensor,
)
Expand Down Expand Up @@ -2321,7 +2321,7 @@ class _add_batch_dim_pre_hook:
def __call__(self, mod: torch.nn.Module, args, kwargs):
for name, param in list(mod.named_parameters(recurse=False)):
if hasattr(param, "in_dim") and hasattr(param, "vmap_level"):
from torch._C._functorch import _add_batch_dim
from torch._C._functorch import _add_batch_dim # @manual=//caffe2:_C

param = _add_batch_dim(param, param.in_dim, param.vmap_level)
delattr(mod, name)
Expand Down

0 comments on commit 2d2b68f

Please sign in to comment.