From a83a371cb84952a78a751328215793cf253f5943 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 12 Dec 2024 23:20:02 -0800 Subject: [PATCH] Ensure fused_params does not get modify and affect another sharder (#2633) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2633 Sometimes fused_params can be modified in between sharding for inference. Here, we deepcopy to ensure this doesn't happen Reviewed By: ZhengkaiZ Differential Revision: D67104769 fbshipit-source-id: 9bdcd9112f7630ccd58226318ce06dc4e40102cc --- torchrec/distributed/embedding_types.py | 5 ++++- torchrec/inference/tests/test_inference.py | 21 ++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 099389aa3..800b37185 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import copy from dataclasses import dataclass from enum import Enum, unique from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union @@ -499,7 +500,9 @@ def __init__( shardable_params: Optional[List[str]] = None, ) -> None: super().__init__() - self._fused_params = fused_params + self._fused_params: Optional[Dict[str, Any]] = ( + copy.deepcopy(fused_params) if fused_params is not None else fused_params + ) if not shardable_params: shardable_params = [] self._shardable_params: List[str] = shardable_params diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index b13c32e9f..7ee04d9f0 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -17,7 +17,10 @@ from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES -from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP +from torchrec.distributed.fused_params import ( + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP, + FUSED_PARAM_REGISTER_TBE_BOOL, +) from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.test_utils.test_model import ( @@ -34,6 +37,8 @@ ) from torchrec.inference.modules import ( assign_weights_to_tbe, + DEFAULT_FUSED_PARAMS, + DEFAULT_SHARDERS, get_table_to_weights_from_tbe, quantize_inference_model, set_pruning_data, @@ -392,3 +397,17 @@ def test_quantized_tbe_count_different_pooling(self) -> None: self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1) # Changing this back self.tables[0].pooling = PoolingType.SUM + + def test_fused_params_overwrite(self) -> None: + orig_value = DEFAULT_FUSED_PARAMS[FUSED_PARAM_REGISTER_TBE_BOOL] + + sharders = DEFAULT_SHARDERS + ebc_sharder = sharders[0] + ebc_fused_params = ebc_sharder.fused_params + ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = -1 + + ec_sharder = sharders[1] + ec_fused_params = ec_sharder.fused_params + + # Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params + self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)