diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 099389aa3..800b37185 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import copy from dataclasses import dataclass from enum import Enum, unique from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union @@ -499,7 +500,9 @@ def __init__( shardable_params: Optional[List[str]] = None, ) -> None: super().__init__() - self._fused_params = fused_params + self._fused_params: Optional[Dict[str, Any]] = ( + copy.deepcopy(fused_params) if fused_params is not None else fused_params + ) if not shardable_params: shardable_params = [] self._shardable_params: List[str] = shardable_params diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index b13c32e9f..7ee04d9f0 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -17,7 +17,10 @@ from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES -from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP +from torchrec.distributed.fused_params import ( + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP, + FUSED_PARAM_REGISTER_TBE_BOOL, +) from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.test_utils.test_model import ( @@ -34,6 +37,8 @@ ) from torchrec.inference.modules import ( assign_weights_to_tbe, + DEFAULT_FUSED_PARAMS, + DEFAULT_SHARDERS, get_table_to_weights_from_tbe, quantize_inference_model, set_pruning_data, @@ -392,3 +397,17 @@ def test_quantized_tbe_count_different_pooling(self) -> None: self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1) # Changing this back self.tables[0].pooling = PoolingType.SUM + + def test_fused_params_overwrite(self) -> None: + orig_value = DEFAULT_FUSED_PARAMS[FUSED_PARAM_REGISTER_TBE_BOOL] + + sharders = DEFAULT_SHARDERS + ebc_sharder = sharders[0] + ebc_fused_params = ebc_sharder.fused_params + ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = -1 + + ec_sharder = sharders[1] + ec_fused_params = ec_sharder.fused_params + + # Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params + self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)