Skip to content

Commit

Permalink
Quant FPEBC with automatic TBE registration and per_table_weight_dtype (
Browse files Browse the repository at this point in the history
pytorch#2400)

Summary:
Pull Request resolved: pytorch#2400

Exposing TBE in Quant FPEBC and supporting per_table_weight_dtype for specific quantization of FPEBC tables

Reviewed By: ZhengkaiZ

Differential Revision: D62773718

fbshipit-source-id: 6136082e0e6a9dad086236b2fe110ab63e942be8
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Sep 19, 2024
1 parent e401bea commit 3262651
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
17 changes: 15 additions & 2 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,19 @@ def _quantize_fp_module(
fp_module_fqn: str,
activation_dtype: torch.dtype = torch.float,
weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE,
per_fp_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None,
) -> None:
"""
If FeatureProcessedEmbeddingBagCollection is found, quantize via direct module swap.
"""
fp_module.qconfig = quant.QConfig(

quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection])
fp_module.qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=activation_dtype),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
per_table_weight_dtype=per_fp_table_weight_dtype,
)

# ie. "root.submodule.feature_processed_mod" -> "root.submodule", "feature_processed_mod"
fp_ebc_parent_fqn, fp_ebc_name = fp_module_fqn.rsplit(".", 1)
fp_ebc_parent = getattr_recursive(model, fp_ebc_parent_fqn)
Expand All @@ -447,7 +452,15 @@ def _quantize_fp_module(
additional_mapping[type(m)] = quantization_mapping[typename]
elif typename == FEATURE_PROCESSED_EBC_TYPE:
# handle the fp ebc separately
_quantize_fp_module(model, m, n, weight_dtype=fp_weight_dtype)
_quantize_fp_module(
model,
m,
n,
weight_dtype=fp_weight_dtype,
# Pass in per_fp_table_weight_dtype if it is provided, perhaps
# fpebc parameters are also in here
per_fp_table_weight_dtype=per_table_weight_dtype,
)

quant_prep_enable_register_tbes(model, list(additional_mapping.keys()))
quantize_embeddings(
Expand Down
40 changes: 40 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from argparse import Namespace

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.test_utils.test_model import (
Expand Down Expand Up @@ -175,3 +176,42 @@ def test_set_pruning_data(self) -> None:
spec[1],
pruning_dict[spec[0]],
)

def test_quantize_per_table_dtype(self) -> None:
max_feature_lengths = {}

# First two tables as FPEBC
max_feature_lengths[self.tables[0].name] = 100
max_feature_lengths[self.tables[1].name] = 100

model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=torch.device("cpu"),
sparse_device=torch.device("cpu"),
over_arch_clazz=TestOverArchRegroupModule,
max_feature_lengths=max_feature_lengths,
)

per_table_dtype = {}

for table in self.tables + self.weighted_tables:
# quint4x2 different than int8, which is default
per_table_dtype[table.name] = torch.quint4x2

quantized_model = quantize_inference_model(
model, per_table_weight_dtype=per_table_dtype
)

num_tbes = 0
# Check EBC configs and TBE for correct shapes
for module in quantized_model.modules():
if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen":
num_tbes += 1
for i, spec in enumerate(module.embedding_specs):
self.assertEqual(spec[3], SparseType.INT4)

# 3 TBES (1 FPEBC, 2 EBCs (1 weighted, 1 unweighted))

self.assertEqual(num_tbes, 3)

0 comments on commit 3262651

Please sign in to comment.