From c4005c929997b3eb0d12e6574220baf02283ce67 Mon Sep 17 00:00:00 2001 From: yhshin Date: Thu, 12 Dec 2024 10:25:29 -0800 Subject: [PATCH] Remove unused imports Differential Revision: D67111062 --- torchrec/distributed/benchmark/benchmark_train.py | 1 - torchrec/distributed/embedding_dim_bucketer.py | 5 +---- torchrec/distributed/fused_params.py | 12 ++++++------ torchrec/distributed/shard.py | 1 - torchrec/distributed/sharding/cw_sharding.py | 1 - .../distributed/sharding/rw_sequence_sharding.py | 1 - torchrec/distributed/tensor_sharding.py | 2 +- torchrec/distributed/test_utils/test_sharding.py | 2 -- torchrec/distributed/tests/test_infer_shardings.py | 2 -- .../distributed/tests/test_quant_model_parallel.py | 2 +- .../test_sequence_model_parallel_single_rank.py | 11 +++-------- .../distributed/train_pipeline/train_pipelines.py | 3 --- torchrec/inference/client.py | 1 - torchrec/inference/tests/test_inference.py | 1 - torchrec/ir/types.py | 2 +- 15 files changed, 13 insertions(+), 34 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index d8cf35d00..15ea780f2 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -10,7 +10,6 @@ #!/usr/bin/env python3 import argparse -import copy import logging import os import time diff --git a/torchrec/distributed/embedding_dim_bucketer.py b/torchrec/distributed/embedding_dim_bucketer.py index 21283a445..ef2f58b15 100644 --- a/torchrec/distributed/embedding_dim_bucketer.py +++ b/torchrec/distributed/embedding_dim_bucketer.py @@ -10,10 +10,7 @@ from enum import Enum, unique from typing import Dict, List -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - ShardedEmbeddingTable, -) +from torchrec.distributed.embedding_types import ShardedEmbeddingTable from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 26af33938..71b6b4786 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, Optional import torch @@ -55,7 +55,7 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool: def get_fused_param_tbe_row_alignment( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[int]: if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params: return None @@ -64,7 +64,7 @@ def get_fused_param_tbe_row_alignment( def fused_param_bounds_check_mode( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[BoundsCheckMode]: if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params: return None @@ -73,7 +73,7 @@ def fused_param_bounds_check_mode( def fused_param_lengths_to_offsets_lookup( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> bool: if ( fused_params is None @@ -85,7 +85,7 @@ def fused_param_lengths_to_offsets_lookup( def is_fused_param_quant_state_dict_split_scale_bias( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> bool: return ( fused_params @@ -95,7 +95,7 @@ def is_fused_param_quant_state_dict_split_scale_bias( def tbe_fused_params( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: if not fused_params: return None diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 4c44ae221..a755d2c8b 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -29,7 +29,6 @@ ) from torchrec.distributed.utils import init_parameters from torchrec.modules.utils import reset_module_states_post_sharding -from torchrec.types import CacheMixin def _join_module_path(path: str, name: str) -> str: diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 940f1a0ca..44c2096e9 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -32,7 +32,6 @@ DTensorMetadata, EmbeddingComputeKernel, InputDistOutputs, - KJTList, ShardedEmbeddingTable, ) from torchrec.distributed.sharding.tw_sharding import ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 38b68c3ed..1d9fb71d5 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -17,7 +17,6 @@ ) from torchrec.distributed.embedding_lookup import ( GroupedEmbeddingsLookup, - InferCPUGroupedEmbeddingsLookup, InferGroupedEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( diff --git a/torchrec/distributed/tensor_sharding.py b/torchrec/distributed/tensor_sharding.py index 9c211c5aa..6a7ca0715 100644 --- a/torchrec/distributed/tensor_sharding.py +++ b/torchrec/distributed/tensor_sharding.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import cast, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import distributed as dist diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index dbd8f1007..4b0aedfd6 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -41,7 +41,6 @@ ) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, - EnumerableShardingSpec, ModuleSharder, ShardedTensor, ShardingEnv, @@ -288,7 +287,6 @@ def sharding_single_rank_test( world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, ) -> None: - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. (global_model, inputs) = gen_model_and_input( diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 9bf975d3e..83b4649ee 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -37,10 +37,8 @@ ) from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder from torchrec.distributed.quant_embeddingbag import ( - QuantEmbeddingBagCollection, QuantEmbeddingBagCollectionSharder, QuantFeatureProcessedEmbeddingBagCollectionSharder, - ShardedQuantEmbeddingBagCollection, ) from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec from torchrec.distributed.shard import _shard_modules diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index 7dc4746de..131b9d3d6 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Any, cast, Dict, List, Optional, Tuple +from typing import cast, Dict, Optional, Tuple import hypothesis.strategies as st import torch diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py index 8e3699825..26ca8c55b 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py @@ -9,18 +9,13 @@ import unittest -from typing import cast, Dict, List, Optional, OrderedDict, Tuple +from typing import cast, OrderedDict import hypothesis.strategies as st import torch from hypothesis import given, settings, Verbosity -from torch import distributed as dist, nn -from torchrec import distributed as trec_dist -from torchrec.distributed import DistributedModelParallel +from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.distributed.test_utils.test_model_parallel_base import ( ModelParallelSingleRankBase, ) @@ -28,7 +23,7 @@ TestEmbeddingCollectionSharder, TestSequenceSparseNN, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import DataType, EmbeddingConfig diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index e9189bd3f..d42a2e9ac 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -11,7 +11,6 @@ import contextlib import logging from collections import deque -from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, @@ -31,7 +30,6 @@ ) import torch -import torchrec.distributed.comm_ops from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable from torchrec.distributed.model_parallel import ShardedModule @@ -760,7 +758,6 @@ def _grad_swap(self) -> None: param.grad = grad def _init_embedding_streams(self) -> None: - for _ in self._pipelined_modules: self._embedding_streams.append( (torch.get_device_module(self._device).Stream(priority=0)) diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py index 725338f46..50bdc09ea 100644 --- a/torchrec/inference/client.py +++ b/torchrec/inference/client.py @@ -11,7 +11,6 @@ import grpc import predictor_pb2, predictor_pb2_grpc import torch -from torch.utils.data import DataLoader from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset from torchrec.datasets.utils import Batch diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index 7ee04d9f0..5c4563185 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -14,7 +14,6 @@ import torch from fbgemm_gpu.split_embedding_configs import SparseType -from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.distributed.fused_params import ( diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index caa7a35a5..7dc1695b9 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch