diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 0d558edb1..e0a19547c 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -604,6 +604,7 @@ def __init__( pooling_mode=PoolingMode.NONE, weights_precision=weights_precision, device=device, + table_names=[t.name for t in config.embedding_tables], **fused_params, ) )