Skip to content

Commit

Permalink
Remove unused imports
Browse files Browse the repository at this point in the history
Differential Revision: D67111062
  • Loading branch information
yhshin authored and facebook-github-bot committed Dec 12, 2024
1 parent 4a227d0 commit c4005c9
Show file tree
Hide file tree
Showing 15 changed files with 13 additions and 34 deletions.
1 change: 0 additions & 1 deletion torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#!/usr/bin/env python3

import argparse
import copy
import logging
import os
import time
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/embedding_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from enum import Enum, unique
from typing import Dict, List

from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
ShardedEmbeddingTable,
)
from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType


Expand Down
12 changes: 6 additions & 6 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, Optional

import torch

Expand Down Expand Up @@ -55,7 +55,7 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool:


def get_fused_param_tbe_row_alignment(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[int]:
if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params:
return None
Expand All @@ -64,7 +64,7 @@ def get_fused_param_tbe_row_alignment(


def fused_param_bounds_check_mode(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[BoundsCheckMode]:
if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params:
return None
Expand All @@ -73,7 +73,7 @@ def fused_param_bounds_check_mode(


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> bool:
if (
fused_params is None
Expand All @@ -85,7 +85,7 @@ def fused_param_lengths_to_offsets_lookup(


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> bool:
return (
fused_params
Expand All @@ -95,7 +95,7 @@ def is_fused_param_quant_state_dict_split_scale_bias(


def tbe_fused_params(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
if not fused_params:
return None
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from torchrec.distributed.utils import init_parameters
from torchrec.modules.utils import reset_module_states_post_sharding
from torchrec.types import CacheMixin


def _join_module_path(path: str, name: str) -> str:
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DTensorMetadata,
EmbeddingComputeKernel,
InputDistOutputs,
KJTList,
ShardedEmbeddingTable,
)
from torchrec.distributed.sharding.tw_sharding import (
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from torchrec.distributed.embedding_lookup import (
GroupedEmbeddingsLookup,
InferCPUGroupedEmbeddingsLookup,
InferGroupedEmbeddingsLookup,
)
from torchrec.distributed.embedding_sharding import (
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tensor_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import cast, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
from torch import distributed as dist
Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from torchrec.distributed.types import (
EmbeddingModuleShardingPlan,
EnumerableShardingSpec,
ModuleSharder,
ShardedTensor,
ShardingEnv,
Expand Down Expand Up @@ -288,7 +287,6 @@ def sharding_single_rank_test(
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
) -> None:

with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
(global_model, inputs) = gen_model_and_input(
Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@
)
from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder
from torchrec.distributed.quant_embeddingbag import (
QuantEmbeddingBagCollection,
QuantEmbeddingBagCollectionSharder,
QuantFeatureProcessedEmbeddingBagCollectionSharder,
ShardedQuantEmbeddingBagCollection,
)
from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec
from torchrec.distributed.shard import _shard_modules
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Any, cast, Dict, List, Optional, Tuple
from typing import cast, Dict, Optional, Tuple

import hypothesis.strategies as st
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,21 @@


import unittest
from typing import cast, Dict, List, Optional, OrderedDict, Tuple
from typing import cast, OrderedDict

import hypothesis.strategies as st
import torch
from hypothesis import given, settings, Verbosity
from torch import distributed as dist, nn
from torchrec import distributed as trec_dist
from torchrec.distributed import DistributedModelParallel
from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.model_parallel import get_default_sharders
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.test_utils.test_model import ModelInput
from torchrec.distributed.test_utils.test_model_parallel_base import (
ModelParallelSingleRankBase,
)
from torchrec.distributed.tests.test_sequence_model import (
TestEmbeddingCollectionSharder,
TestSequenceSparseNN,
)
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.modules.embedding_configs import DataType, EmbeddingConfig


Expand Down
3 changes: 0 additions & 3 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import contextlib
import logging
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Expand All @@ -31,7 +30,6 @@
)

import torch
import torchrec.distributed.comm_ops
from torch.autograd.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torchrec.distributed.model_parallel import ShardedModule
Expand Down Expand Up @@ -760,7 +758,6 @@ def _grad_swap(self) -> None:
param.grad = grad

def _init_embedding_streams(self) -> None:

for _ in self._pipelined_modules:
self._embedding_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
Expand Down
1 change: 0 additions & 1 deletion torchrec/inference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import grpc
import predictor_pb2, predictor_pb2_grpc
import torch
from torch.utils.data import DataLoader
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.datasets.random import RandomRecDataset
from torchrec.datasets.utils import Batch
Expand Down
1 change: 0 additions & 1 deletion torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from torch.fx import symbolic_trace
from torchrec import PoolingType
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.fused_params import (
Expand Down
2 changes: 1 addition & 1 deletion torchrec/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#!/usr/bin/env python3

import abc
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional

import torch

Expand Down

0 comments on commit c4005c9

Please sign in to comment.