diff --git a/torchrec/distributed/tests/test_shards_wrapper.py b/torchrec/distributed/tests/test_shards_wrapper.py index af1fa039f..7199552dd 100644 --- a/torchrec/distributed/tests/test_shards_wrapper.py +++ b/torchrec/distributed/tests/test_shards_wrapper.py @@ -11,9 +11,8 @@ from typing import List, Optional, Union import torch -from hypothesis import settings, Verbosity from torch import distributed as dist -from torch.distributed._tensor._shards_wrapper import LocalShardsWrapper +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.test_utils.multi_process import ( MultiProcessContext, MultiProcessTestBase,