Skip to content

Commit

Permalink
Ensure fused_params does not get modify and affect another sharder (#…
Browse files Browse the repository at this point in the history
…2633)

Summary:
Pull Request resolved: #2633

Sometimes fused_params can be modified in between sharding for inference. Here, we deepcopy to ensure this doesn't happen

Reviewed By: ZhengkaiZ

Differential Revision: D67104769

fbshipit-source-id: 9bdcd9112f7630ccd58226318ce06dc4e40102cc
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Dec 13, 2024
1 parent 8afe20e commit a83a371
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import copy
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -499,7 +500,9 @@ def __init__(
shardable_params: Optional[List[str]] = None,
) -> None:
super().__init__()
self._fused_params = fused_params
self._fused_params: Optional[Dict[str, Any]] = (
copy.deepcopy(fused_params) if fused_params is not None else fused_params
)
if not shardable_params:
shardable_params = []
self._shardable_params: List[str] = shardable_params
Expand Down
21 changes: 20 additions & 1 deletion torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from torch.fx import symbolic_trace
from torchrec import PoolingType
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP
from torchrec.distributed.fused_params import (
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.test_model import (
Expand All @@ -34,6 +37,8 @@
)
from torchrec.inference.modules import (
assign_weights_to_tbe,
DEFAULT_FUSED_PARAMS,
DEFAULT_SHARDERS,
get_table_to_weights_from_tbe,
quantize_inference_model,
set_pruning_data,
Expand Down Expand Up @@ -392,3 +397,17 @@ def test_quantized_tbe_count_different_pooling(self) -> None:
self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1)
# Changing this back
self.tables[0].pooling = PoolingType.SUM

def test_fused_params_overwrite(self) -> None:
orig_value = DEFAULT_FUSED_PARAMS[FUSED_PARAM_REGISTER_TBE_BOOL]

sharders = DEFAULT_SHARDERS
ebc_sharder = sharders[0]
ebc_fused_params = ebc_sharder.fused_params
ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = -1

ec_sharder = sharders[1]
ec_fused_params = ec_sharder.fused_params

# Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params
self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)

0 comments on commit a83a371

Please sign in to comment.