hbm (torchrec.distributed.planner.types.Storage attribute)
-
- HBM (torchrec.distributed.types.ParameterStorage attribute)
|
+ - HBM (torchrec.distributed.types.ParameterStorage attribute)
+
- hbm_mem_bw (torchrec.distributed.planner.types.Topology property)
- HeteroEmbeddingShardingPlanner (class in torchrec.distributed.planner.planners)
@@ -3278,6 +3282,8 @@
S
- stride_per_key_per_rank() (torchrec.sparse.jagged_tensor.KeyedJaggedTensor method)
- SUM (torchrec.modules.embedding_configs.PoolingType attribute)
+
+ - supported_fields (torchrec.distributed.planner.types.CustomTopologyData attribute)
- SwishLayerNorm (class in torchrec.modules.activation)
diff --git a/objects.inv b/objects.inv
index 308d482ae..d9fb69c48 100644
Binary files a/objects.inv and b/objects.inv differ
diff --git a/searchindex.js b/searchindex.js
index ff30051ba..2c01bca1e 100644
--- a/searchindex.js
+++ b/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["index", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"pytorch": [0, 3, 9, 10, 12], "domain": 0, "librari": [0, 7], "built": [0, 9], "provid": [0, 3, 4, 5, 7, 9, 11], "common": [0, 9, 12], "sparsiti": 0, "parallel": [0, 3, 5], "primit": [0, 3, 5], "need": [0, 3, 5, 6, 7, 9, 10, 11, 12], "larg": [0, 4], "scale": 0, "recommend": 0, "system": [0, 3, 4], "recsi": [0, 8, 10], "It": [0, 3, 4, 5, 7, 9, 10, 11, 12], "allow": [0, 3, 4, 6, 9, 10], "author": [0, 3, 7], "train": [0, 3, 4, 5, 7, 8, 9, 10, 11, 12], "model": [0, 3, 4, 5, 6, 7, 9, 10, 11], "embed": [0, 4, 5, 6, 8, 9, 11, 12], "shard": [0, 3, 4, 7, 9, 10, 11], "across": [0, 3, 4, 5], "mani": [0, 3, 5], "gpu": [0, 3, 4, 7], "For": [0, 3, 4, 5, 8, 9, 10, 11, 12], "instal": 0, "instruct": 0, "visit": 0, "http": [0, 3, 4, 8, 9, 12], "github": [0, 9], "com": [0, 9], "readm": 0, "In": [0, 3, 4, 9, 10, 12], "thi": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "we": [0, 3, 4, 5, 6, 7, 9, 10, 11, 12], "introduc": [0, 10], "primari": [0, 7], "call": [0, 3, 4, 5, 7, 9, 10, 11], "distributedmodelparallel": [0, 3], "dmp": [0, 3], "like": [0, 3, 4, 5, 6, 9, 10, 12], "s": [0, 3, 4, 6, 7, 8, 9, 10, 11, 12], "distributeddataparallel": 0, "wrap": [0, 3, 5, 10], "enabl": [0, 3, 4, 10], "distribut": [0, 7, 9, 10, 12], "sourc": [0, 8, 9], "open": 0, "googl": 0, "colab": 0, "dataset": [0, 3, 4], "criteo": 0, "movielen": 0, "random": [0, 4], "util": [0, 5], "script": [0, 12], "contiguous_preproc_criteo": 0, "npy_preproc_criteo": 0, "collective_util": 0, "comm": 0, "comm_op": 0, "dist_data": 0, "embedding_lookup": 0, "embedding_shard": 0, "embedding_typ": 0, "embeddingbag": [0, 4, 6, 8, 9, 11], "grouped_position_weight": 0, "model_parallel": 0, "quant_embeddingbag": 0, "train_pipelin": 0, "type": [0, 5, 6, 7, 8, 9, 11, 12], "mc_modul": 0, "mc_embeddingbag": 0, "mc_embed": 0, "planner": [0, 3], "constant": [0, 10], "enumer": [0, 3, 9, 10], "partition": 0, "perf_model": 0, "propos": [0, 8], "shard_estim": 0, "stat": [0, 3], "storage_reserv": 0, "cw_shard": 0, "dp_shard": 0, "rw_shard": 0, "tw_shard": 0, "twcw_shard": 0, "twrw_shard": 0, "fx": [0, 7], "tracer": 0, "modul": [0, 3, 4, 5], "infer": [0, 3, 4, 5, 11, 12], "model_packag": 0, "deepfm": 0, "dlrm": [0, 7], "activ": [0, 11], "crossnet": 0, "embedding_config": [0, 3, 11], "embedding_modul": 0, "feature_processor": [0, 3, 5, 11], "lazy_extens": 0, "mlp": [0, 8], "mc_embedding_modul": 0, "optim": [0, 3, 4, 9, 11], "clip": 0, "fuse": [0, 3, 5], "kei": [0, 3, 5, 7, 8, 9, 11, 12], "warmup": 0, "quant": [0, 3], "spars": [0, 3, 5, 8, 9, 11], "jagged_tensor": [0, 3], "index": [0, 9, 12], "search": [0, 4], "page": 0, "necessari": [3, 4, 5], "oper": [3, 4, 5, 6, 9, 12], "These": [3, 4, 7, 9], "includ": [3, 4, 6, 7, 9, 12], "through": [3, 6, 10], "collect": [3, 5, 8, 9, 10, 11], "all": [3, 4, 5, 7, 8, 9, 10, 12], "reduc": [3, 5, 9, 11], "scatter": [3, 5], "wrapper": [3, 10], "featur": [3, 4, 5, 8, 9, 11, 12], "kjt": [3, 4, 5, 8, 9, 11, 12], "variou": [3, 7, 9], "implement": [3, 4, 5, 7, 9, 10, 12], "shardedembeddingbag": 3, "nn": [3, 4, 6, 9, 11], "shardedembeddingbagcollect": [3, 9, 11], "embeddingbagcollect": [3, 8, 9, 11], "sharder": [3, 4], "defin": [3, 5, 7, 8, 9], "ani": [3, 4, 5, 6, 7, 9, 10, 12], "support": [3, 4, 5, 6, 9, 10], "comput": [3, 4, 5, 7, 8, 9, 11], "kernel": [3, 4, 9], "which": [3, 4, 5, 7, 9, 10, 12], "ar": [3, 4, 5, 7, 9, 10, 11, 12], "devic": [3, 4, 5, 6, 7, 8, 9, 11, 12], "cpu": [3, 4], "mai": [3, 12], "batch": [3, 4, 5, 6, 7, 8, 9, 11, 12], "togeth": [3, 9], "tabl": [3, 4, 5, 6, 8, 9, 11], "fusion": 3, "pipelin": [3, 4, 9, 12], "trainpipelinesparsedist": 3, "overlap": 3, "dataload": 3, "transfer": 3, "copi": [3, 5, 7, 9, 10, 12], "inter": [3, 9], "commun": [3, 4, 5], "input_dist": [3, 9], "forward": [3, 4, 5, 7, 8, 9, 11, 12], "backward": [3, 4, 6, 10], "increas": 3, "perform": [3, 4, 5, 7, 9, 10, 11], "quantiz": [3, 5, 6, 11], "precis": [3, 9, 11], "file": 3, "contain": [3, 4, 7, 9, 10, 11], "construct": [3, 6, 9, 12], "base": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "control": [3, 6], "flow": [3, 6], "invoke_on_rank_and_broadcast_result": 3, "pg": [3, 4, 5], "processgroup": [3, 4, 5], "rank": [3, 4, 5, 9, 10, 12], "int": [3, 4, 5, 6, 8, 9, 10, 11, 12], "func": 3, "callabl": [3, 5, 6, 9, 10, 11], "t": [3, 4, 5, 6, 7, 9, 10, 12], "arg": [3, 4, 7, 9, 11, 12], "kwarg": [3, 9, 12], "invok": [3, 4], "function": [3, 4, 5, 6, 7, 9, 10, 12], "design": [3, 7, 9], "broadcast": [3, 4], "result": [3, 4, 5, 7, 9, 11], "member": [3, 9], "within": [3, 4, 5, 7, 9, 12], "group": [3, 4, 5, 9, 10, 12], "exampl": [3, 4, 5, 7, 8, 9, 10, 11, 12], "id": [3, 4, 5, 9], "0": [3, 4, 5, 8, 9, 10, 11, 12], "allocate_id": 3, "is_lead": 3, "option": [3, 4, 5, 6, 7, 9, 10, 11, 12], "leader_rank": 3, "bool": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "check": [3, 4, 9, 10, 12], "current": [3, 4, 7, 9], "processs": 3, "leader": 3, "paramet": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "dist": [3, 5], "process": [3, 4, 5, 8, 9, 11], "us": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "determin": [3, 4, 5], "being": [3, 4, 7, 9], "none": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "impli": 3, "onli": [3, 4, 5, 9, 12], "e": [3, 4, 5, 6, 7, 8, 9, 10], "g": [3, 4, 7, 9, 10], "singl": [3, 4, 5, 9, 10], "program": 3, "definit": [3, 6, 7], "default": [3, 4, 6, 7, 8, 9, 10, 11, 12], "The": [3, 4, 5, 6, 7, 8, 9, 10, 12], "caller": 3, "can": [3, 4, 7, 9, 10, 12], "overrid": [3, 4, 6, 7], "context": [3, 5, 12], "specif": [3, 4, 7, 10], "run_on_lead": 3, "get_group_rank": 3, "world_siz": [3, 4, 5], "get": [3, 4, 5], "worker": 3, "also": [3, 4, 7, 9, 10], "avail": [3, 4, 5], "group_rank": 3, "environ": [3, 7], "varibl": 3, "A": [3, 4, 5, 6, 7, 10, 12], "number": [3, 4, 5, 8, 9, 12], "between": [3, 4, 7, 8, 9, 12], "get_num_group": 3, "see": [3, 4, 5, 6, 9, 12], "org": [3, 4, 8, 12], "doc": [3, 12], "stabl": [3, 12], "elast": 3, "run": [3, 4, 5, 7, 9, 10], "html": [3, 12], "get_local_rank": 3, "local": [3, 4, 5, 9], "usual": [3, 4, 9], "its": [3, 4, 5, 9, 10, 12], "node": [3, 6], "get_local_s": 3, "equival": 3, "max_nnod": 3, "intra_and_cross_node_pg": 3, "backend": [3, 5, 7], "str": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "tupl": [3, 4, 5, 6, 7, 9, 10, 11, 12], "creat": [3, 6, 7, 9, 10, 12], "sub": 3, "intra": 3, "cross": [3, 9], "class": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "all2alldenseinfo": 3, "output_split": [3, 5], "list": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "batch_siz": [3, 4, 5, 9, 12], "input_shap": 3, "input_split": [3, 5], "object": [3, 4, 7, 9, 10], "data": [3, 4, 5, 6, 7, 9, 10, 11, 12], "attribut": [3, 4, 10], "when": [3, 4, 6, 9, 10], "alltoall_dens": 3, "all2allpooledinfo": 3, "batch_size_per_rank": [3, 5], "dim_sum_per_rank": [3, 5], "dim_sum_per_rank_tensor": 3, "tensor": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "cumsum_dim_sum_per_rank_tensor": 3, "codec": [3, 5], "quantizedcommcodec": [3, 5], "alltoall_pool": [3, 5], "size": [3, 4, 5, 8, 9, 11, 12], "each": [3, 4, 5, 8, 9, 11, 12], "sum": [3, 4, 5, 9], "dimens": [3, 4, 5, 8, 9, 11, 12], "version": [3, 9, 11], "fast": 3, "_recat_pooled_embedding_grad_out": 3, "cumul": [3, 12], "all2allsequenceinfo": 3, "embedding_dim": [3, 4, 5, 8, 9, 11], "lengths_after_sparse_data_all2al": 3, "forward_recat_tensor": 3, "backward_recat_tensor": 3, "variable_batch_s": 3, "fals": [3, 4, 5, 7, 9, 10, 11, 12], "permuted_lengths_after_sparse_data_all2al": 3, "alltoall_sequ": 3, "length": [3, 4, 5, 8, 9, 11, 12], "after": [3, 4, 5, 9], "alltoal": [3, 5], "recat": [3, 5, 12], "input": [3, 4, 5, 6, 7, 8, 9, 11, 12], "split": [3, 4, 5, 7, 12], "output": [3, 4, 5, 7, 8, 9, 11, 12], "whether": [3, 4, 6, 9, 11], "variabl": [3, 5, 9, 11, 12], "befor": [3, 5, 9, 10], "all2allvinfo": 3, "dims_sum_per_rank": 3, "b_global": 3, "b_local": 3, "b_local_list": 3, "d_local_list": 3, "input_split_s": 3, "factori": [3, 4, 9], "output_split_s": 3, "alltoallv": 3, "global": [3, 4, 5], "my": 3, "rememb": [3, 12], "how": [3, 4, 5, 7, 10], "do": [3, 4, 9, 10, 12], "all_to_all_singl": 3, "fill": 3, "all2all_pooled_req": 3, "static": [3, 4, 10, 12], "ctx": 3, "unus": 3, "formula": 3, "differenti": 3, "mode": [3, 4], "automat": [3, 4, 12], "overridden": [3, 5, 7, 9], "subclass": [3, 5, 7, 9, 10], "vjp": 3, "must": [3, 4, 5, 7, 9], "accept": [3, 4, 7, 9], "first": [3, 4, 5, 9, 10, 12], "argument": [3, 6, 7, 9], "follow": [3, 4, 5, 8, 9, 10, 12], "return": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "pass": [3, 4, 5, 7, 9, 10, 11, 12], "non": [3, 4, 5, 6, 9, 11], "should": [3, 4, 5, 7, 8, 9, 10, 12], "were": 3, "gradient": [3, 4, 10], "w": [3, 5, 9, 12], "r": [3, 9], "given": [3, 4, 5, 6, 9], "valu": [3, 4, 5, 6, 8, 9, 10, 11, 12], "correspond": [3, 4, 5, 7, 9, 12], "If": [3, 4, 7, 9, 10, 12], "an": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "requir": [3, 4, 9, 10], "grad": [3, 10], "you": [3, 5, 6, 12], "just": [3, 4, 8, 9, 12], "retriev": 3, "save": [3, 9, 10], "dure": [3, 4, 10], "ha": [3, 4, 9, 12], "needs_input_grad": 3, "boolean": 3, "repres": [3, 4, 7, 8, 9, 11, 12], "have": [3, 4, 5, 8, 9, 10, 12], "true": [3, 4, 7, 9, 10, 12], "myreq": 3, "request": [3, 7, 10], "a2ai": 3, "input_embed": [3, 9], "custom": [3, 6, 9], "autograd": [3, 7, 9], "There": 3, "two": [3, 4, 9, 12], "wai": [3, 4], "usag": [3, 4], "1": [3, 4, 5, 8, 9, 10, 11, 12], "combin": [3, 9, 10], "staticmethod": 3, "def": [3, 9], "other": [3, 4, 10], "more": [3, 4, 5, 9], "detail": [3, 4, 5, 9], "2": [3, 4, 5, 8, 9, 10, 11, 12], "separ": 3, "setup_context": 3, "longer": [3, 4], "instead": [3, 5, 7, 9, 10], "torch": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "handl": [3, 4, 5, 6, 7, 9, 10], "set": [3, 4, 7, 9, 10], "up": [3, 11], "extend": 3, "store": [3, 4, 5, 12], "arbitrari": 3, "directli": [3, 10], "though": 3, "enforc": [3, 7, 9], "compat": [3, 6, 10], "either": [3, 9], "save_for_backward": 3, "thei": [3, 12], "intend": 3, "save_for_forward": 3, "jvp": 3, "all2all_pooled_wait": 3, "grad_output": 3, "dummy_tensor": 3, "all2all_seq_req": 3, "sharded_input_embed": 3, "all2all_seq_req_wait": 3, "sharded_grad_output": 3, "all2allv_req": 3, "all2allv_wait": 3, "allgatherbaseinfo": 3, "input_s": [3, 9], "all_gatther_base_pool": 3, "allgatherbase_req": 3, "agi": 3, "allgatherbase_wait": 3, "reducescatterbaseinfo": 3, "reduce_scatter_base_pool": 3, "flatten": [3, 5, 9], "reducescatterbase_req": 3, "rsi": 3, "reducescatterbase_wait": 3, "reducescatterinfo": 3, "reduce_scatter_pool": 3, "produc": [3, 4], "reducescattervinfo": 3, "equal_split": 3, "total_input_s": 3, "reduce_scatter_v_pool": 3, "along": [3, 5, 10, 12], "dim": [3, 5], "total": [3, 4, 5], "reducescatterv_req": 3, "reducescatterv_wait": 3, "reducescatter_req": 3, "reducescatter_wait": 3, "await": [3, 5, 6], "variablebatchall2allpooledinfo": 3, "batch_size_per_rank_per_featur": [3, 5], "batch_size_per_feature_pre_a2a": [3, 5], "emb_dim_per_rank_per_featur": [3, 5], "variable_batch_alltoall_pool": [3, 5], "per": [3, 4, 5, 9, 12], "variable_batch_all2all_pooled_req": 3, "variable_batch_all2all_pooled_wait": 3, "all2all_pooled_sync": 3, "all2all_sequence_sync": 3, "all2allv_sync": 3, "all_gather_base_pool": 3, "gather": [3, 5], "from": [3, 4, 5, 6, 7, 9, 10, 12], "form": [3, 9, 11], "pool": [3, 4, 5, 8, 9, 11, 12], "output_tensor_s": 3, "work": [3, 4, 7, 9, 12], "async": [3, 5], "wait": [3, 5], "later": [3, 9], "experiment": [3, 9], "subject": 3, "chang": [3, 9, 10], "all_gather_base_sync": 3, "a2a_pooled_embs_tensor": 3, "world": [3, 5], "Then": 3, "concaten": [3, 5, 9, 12], "receiv": [3, 10], "Its": 3, "shape": [3, 5, 9, 12], "b": [3, 4, 5, 8, 9, 11, 12], "x": [3, 4, 5, 8, 9, 11, 12], "d_local_sum": 3, "where": [3, 4, 5, 9, 11], "a2a_sequence_embs_tensor": 3, "sequenc": [3, 4, 5], "doe": [3, 8, 9, 10, 12], "mix": 3, "out_split": 3, "per_rank_split_length": 3, "one": [3, 4, 5, 7, 8, 9, 10], "differ": [3, 4, 5, 9, 10, 12], "specifi": [3, 4, 5, 6, 9, 10], "assumpt": [3, 12], "emb": 3, "same": [3, 4, 5, 7, 8, 9, 12], "fn": [3, 9], "get_gradient_divis": 3, "reduce_scatter_base_sync": 3, "chunk": [3, 5], "reduce_scatter_sync": 3, "reduce_scatter_v_per_feature_pool": 3, "v": [3, 5, 9, 12], "d": [3, 8, 9, 11, 12], "unevenli": 3, "accord": [3, 4, 5, 7, 8, 10, 12], "reduce_scatter_v_sync": 3, "set_gradient_divis": 3, "val": 3, "variable_batch_all2all_pooled_sync": 3, "embeddingsalltoon": [3, 5], "cat_dim": [3, 5, 12], "merg": [3, 5], "buffer": [3, 5, 7, 9], "alloc": [3, 5, 7], "topolog": [3, 4, 5], "would": [3, 5, 12], "alltoon": [3, 5], "set_devic": [3, 5], "device_str": [3, 5], "embeddingsalltoonereduc": [3, 5], "kjtalltoal": [3, 5], "stagger": [3, 5, 12], "redistribut": [3, 5], "keyedjaggedtensor": [3, 5, 8, 9, 11, 12], "part": [3, 4, 5, 9, 10], "kjtalltoallsplitsawait": [3, 5], "transmit": [3, 5], "correct": [3, 5, 12], "space": [3, 4, 5, 8], "kjtalltoalltensorsawait": [3, 5], "actual": [3, 4, 5, 7, 9], "asynchron": [3, 5], "len": [3, 5, 8], "indic": [3, 5, 7, 9, 10, 11, 12], "send": [3, 5], "assum": [3, 4, 5, 7, 8, 10], "order": [3, 4, 5, 7, 9, 12], "destin": [3, 5, 7, 9], "appli": [3, 5, 8, 9], "_get_recat": [3, 5], "c": [3, 5, 7, 12], "kjta2a": [3, 5], "rank0_input": [3, 5], "hold": [3, 4, 5, 10, 12], "v0": [3, 5, 12], "v1": [3, 5, 9, 12], "v2": [3, 5, 9, 12], "rank1_input": [3, 5], "v3": [3, 5, 12], "v4": [3, 5, 12], "rank0_output": [3, 5], "3": [3, 4, 5, 8, 9, 10, 11, 12], "4": [3, 4, 5, 8, 9, 11, 12], "5": [3, 5, 8, 9, 11, 12], "rank1_output": [3, 5], "relev": [3, 4, 5], "issu": [3, 5, 9], "second": [3, 4, 5, 9, 12], "label": [3, 5], "tensor_split": [3, 5], "input_tensor": [3, 5], "dict": [3, 4, 5, 6, 7, 9, 10, 11, 12], "ie": [3, 4, 5, 9, 12], "stride_per_rank": [3, 5, 12], "stride": [3, 5, 12], "case": [3, 4, 5, 9, 10, 12], "kjtonetoal": [3, 5], "onetoal": [3, 5], "essenti": [3, 5, 12], "p2p": [3, 5], "keyjaggedtensor": [3, 5], "them": [3, 5, 7, 9, 10], "kjtlist": [3, 5], "slice": [3, 5, 6, 12], "pooledembeddingsallgath": [3, 5], "layout": [3, 5, 6], "want": [3, 5], "nccl": [3, 5], "happen": [3, 5], "init_distribut": [3, 5], "new_group": [3, 5], "randn": [3, 5, 8, 9], "m": [3, 5, 6, 9], "local_emb": [3, 5], "pooledembeddingsawait": [3, 5], "num_bucket": [3, 5], "pooledembeddingsalltoal": [3, 5], "callback": [3, 5], "a2a": [3, 5], "t0": [3, 5], "rand": [3, 5, 8], "6": [3, 4, 5, 8, 9, 11, 12], "t1": [3, 5, 8, 9, 11], "print": [3, 5, 9, 11], "properti": [3, 4, 5, 7, 9, 10, 11], "tensor_await": [3, 5], "pooledembeddingsreducescatt": [3, 5], "row": [3, 4, 5], "wise": [3, 4, 5, 9], "twrw": [3, 4, 5], "over": [3, 5, 9, 10], "unequ": [3, 5], "bucket": [3, 5], "seqembeddingsalltoon": [3, 5], "concat": [3, 5, 9, 12], "sequenceembeddingsalltoal": [3, 5], "features_per_rank": [3, 5], "sharding_ctx": [3, 5], "sequenceshardingcontext": [3, 5], "lengths_after_input_dist": [3, 5], "unbucketize_permute_tensor": [3, 5], "sparse_features_recat": [3, 5], "sequenceembeddingsawait": [3, 5], "permut": [3, 5, 12], "splitsalltoallawait": [3, 5], "variablebatchpooledembeddingsalltoal": [3, 5], "kjt_split": [3, 5], "24": [3, 5], "r0_batch_siz": [3, 5], "r1_batch_siz": [3, 5], "f_0": [3, 5], "f_1": [3, 5], "f_2": [3, 5], "r0_batch_size_per_rank_per_featur": [3, 5], "r1_batch_size_per_rank_per_featur": [3, 5], "r0_batch_size_per_feature_pre_a2a": [3, 5], "r1_batch_size_per_feature_pre_a2a": [3, 5], "r0": [3, 5], "r1": [3, 5], "16": [3, 5, 9, 11], "14": [3, 5], "post": [3, 5], "rank_0": [3, 5], "rank_1": [3, 5], "variablebatchpooledembeddingsreducescatt": [3, 5], "rw": [3, 4, 5, 9], "1d": [3, 4, 5], "multipli": [3, 4, 5], "batch_size_r0_f0": [3, 5], "emb_dim_f0": [3, 5], "embeddingcollectionawait": 3, "lazyawait": 3, "jaggedtensor": [3, 9, 11, 12], "embeddingcollectioncontext": 3, "sharding_context": 3, "input_featur": 3, "reverse_indic": [3, 9], "multistream": 3, "record_stream": [3, 12], "stream": [3, 12], "gener": [3, 4, 6, 7, 8, 9, 10, 12], "embeddingcollectionshard": 3, "fused_param": [3, 5], "qcomm_codecs_registri": [3, 5], "use_index_dedup": 3, "baseembeddingshard": 3, "embeddingcollect": [3, 9, 11], "module_typ": [3, 11], "param": [3, 10], "parametershard": 3, "env": [3, 5], "shardingenv": [3, 5], "shardedembeddingcollect": [3, 9, 11], "locat": 3, "replic": [3, 4, 5], "embeddingmoduleshardingplan": 3, "fulli": [3, 4, 10], "qualifi": 3, "name": [3, 4, 7, 8, 9, 10, 11, 12], "path": [3, 4, 7], "spec": 3, "shardedmodul": 3, "shardable_paramet": 3, "sharding_typ": [3, 4, 9], "compute_device_typ": 3, "shardingtyp": [3, 4, 9], "well": [3, 4, 9], "known": [3, 4, 9], "table_name_to_parameter_shard": 3, "shardedembeddingmodul": 3, "fusedoptimizermodul": [3, 10], "public": [3, 9], "api": [3, 6, 9], "manual": [3, 10], "dist_input": 3, "compute_and_output_dist": 3, "multipl": [3, 4, 9, 10], "make": [3, 9, 10], "sens": [3, 10], "method": [3, 6, 7, 9], "initi": [3, 9, 10], "distibut": 3, "soon": 3, "complet": [3, 4], "create_context": 3, "fused_optim": [3, 10], "keyedoptim": [3, 10], "output_dist": 3, "reset_paramet": [3, 9], "create_embedding_shard": 3, "sharding_info": [3, 5], "embeddingshardinginfo": [3, 5], "embeddingshard": [3, 5], "create_sharding_infos_by_shard": 3, "embeddingcollectioninterfac": [3, 9, 11], "get_ec_index_dedup": 3, "set_ec_index_dedup": 3, "commopgradientsc": 3, "functionctx": 3, "scale_gradient_factor": 3, "groupedembeddingslookup": 3, "grouped_config": 3, "groupedembeddingconfig": [3, 5], "baseembeddinglookup": [3, 5], "lookup": [3, 4, 8, 9, 11], "i": [3, 4, 5, 6, 8, 9], "flush": 3, "sparse_featur": [3, 5, 8], "everi": [3, 4, 5, 7, 9], "although": [3, 5, 7, 9], "recip": [3, 5, 7, 9], "instanc": [3, 5, 6, 7, 9], "afterward": [3, 5, 7, 9], "sinc": [3, 5, 7, 9], "former": [3, 5, 7, 9], "take": [3, 4, 5, 7, 9, 10], "care": [3, 5, 7, 9], "regist": [3, 5, 6, 7, 9], "hook": [3, 5, 7, 9], "while": [3, 5, 6, 7, 9], "latter": [3, 5, 7, 9], "silent": [3, 5, 7, 9], "ignor": [3, 4, 5, 7, 9], "load_state_dict": [3, 10], "state_dict": [3, 7, 9, 10], "ordereddict": [3, 6, 7, 9], "union": [3, 4, 6, 7, 9, 10], "shardedtensor": [3, 10], "strict": [3, 10], "_incompatiblekei": 3, "descend": [3, 4], "exactli": 3, "match": [3, 4, 7, 9], "assign": [3, 12], "unless": [3, 10], "get_swap_module_params_on_convers": 3, "persist": [3, 7, 9], "strictli": [3, 9], "preserv": [3, 9], "state": [3, 7, 9, 10], "except": [3, 4, 9], "requires_grad": 3, "field": [3, 9, 10, 12], "missing_kei": 3, "expect": [3, 4, 8, 9], "miss": [3, 4], "unexpected_kei": 3, "present": [3, 10], "namedtupl": 3, "exist": [3, 5, 7, 12], "rais": 3, "runtimeerror": 3, "named_buff": [3, 9], "prefix": [3, 7, 9], "recurs": [3, 9], "remove_dupl": [3, 9], "iter": [3, 4, 9, 10], "yield": [3, 9], "both": [3, 7, 8, 9, 10, 12], "itself": [3, 9], "prepend": [3, 9], "submodul": [3, 9, 10], "otherwis": [3, 4, 7, 9, 10, 12], "direct": [3, 9], "remov": [3, 6, 9], "duplic": [3, 9, 10], "xdoctest": [3, 7, 9], "skip": [3, 7, 9, 10], "undefin": [3, 7, 9], "var": [3, 7, 9], "buf": [3, 9], "self": [3, 4, 9, 12], "running_var": [3, 9], "named_paramet": 3, "bia": [3, 7, 9], "named_parameters_by_t": 3, "tablebatchedembeddingslic": 3, "table_nam": 3, "embedding_weight": 3, "cw": [3, 4], "weight": [3, 4, 5, 7, 9, 10, 11, 12], "compos": [3, 7, 9], "prefetch": [3, 4], "forward_stream": 3, "purg": 3, "keep_var": [3, 7, 9], "dictionari": [3, 7, 9], "refer": [3, 7, 9, 12], "whole": [3, 7, 9], "averag": [3, 4, 7, 9], "shallow": [3, 7, 9], "posit": [3, 4, 5, 7, 9], "howev": [3, 7, 9, 10], "deprec": [3, 7, 9], "keyword": [3, 7, 9], "futur": [3, 7, 9], "releas": [3, 7, 9], "pleas": [3, 4, 7, 9, 12], "avoid": [3, 7, 9, 10], "end": [3, 4, 7, 9], "user": [3, 4, 7, 9, 10], "updat": [3, 4, 7, 9, 10], "ad": [3, 7, 9, 10], "detach": [3, 7, 9], "groupedpooledembeddingslookup": 3, "basegroupedfeatureprocessor": [3, 5, 9], "scale_weight_gradi": 3, "infercpugroupedembeddingslookup": 3, "grouped_configs_per_rank": 3, "infergroupedlookupmixin": 3, "tbetoregistermixin": 3, "get_tbes_to_regist": 3, "intnbittablebatchedembeddingbagscodegen": 3, "infergroupedembeddingslookup": 3, "abc": [3, 4, 7, 9, 10], "infergroupedpooledembeddingslookup": 3, "metainfergroupedembeddingslookup": 3, "meta": [3, 4], "tbe": [3, 4, 11], "op": [3, 4, 5, 10, 11], "metainfergroupedpooledembeddingslookup": 3, "bag": [3, 5, 6, 8, 9], "embeddings_cat_empty_rank_handl": 3, "dummy_embs_tensor": 3, "embeddings_cat_empty_rank_handle_infer": 3, "dtype": [3, 4, 5, 6, 7, 9, 11, 12], "fx_wrap_tensor_view2d": 3, "dim0": 3, "dim1": 3, "baseembeddingdist": [3, 5], "convert": [3, 6, 7, 12], "embeddinglookup": 3, "abstract": [3, 4, 7, 9, 10], "basesparsefeaturesdist": [3, 5], "f": [3, 4, 5, 8, 9, 11], "featureshardingmixin": 3, "table_wis": [3, 9], "create_input_dist": [3, 5], "create_lookup": [3, 5], "create_output_dist": [3, 5], "embedding_nam": [3, 5, 9], "embedding_names_per_rank": [3, 5], "embedding_shard_metadata": [3, 5], "shardmetadata": [3, 5], "embedding_t": [3, 5], "shardedembeddingt": [3, 5], "uncombined_embedding_dim": [3, 5], "uncombined_embedding_nam": [3, 5], "embeddingshardingcontext": [3, 5], "variable_batch_per_featur": 3, "embeddingtableconfig": [3, 9], "param_shard": 3, "nonetyp": [3, 9], "fusedkjtlistsplitsawait": 3, "kjtlistsplitsawait": 3, "kjtlistawait": 3, "info": [3, 9], "metadata": [3, 7, 9], "kjtsplitsalltoallmeta": 3, "distributed_c10d": 3, "_input": 3, "splits_tensor": 3, "listofkjtlistawait": 3, "listofkjtlist": 3, "listofkjtlistsplitsawait": 3, "bucketize_kjt_before_all2al": 3, "block_siz": [3, 5], "output_permut": 3, "bucketize_po": 3, "block_bucketize_row_po": 3, "readjust": 3, "note": [3, 4, 5, 9, 12], "memori": [3, 4, 10], "map": [3, 9, 10, 11], "unbucket": 3, "offset": [3, 4, 8, 9, 11, 12], "group_tabl": 3, "tables_per_rank": 3, "datatyp": [3, 4, 9, 11, 12], "poolingtyp": [3, 9], "embeddingcomputekernel": [3, 4], "consist": 3, "weighted": 3, "interfac": [3, 7, 9], "reli": [3, 7, 9, 11], "etc": [3, 7, 10, 12], "moduleshard": [3, 4], "compute_kernel": [3, 4], "storage_usag": 3, "resourc": 3, "processor": [3, 5, 9], "basequantembeddingshard": 3, "shardable_param": 3, "embeddingattribut": 3, "dens": [3, 4, 8, 9, 12], "enum": [3, 4, 9, 10], "fused_uvm": 3, "fused_uvm_cach": 3, "quant_uvm": 3, "quant_uvm_cach": 3, "awar": [3, 12], "feature_nam": [3, 4, 5, 8, 9, 11], "feature_names_per_rank": [3, 5], "data_typ": [3, 9], "is_weight": [3, 4, 9, 11, 12], "has_feature_processor": [3, 5, 9], "dim_sum": 3, "feature_hash_s": [3, 5], "num_featur": [3, 5, 8, 9], "moduleshardingmixin": 3, "access": [3, 4, 10, 12], "scheme": 3, "optimtyp": 3, "adagrad": [3, 10], "adam": [3, 10], "adamw": 3, "lamb": 3, "lars_sgd": 3, "lion": 3, "partial_rowwise_adam": 3, "partial_rowwise_lamb": 3, "rowwise_adagrad": 3, "sgd": 3, "shampoo": 3, "shampoo_v2": 3, "shardedconfig": 3, "local_row": [3, 4], "local_col": [3, 4], "compin": 3, "distout": 3, "out": [3, 9, 12], "shrdctx": 3, "commop": 3, "extra_repr": 3, "pretti": 3, "represent": [3, 4, 6, 9, 12], "num_embed": [3, 4, 8, 9, 11], "fp32": [3, 4, 9], "weight_init_max": [3, 9], "float": [3, 4, 6, 9, 10, 12], "weight_init_min": [3, 9], "pruning_indices_remap": [3, 9], "init_fn": [3, 9], "need_po": [3, 5, 9], "local_metadata": 3, "_shard": 3, "global_metadata": 3, "sharded_tensor": 3, "shardedtensormetadata": 3, "shardedmetaconfig": 3, "compute_kernel_to_embedding_loc": 3, "embeddingloc": 3, "embeddingawait": 3, "embeddingbagcollectionawait": 3, "lazygetitemmixin": 3, "keyedtensor": [3, 8, 9, 11, 12], "embeddingbagcollectioncontext": 3, "inverse_indic": [3, 9, 12], "divisor": 3, "embeddingbagcollectionshard": 3, "embeddingbagshard": 3, "nullshardedmodulecontext": 3, "per_sample_weight": 3, "named_modul": 3, "memo": 3, "network": [3, 4, 9, 10], "alreadi": [3, 5, 7, 10], "onc": [3, 9], "l": [3, 9, 11], "linear": [3, 4, 9, 10], "net": [3, 9], "sequenti": [3, 4, 9], "idx": 3, "in_featur": [3, 8, 9], "out_featur": [3, 9], "sharded_parameter_nam": 3, "embeddingbagcollectioninterfac": [3, 9, 11], "variablebatchembeddingbagcollectionawait": 3, "construct_output_kt": 3, "create_embedding_bag_shard": 3, "permute_embed": [3, 5], "suffix": 3, "replace_placement_with_meta_devic": 3, "placement": [3, 4], "could": [3, 4, 12], "unmatch": 3, "some": [3, 12], "scenario": [3, 9, 11], "cuda": [3, 4, 7], "embeddingshardingplann": [3, 4], "groupedpositionweightedmodul": 3, "max_feature_length": [3, 9], "dataparallelwrapp": 3, "defaultdataparallelwrapp": 3, "bucket_cap_mb": 3, "25": 3, "static_graph": 3, "find_unused_paramet": 3, "allreduce_comm_precis": 3, "unshard": [3, 4, 9, 11], "plan": [3, 4, 9], "shardingplan": [3, 4], "init_data_parallel": 3, "init_paramet": 3, "data_parallel_wrapp": 3, "entri": 3, "point": [3, 4], "collective_plan": [3, 4], "lazi": [3, 9, 10], "delai": 3, "until": 3, "still": [3, 12], "no_grad": [3, 9], "init_weight": [3, 9], "isinst": 3, "fill_": [3, 9], "elif": 3, "init": 3, "kaiming_normal_": 3, "mymodel": 3, "bare_named_paramet": 3, "new": [3, 4], "origin": [3, 4], "tor": 3, "safe": 3, "time": [3, 4, 7, 9], "ddp": 3, "fsdp": 3, "sparse_grad_parameter_nam": [3, 10], "get_modul": 3, "unwrap": 3, "so": [3, 4, 10, 12], "get_unwrapped_modul": 3, "quantembeddingbagcollectionshard": 3, "shardedquantembeddingbagcollect": 3, "quantfeatureprocessedembeddingbagcollectionshard": 3, "featureprocessedembeddingbagcollect": [3, 11], "shardedquantebcinputdist": 3, "sharding_type_to_shard": 3, "nullshardingcontext": [3, 5], "sqebc_input_dist": 3, "infertwsequenceembeddingshard": 3, "f1": [3, 8, 9, 11], "f2": [3, 8, 9, 11], "7": [3, 8, 9, 11, 12], "8": [3, 4, 8, 9, 11, 12], "shardedquantembeddingmodulest": 3, "embedding_bag_config": [3, 9, 11], "embeddingbagconfig": [3, 8, 9, 11], "execut": [3, 4, 7, 9, 11], "step": [3, 4, 10], "sharding_type_to_sharding_info": 3, "tbes_config": 3, "shardedquantfeatureprocessedembeddingbagcollect": 3, "featureprocessorscollect": [3, 11], "apply_feature_processor": 3, "kjt_list": [3, 12], "embedding_bag": [3, 11], "moduledict": [3, 9, 11], "modulelist": [3, 9, 11], "create_infer_embedding_bag_shard": 3, "flatten_feature_length": 3, "get_device_from_parameter_shard": 3, "ps": 3, "get_device_from_sharding_info": 3, "emb_shard_info": 3, "cacheparam": [3, 4], "algorithm": 3, "cachealgorithm": 3, "load_factor": [3, 4], "reserved_memori": 3, "prefetch_pipelin": [3, 4], "cachestatist": [3, 4], "cach": [3, 4], "relat": [3, 4], "most": [3, 10], "fbgemm": [3, 4, 11], "uvm": [3, 4], "lru": [3, 4], "lfu": 3, "load": [3, 4, 10], "factor": [3, 4, 9], "decid": 3, "crucial": 3, "reserv": [3, 4], "ideal": 3, "aka": 3, "statist": [3, 4], "better": 3, "tune": [3, 10], "cacheabl": [3, 4], "summar": [3, 4], "measur": [3, 4], "difficulti": [3, 4], "independ": [3, 4], "score": [3, 4, 5, 9], "mean": [3, 4, 9], "veri": [3, 4], "high": [3, 4, 9], "difficult": [3, 4], "expected_lookup": [3, 4], "distinct": [3, 4], "expected_miss_r": [3, 4], "clf": [3, 4], "rate": [3, 4, 10], "100": [3, 4, 8, 9], "hit": [3, 4], "extrem": [3, 4], "estim": [3, 4], "knowledg": [3, 4], "pooled_embeddings_all_to_al": 3, "pooled_embeddings_reduce_scatt": 3, "sequence_embeddings_all_to_al": 3, "computekernel": 3, "moduleshardingplan": 3, "describ": 3, "genericmeta": 3, "getitemlazyawait": 3, "parentw": 3, "kt": [3, 12], "__getitem__": 3, "parent": 3, "expos": [3, 10], "concret": 3, "behavior": [3, 6, 10], "achiev": 3, "late": 3, "possibl": [3, 4], "__torch_function__": 3, "below": 3, "help": 3, "doesn": [3, 9, 10], "python": [3, 6, 7], "magic": 3, "__getattr__": 3, "caveat": 3, "arbitari": 3, "mechan": [3, 9], "ensur": [3, 9, 12], "perfect": 3, "quickli": 3, "long": [3, 4, 9], "kwd": 3, "vt_co": 3, "augment": 3, "trigger": [3, 9], "keyedlazyawait": 3, "anoth": 3, "defer": 3, "mixin": 3, "inherit": [3, 9], "mro": 3, "properli": [3, 9], "select": [3, 4, 5, 12], "lazynowait": 3, "classmethod": [3, 4, 7, 11], "noopquantizedcommcodec": 3, "quantizationcontext": 3, "No": [3, 5], "calc_quantized_s": 3, "input_len": 3, "decod": 3, "input_grad": 3, "encod": 3, "quantized_dtyp": 3, "nowait": [3, 6], "obj": 3, "sharding_spec": 3, "shardingspec": 3, "cache_param": [3, 4], "enforce_hbm": [3, 4], "stochastic_round": [3, 4], "bounds_check_mod": [3, 4], "boundscheckmod": [3, 4], "output_dtyp": [3, 4, 7, 11], "hbm": [3, 4], "stochast": [3, 4], "round": [3, 4], "bound": [3, 4], "place": [3, 4, 5, 10, 12], "column_wis": [3, 9], "seen": [3, 6], "individu": 3, "table_row_wis": [3, 9], "row_wis": [3, 9], "data_parallel": [3, 4, 9], "parameterstorag": 3, "physic": 3, "constraint": [3, 4], "shardingplann": [3, 4], "ddr": [3, 4], "pooled_all_to_al": 3, "reduce_scatt": 3, "float32": [3, 7, 9, 11], "quantized_tensor": 3, "quantized_comm_codec": 3, "collective_cal": 3, "output_tensor": 3, "assert_clos": 3, "int8": 3, "addit": [3, 4, 6, 7, 9, 10, 12], "carri": 3, "session": 3, "respect": [3, 9], "sequence_all_to_al": 3, "modulenocopymixin": [3, 11], "respons": 3, "transform": [3, 7, 9], "vise": [3, 10], "versa": [3, 10], "practic": 3, "from_loc": 3, "host": [3, 4, 5], "typic": [3, 4, 6, 9, 10, 12], "from_process_group": 3, "fqn": [3, 4], "larger": 3, "desir": 3, "get_plan_for_modul": 3, "module_path": 3, "re": [3, 10], "stabil": 3, "table_column_wis": [3, 9], "get_tensor_size_byt": 3, "scope": [3, 6], "copyablemixin": 3, "target": [3, 8], "mymodul": 3, "add_params_from_parameter_shard": 3, "parameter_shard": 3, "extract": 3, "add": [3, 6, 9, 10], "ones": 3, "add_prefix_to_state_dict": 3, "filter": [3, 9], "append_prefix": 3, "append": 3, "convert_to_fbgemm_typ": 3, "copy_to_devic": 3, "current_devic": [3, 7], "to_devic": 3, "filter_state_dict": 3, "start": [3, 9, 12], "strip": 3, "begin": [3, 10], "get_unsharded_module_nam": 3, "top": [3, 9], "level": [3, 5], "don": [3, 7, 9], "merge_fused_param": 3, "param_fused_param": 3, "configur": 3, "cache_precis": 3, "preset": 3, "table_level_fused_param": 3, "precid": 3, "grouped_fused_param": 3, "null": 3, "none_throw": 3, "_t": 3, "messag": [3, 4], "unexpect": 3, "assertionerror": 3, "optimizer_type_to_emb_opt_typ": 3, "optimizer_class": 3, "emboptimtyp": 3, "sharded_model_copi": 3, "m_cpu": 3, "deepcopi": 3, "managedcollisioncollectionawait": 3, "managedcollisioncollectioncontext": 3, "managedcollisioncollectionshard": 3, "managedcollisioncollect": [3, 9], "shardedmanagedcollisioncollect": 3, "evict": [3, 9], "create_mc_shard": 3, "managedcollisionembeddingbagcollectioncontext": 3, "evictions_per_t": 3, "remapped_kjt": 3, "managedcollisionembeddingbagcollectionshard": 3, "ebc_shard": 3, "mc_sharder": 3, "basemanagedcollisionembeddingcollectionshard": 3, "managedcollisionembeddingbagcollect": [3, 9], "shardedmanagedcollisionembeddingbagcollect": 3, "baseshardedmanagedcollisionembeddingcollect": 3, "managedcollisionembeddingcollectioncontext": 3, "managedcollisionembeddingcollectionshard": 3, "ec_shard": 3, "managedcollisionembeddingcollect": [3, 9], "shardedmanagedcollisionembeddingcollect": 3, "consid": [4, 9, 11, 12], "build": 4, "perf": 4, "storag": [4, 12], "peak": 4, "elimin": 4, "might": [4, 12], "oom": 4, "customiz": 4, "partit": [4, 5], "kernel_bw_lookup": 4, "compute_devic": 4, "hbm_mem_bw": 4, "ddr_mem_bw": 4, "caching_ratio": 4, "calcul": 4, "bandwidth": 4, "ratio": 4, "embeddingenumer": 4, "parameterconstraint": 4, "shardestim": 4, "shardingopt": 4, "valid": [4, 9, 12], "popul": [4, 9], "populate_estim": 4, "sharding_opt": 4, "descript": 4, "get_partition_by_typ": 4, "string": [4, 7, 9], "partitionbytyp": 4, "greedyperfpartition": 4, "sort_bi": 4, "sortbi": 4, "balance_modul": 4, "greedi": 4, "sort": 4, "smaller": 4, "effect": [4, 9], "balanc": 4, "storage_constraint": 4, "partition_bi": 4, "uniform": [4, 9], "strategi": 4, "final": [4, 8, 9, 11, 12], "docstr": [4, 12], "partition_by_devic": 4, "done": [4, 9, 10, 12], "clariti": 4, "memorybalancedpartition": 4, "max_search_count": 4, "10": [4, 8, 9, 11, 12], "toler": 4, "02": 4, "maximum": [4, 9], "greedypartition": 4, "reject": 4, "200": 4, "wors": 4, "repeatedli": 4, "find": 4, "least": 4, "amount": 4, "ordereddevicehardwar": 4, "devicehardwar": 4, "local_world_s": 4, "shardingoptiongroup": 4, "storage_sum": 4, "perf_sum": 4, "param_count": 4, "set_hbm_per_devic": 4, "hbm_per_devic": 4, "noopperfmodel": 4, "perfmodel": 4, "among": [4, 8], "here": 4, "without": [4, 12], "noopstoragemodel": 4, "storagereserv": 4, "performance_model": 4, "debug": 4, "shardabl": 4, "heteroembeddingshardingplann": 4, "topology_group": 4, "embeddingoffloadscaleuppropos": 4, "use_depth": 4, "allocate_budget": 4, "budget": 4, "allocation_prior": 4, "build_affine_storage_model": 4, "uvm_caching_sharding_opt": 4, "clf_to_byt": 4, "feedback": 4, "perf_rat": 4, "get_budget": 4, "get_cach": 4, "get_expected_lookup": 4, "search_spac": 4, "next_plan": 4, "starting_propos": 4, "greedypropos": 4, "threshold": [4, 9], "fashion": [4, 5], "On": [4, 9], "largest": 4, "tri": [4, 10], "next": 4, "max": [4, 9, 10], "earli": 4, "stop": 4, "consecut": 4, "than": [4, 9, 10], "best_perf_r": 4, "gridsearchpropos": 4, "max_propos": 4, "10000": 4, "uniformpropos": 4, "proposers_to_proposals_list": 4, "proposers_list": 4, "static_feedback": 4, "embeddingoffloadstat": 4, "mrc_hist_count": 4, "height": 4, "uvm_fused_cach": 4, "cachebl": 4, "area": 4, "under": 4, "curv": 4, "uniqu": [4, 9], "n": [4, 7, 9, 12], "histogram": 4, "bin": 4, "nth": 4, "wa": [4, 7], "estimate_cache_miss_r": 4, "cache_s": 4, "hist": 4, "mrc": 4, "embeddingperfestim": 4, "is_infer": 4, "wall": 4, "sharder_map": 4, "perf_func_emb_wall_tim": 4, "shard_siz": 4, "input_length": 4, "input_data_type_s": 4, "table_data_type_s": 4, "output_data_type_s": 4, "fwd_a2a_comm_data_type_s": 4, "bwd_a2a_comm_data_type_s": 4, "fwd_sr_comm_data_type_s": 4, "bwd_sr_comm_data_type_s": 4, "num_pool": 4, "intra_host_bw": 4, "inter_host_bw": 4, "bwd_compute_multipli": 4, "is_pool": 4, "expected_cache_fetch": 4, "attempt": 4, "rel": [4, 9], "tw": 4, "dp": 4, "queri": 4, "fwd_comm_data_type_s": 4, "bwd_comm_data_type_s": 4, "sampl": [4, 9], "thread": 4, "machin": [4, 9], "unpool": 4, "ebc": [4, 8, 9, 11], "signifi": 4, "fetch": 4, "embeddingstorageestim": 4, "calculate_shard_storag": 4, "compris": 4, "synonym": 4, "byte": [4, 7], "embeddingstat": 4, "log": 4, "sharding_plan": 4, "num_propos": 4, "num_plan": 4, "run_tim": 4, "best_plan": 4, "tabular": 4, "view": 4, "chosen": [4, 9], "evalu": [4, 9], "successfulli": 4, "taken": 4, "noopembeddingstat": 4, "noop": 4, "round_to_one_sigfig": 4, "fixedpercentagestoragereserv": 4, "percentag": 4, "heuristicalstoragereserv": 4, "parameter_multipli": 4, "dense_tensor_estim": 4, "heurist": 4, "extra": 4, "percent": 4, "act": 4, "margin": 4, "error": [4, 9, 12], "beyond": 4, "inferencestoragereserv": 4, "512": 4, "min_partit": 4, "pooling_factor": 4, "fbgemm_gpu": 4, "split_table_batched_embeddings_ops_common": 4, "device_group": 4, "around": 4, "lower": [4, 6, 7, 10, 11], "column": [4, 5], "rang": [4, 6, 9], "divid": 4, "divis": 4, "optionallist": 4, "momentum": 4, "determinist": 4, "import": [4, 7, 9, 11], "maintain": 4, "accuraci": [4, 9], "term": [4, 9], "fp16": 4, "exce": 4, "todai": 4, "bldm": 4, "fwd_comput": 4, "fwd_comm": 4, "bwd_comput": 4, "bwd_comm": 4, "prefetch_comput": 4, "breakdown": 4, "plannererror": 4, "error_typ": 4, "plannererrortyp": 4, "classifi": 4, "insufficient_storag": 4, "strict_constraint": 4, "prospos": 4, "paritit": 4, "subset": 4, "much": [4, 10], "depend": [4, 7, 9], "One": [4, 9], "eval": 4, "job": 4, "tower": [4, 9], "cache_load_factor": 4, "module_pool": 4, "sharding_option_nam": 4, "num_input": 4, "num_shard": 4, "total_perf": 4, "total_storag": 4, "capac": 4, "hardwar": 4, "fits_in": 4, "hbm_cap": 4, "ddr_cap": 4, "963146416": 4, "128": 4, "54760833": 4, "024": 4, "644245094": 4, "13421772": 4, "binarysearchpred": 4, "extern": [4, 8], "predic": 4, "discov": 4, "binari": 4, "minim": 4, "invoc": 4, "try": 4, "prior_result": 4, "probe": 4, "prior": 4, "entir": [4, 5], "explor": 4, "reach": 4, "luusjaakolasearch": 4, "max_iter": 4, "seed": 4, "42": 4, "left_cost": 4, "clamp": 4, "variant": 4, "luu": 4, "jaakola": 4, "en": 4, "wikipedia": 4, "wiki": 4, "best": 4, "far": 4, "associ": 4, "cost": [4, 9], "left": [4, 12], "right": [4, 9], "fy": 4, "y": [4, 9], "previou": 4, "subsequ": 4, "been": [4, 9], "shrink_right": 4, "shrink": 4, "boundari": 4, "infin": 4, "bytes_to_gb": 4, "num_byt": 4, "bytes_to_mb": 4, "gb_to_byt": 4, "gb": 4, "local_s": [4, 5], "format": [4, 7, 12], "prod": 4, "reset_shard_rank": 4, "sharder_nam": 4, "storage_repr_in_gb": 4, "basecwembeddingshard": 5, "basetwembeddingshard": 5, "cwpooledembeddingshard": 5, "infercwpooledembeddingdist": 5, "infercwpooledembeddingdistwithpermut": 5, "infercwpooledembeddingshard": 5, "basedpembeddingshard": 5, "dppooledembeddingdist": 5, "dppooledembeddingshard": 5, "dpsparsefeaturesdist": 5, "sparsefeatur": 5, "baserwembeddingshard": 5, "infercpurwsparsefeaturesdist": 5, "is_sequ": 5, "emb_shard": 5, "inferrwpooledembeddingdist": 5, "inferrwpooledembeddingshard": 5, "inferrwsparsefeaturesdist": 5, "rwpooledembeddingdist": 5, "share": [5, 9], "rwpooledembeddingshard": 5, "evenli": 5, "rwsparsefeaturesdist": 5, "intra_pg": 5, "hash": [5, 9], "get_block_sizes_runtime_devic": 5, "runtime_devic": 5, "tensor_cach": 5, "int32": [5, 12], "get_embedding_shard_metadata": 5, "grouped_embedding_configs_per_rank": 5, "infertwembeddingshard": 5, "infertwpooledembeddingdist": 5, "infertwsparsefeaturesdist": 5, "twpooledembeddingdist": 5, "twpooledembeddingshard": 5, "twsparsefeaturesdist": 5, "twcwpooledembeddingshard": 5, "basetwrwembeddingshard": 5, "twrwpooledembeddingdist": 5, "cross_pg": 5, "dim_sum_per_nod": 5, "emb_dim_per_node_per_featur": 5, "twrwpooledembeddingshard": 5, "twrwsparsefeaturesdist": 5, "id_list_features_per_rank": 5, "id_score_list_features_per_rank": 5, "id_list_feature_hash_s": 5, "id_score_list_feature_hash_s": 5, "shuffl": 5, "look": [5, 6, 12], "reorder": 5, "document": [6, 8], "leaf_modul": 6, "trace": [6, 7], "torchscript": 6, "create_arg": 6, "complex": 6, "memory_format": 6, "opoverload": 6, "prepar": [6, 9], "graph": 6, "emit": 6, "appropri": 6, "is_leaf_modul": 6, "module_qualified_nam": 6, "module_stack": 6, "node_name_to_scop": 6, "path_of_modul": 6, "mod": 6, "abil": 6, "made": [6, 10], "root": 6, "concrete_arg": 6, "guarante": [6, 10], "is_fx_trac": 6, "symbolic_trac": 6, "graphmodul": 6, "symbol": 6, "record": [6, 9], "partial": 6, "special": [6, 9, 10], "your": 6, "structur": [6, 10], "deploi": 7, "packag": 7, "predictmodul": 7, "predictfactori": 7, "contract": 7, "serv": 7, "predictfactorypackag": 7, "batchingqueu": 7, "config": [7, 9], "gpuexecutor": 7, "insid": 7, "dlrm_packag": 7, "py": 7, "demonstr": 7, "export": 7, "dlrm_predict": 7, "show": 7, "save_predict_factori": 7, "predict_factori": 7, "pathlib": 7, "binaryio": 7, "extra_fil": 7, "loader_cod": 7, "nimport": 7, "nmodule_factori": 7, "package_import": 7, "_sysimport": 7, "set_extern_modul": 7, "decor": 7, "abstractmethod": 7, "set_mocked_modul": 7, "load_config_text": 7, "load_pickle_config": 7, "clazz": 7, "batchingmetadata": 7, "pin": 7, "kept": 7, "sync": [7, 12], "learn": [7, 8, 9, 10], "batching_metadata": 7, "infom": 7, "batching_metadata_json": 7, "serial": 7, "json": 7, "eas": [7, 9], "pars": 7, "create_predict_modul": 7, "transformmodul": 7, "transform_state_dict": 7, "init_process_group": 7, "get_world_s": 7, "model_inputs_data": 7, "benchmark": 7, "qualname_metadata": 7, "qualnamemetadata": 7, "qualnam": 7, "inform": [7, 12], "qualname_metadata_json": 7, "result_metadata": 7, "run_weights_dependent_transform": 7, "predict_modul": 7, "predict": 7, "run_weights_independent_tranform": 7, "predict_forward": 7, "need_preproc": 7, "quantize_dens": 7, "additional_embedding_module_typ": 7, "quantize_embed": 7, "inplac": [7, 11], "additional_qconfig_spec_kei": 7, "additional_map": 7, "per_table_weight_dtyp": [7, 9], "quantize_featur": 7, "trim_torch_package_prefix_from_typenam": 7, "typenam": 7, "densearch": 8, "hidden_layer_s": 8, "deepfmnn": 8, "layer": [8, 9, 10], "embedding_dimens": 8, "dimension": 8, "hidden": [8, 9], "sparsearch": 8, "20": [8, 9], "dense_arch": 8, "dense_arch_input": 8, "dense_embed": 8, "fminteractionarch": 8, "fm_in_featur": 8, "sparse_feature_nam": 8, "deep_fm_dimens": 8, "dense_featur": [8, 9], "interact": [8, 9], "paper": [8, 9], "arxiv": 8, "pdf": 8, "1703": 8, "04247": 8, "cat": [8, 9], "dense_modul": [8, 9], "deep": [8, 9], "di": 8, "arch": 8, "fm_inter_arch": 8, "length_per_kei": [8, 12], "cat_fm_output": 8, "overarch": 8, "simpl": 8, "over_arch": 8, "logit": 8, "simpledeepfmnn": 8, "num_dense_featur": 8, "embedding_bag_collect": [8, 9], "basic": [8, 12], "relationship": 8, "project": 8, "those": [8, 9], "deep_fm": 8, "notat": 8, "throughout": 8, "eb1_config": [8, 11], "f3": 8, "eb2_config": [8, 11], "t2": [8, 9, 11], "sparse_nn": 8, "over_embedding_dim": 8, "9": 8, "from_offsets_sync": [8, 9, 11, 12], "sparse_arch": 8, "extens": 9, "establish": 9, "pattern": 9, "swishlayernorm": 9, "positionweightedmodul": 9, "lazymoduleextensionmixin": 9, "embeddingtow": 9, "embeddingtowercollect": 9, "logic": 9, "input_dim": 9, "swish": 9, "normal": 9, "sigmoid": 9, "layernorm": 9, "d1": 9, "d2": 9, "d3": 9, "last": [9, 12], "sln": 9, "num_lay": 9, "stack": 9, "learnabl": 9, "polynom": 9, "full": [9, 10, 12], "matrix": 9, "nxn": 9, "cover": 9, "bit": 9, "x_": 9, "x_0": 9, "w_l": 9, "cdot": 9, "x_l": 9, "b_l": 9, "squar": 9, "element": 9, "dcn": 9, "lowrankcrossnet": 9, "low_rank": 9, "low": 9, "highli": 9, "effici": 9, "matric": 9, "simplifi": 9, "v_l": 9, "vector": 9, "smartli": 9, "setup": 9, "alwai": [9, 12], "lowrankmixturecrossnet": 9, "num_expert": 9, "relu": 9, "mixtur": 9, "expert": 9, "compar": [9, 12], "leverag": 9, "k": 9, "subspac": 9, "adapt": 9, "gate": 9, "moe": 9, "expert_i": 9, "k_": 9, "u_": 9, "li": 9, "c_": 9, "v_": 9, "vectorcrossnet": 9, "keep": 9, "nx1": 9, "dot": 9, "thu": [9, 10], "further": [9, 12], "cut": 9, "off": 9, "implent": 9, "framework": 9, "factorizationmachin": 9, "fm": 9, "abov": [9, 12], "publish": 9, "compon": 9, "learnt": 9, "To": 9, "flexibl": 9, "raw": 9, "limit": 9, "architectur": 9, "90": 9, "30": 9, "40": 9, "equal": [9, 12], "count": 9, "fb": 9, "lazymlp": 9, "output_dim": 9, "64": 9, "32": 9, "192": 9, "deep_fm_output": 9, "common_spars": 9, "specialized_spars": 9, "embedding_featur": 9, "raw_embedding_featur": 9, "nativ": 9, "trained_embed": 9, "native_embed": 9, "ident": 9, "mention": 9, "2nd": 9, "baseembeddingconfig": 9, "get_weight_init_max": 9, "get_weight_init_min": 9, "embeddingconfig": [9, 11], "quantconfig": 9, "placeholderobserv": [9, 11], "alia": 9, "data_type_to_dtyp": 9, "data_type_to_sparse_typ": 9, "sparsetyp": 9, "dtype_to_data_typ": 9, "pooling_type_to_pooling_mod": 9, "pooling_typ": 9, "poolingmod": 9, "pooling_type_to_str": 9, "sensit": [9, 11], "jag": [9, 11, 12], "table_0": [9, 11], "table_1": [9, 11], "pooled_embed": 9, "8899": 9, "1342": 9, "9060": 9, "0905": 9, "2814": 9, "9369": 9, "7783": 9, "0000": 9, "1598": 9, "0695": 9, "3265": 9, "1011": 9, "4256": 9, "1846": 9, "1648": 9, "0893": 9, "3590": 9, "9784": 9, "7681": 9, "grad_fn": [9, 11], "catbackward0": 9, "offset_per_kei": [9, 12], "need_indic": [9, 11], "e1_config": [9, 11], "e2_config": [9, 11], "ec": [9, 11], "feature_embed": [9, 11], "2050": [9, 11], "5478": [9, 11], "6054": [9, 11], "7352": [9, 11], "3210": [9, 11], "0399": [9, 11], "1279": [9, 11], "1756": [9, 11], "4130": [9, 11], "7519": [9, 11], "4341": [9, 11], "0499": [9, 11], "9329": [9, 11], "0697": [9, 11], "8095": [9, 11], "embeddingbackward": [9, 11], "embedding_names_by_t": [9, 11], "get_embedding_names_by_t": 9, "process_pooled_embed": 9, "reorder_inverse_indic": 9, "basefeatureprocessor": 9, "max_length": 9, "truncat": 9, "positionweightedprocessor": 9, "feature_length": 9, "feature0": [9, 12], "feature1": [9, 12], "feature2": 9, "from_lengths_sync": [9, 12], "pw": 9, "featureprocessorcollect": 9, "feature_processor_modul": 9, "positionweightedfeatureprocessor": 9, "fp_featur": 9, "non_fp_featur": 9, "non_fp": 9, "feature_process": 9, "come": 9, "And": 9, "offsets_to_range_tracebl": 9, "position_weighted_module_update_featur": 9, "weighted_featur": 9, "lazymodulemixin": 9, "temporari": 9, "upstream": 9, "59923": 9, "testlazymoduleextensionmixin": 9, "unit": 9, "test": 9, "_infer_paramet": 9, "code": 9, "pariti": 9, "_call_impl": 9, "pre": [9, 10], "children": 9, "uniniti": 9, "dummi": [9, 10], "lazylinear": 9, "fail": [9, 12], "becaus": [9, 10], "hasn": 9, "yet": 9, "now": [9, 12], "lazy_appli": 9, "attach": 9, "numer": 9, "immedi": 9, "seq": 9, "in_siz": 9, "layer_s": 9, "perceptron": 9, "multi": 9, "out_siz": 9, "swish_layernorm": 9, "won": 9, "constructor": 9, "mlp_modul": 9, "assert": 9, "o": 9, "channel": 9, "check_module_output_dimens": 9, "verifi": 9, "construct_jagged_tensor": 9, "features_to_permute_indic": 9, "original_featur": 9, "construct_jagged_tensors_infer": 9, "construct_modulelist_from_single_modul": 9, "nest": 9, "reiniti": 9, "convert_list_of_modules_to_modulelist": 9, "extract_module_or_tensor_cal": 9, "module_or_cal": 9, "get_module_output_dimens": 9, "init_mlp_weights_xavier_uniform": 9, "distancelfu_evictionpolici": 9, "decay_expon": 9, "threshold_filtering_func": 9, "mchevictionpolici": 9, "coalesce_history_metadata": 9, "current_it": 9, "history_metadata": 9, "unique_ids_count": 9, "unique_inverse_map": 9, "additional_id": 9, "threshold_mask": 9, "histori": 9, "invers": [9, 12], "history_accumul": 9, "coalesc": 9, "metadata_info": 9, "mchevictionpolicymetadatainfo": 9, "record_history_metadata": 9, "incoming_id": 9, "incom": 9, "polici": [9, 10], "update_metadata_and_generate_eviction_scor": 9, "mch_size": 9, "coalesced_history_argsort_map": 9, "coalesced_history_sorted_unique_ids_count": 9, "coalesced_history_mch_matching_elements_mask": 9, "coalesced_history_mch_matching_indic": 9, "mch_metadata": 9, "coalesced_history_metadata": 9, "evicted_indic": 9, "selected_new_indic": 9, "mch": 9, "lfu_evictionpolici": 9, "lru_evictionpolici": 9, "metadata_nam": 9, "is_mch_metadata": 9, "is_history_metadata": 9, "mchmanagedcollisionmodul": 9, "zch_size": 9, "eviction_polici": 9, "eviction_interv": 9, "input_hash_s": 9, "9223372036854775808": 9, "input_hash_func": 9, "mch_hash_func": 9, "output_global_offset": 9, "managedcollisionmodul": 9, "zch": 9, "manag": 9, "collis": 9, "output_size_offset": 9, "interv": 9, "drive": 9, "greater": 9, "residu": 9, "legaci": 9, "intern": [9, 12], "shift": 9, "zch_output_rang": 9, "down": 9, "applic": 9, "slot": 9, "reset": [9, 10], "assumptionn": 9, "downstream": 9, "modifi": [9, 10], "jt": [9, 12], "rtype": 9, "output_s": 9, "vs": 9, "preprocess": 9, "profil": 9, "rebuild_with_output_id_rang": 9, "output_id_rang": 9, "mc": 9, "hack": 9, "remap": 9, "managed_collision_modul": 9, "mcc": 9, "embedding_confg": 9, "collsion": 9, "max_output_id": 9, "remapping_range_start_index": 9, "mcm": 9, "mcm_jt": 9, "fp": 9, "apply_mc_method_to_jt_dict": 9, "features_dict": 9, "table_to_featur": 9, "managed_collis": 9, "average_threshold_filt": 9, "id_count": 9, "dynamic_threshold_filt": 9, "threshold_skew_multipli": 9, "total_count": 9, "num_id": 9, "probabilistic_threshold_filt": 9, "per_id_prob": 9, "01": 9, "probabl": 9, "appear": 9, "60": 9, "randomli": 9, "chanc": 9, "basemanagedcollisionembeddingcollect": 9, "managed_collision_collect": 9, "return_remapped_featur": 9, "embedding_collect": 9, "meaning": 10, "prohibit": 10, "empti": [10, 12], "sever": 10, "combinedoptim": 10, "optimizerwrapp": 10, "rowwis": 10, "gradientclip": 10, "norm": 10, "gradientclippingoptim": 10, "max_gradi": 10, "closur": 10, "reevalu": 10, "loss": 10, "emptyfusedoptim": 10, "fusedoptim": 10, "zero_grad": 10, "set_to_non": 10, "zero": [10, 12], "footprint": 10, "modestli": 10, "improv": 10, "certain": 10, "0s": 10, "behav": 10, "did": 10, "altogeth": 10, "param_group": 10, "meant": 10, "post_load_state_dict": 10, "prepend_opt_kei": 10, "opt_kei": 10, "save_param_group": 10, "stricter": 10, "old": 10, "switch": 10, "flag": 10, "reason": 10, "identifi": 10, "littl": 10, "add_param_group": 10, "fine": 10, "frozen": 10, "trainabl": 10, "progress": 10, "what": 10, "init_st": 10, "checkpoint": 10, "usabl": 10, "sure": 10, "sd": 10, "load_checkpoint": 10, "replac": 10, "advanc": 10, "protocol": 10, "keyedoptimizerwrapp": 10, "optim_factori": 10, "conveni": 10, "warmupoptim": 10, "stage": 10, "warmupstag": 10, "lr": 10, "lr_param": 10, "param_nam": 10, "__warmup": 10, "adjust": 10, "schedul": 10, "go": 10, "fake": 10, "warmuppolici": 10, "invsqrt": 10, "inv_sqrt": 10, "poli": 10, "max_it": 10, "lr_scale": 10, "decay_it": 10, "speed": 11, "trec_quant": 11, "trec": 11, "qconfig": 11, "with_arg": 11, "qint8": 11, "quantize_dynam": 11, "qconfig_spec": 11, "table_name_to_quantized_weight": 11, "register_tb": 11, "quant_state_dict_split_scale_bia": 11, "row_align": 11, "qebc": 11, "quantembeddingbagcollect": 11, "from_float": 11, "quantized_embed": 11, "features_to_dict": 11, "for_each_module_of_type_do": 11, "pruned_num_embed": 11, "pruning_indices_map": 11, "quant_prep_customize_row_align": 11, "quant_prep_enable_quant_state_dict_split_scale_bia": 11, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 11, "quant_prep_enable_register_tb": 11, "quantize_state_dict": 11, "table_name_to_data_typ": 11, "table_name_to_pruning_indices_map": 11, "whose": 12, "dimes": 12, "computejtdicttokjt": 12, "jt_dict": 12, "v5": 12, "v6": 12, "v7": 12, "dim_1": 12, "dim_0": 12, "computekjttojtdict": 12, "keyed_jagged_tensor": 12, "jit": 12, "abl": 12, "NOT": 12, "expens": 12, "values_dtyp": 12, "weights_dtyp": 12, "lengths_dtyp": 12, "from_dens": 12, "2d": 12, "11": 12, "12": 12, "j1": 12, "from_dense_length": 12, "lengths_or_non": 12, "offsets_or_non": 12, "non_block": 12, "new_devic": 12, "to_dens": 12, "inttensor": 12, "values_list": 12, "to_dense_weight": 12, "weights_list": 12, "to_padded_dens": 12, "desired_length": 12, "padding_valu": 12, "longest": 12, "pad": 12, "dt": 12, "to_padded_dense_weight": 12, "d_wt": 12, "weights_or_non": 12, "jaggedtensormeta": 12, "namespac": 12, "abcmeta": 12, "proxyableclassmeta": 12, "stride_per_key_per_rank": 12, "outer": 12, "inner": 12, "index_per_kei": 12, "expand": 12, "dedupl": 12, "dim_2": 12, "w0": 12, "w1": 12, "w2": 12, "w3": 12, "w4": 12, "w5": 12, "w6": 12, "w7": 12, "dist_init": 12, "variable_stride_per_kei": 12, "num_work": 12, "dist_label": 12, "dist_split": 12, "key_split": 12, "dist_tensor": 12, "empty_lik": 12, "flatten_length": 12, "from_jt_dict": 12, "implicit": 12, "visual": 12, "variable_feature_dim": 12, "But": 12, "That": 12, "didn": 12, "notic": 12, "correctli": 12, "technic": 12, "know": 12, "violat": 12, "precondit": 12, "fix": 12, "inverse_indices_or_non": 12, "length_per_key_or_non": 12, "lengths_offset_per_kei": 12, "offset_per_key_or_non": 12, "indices_tensor": 12, "include_inverse_indic": 12, "pin_memori": 12, "segment": 12, "stride_per_kei": 12, "to_dict": 12, "unsync": 12, "key_dim": 12, "tensor_list": 12, "from_tensor_list": 12, "regroup": 12, "keyed_tensor": 12, "regroup_as_dict": 12, "flatten_kjt_list": 12, "kjt_arr": 12, "is_non_strict_export": 12, "jt_is_equ": 12, "jt_1": 12, "jt_2": 12, "comparison": 12, "themselv": 12, "treat": 12, "kjt_is_equ": 12, "kjt_1": 12, "kjt_2": 12, "unflatten_kjt_list": 12}, "objects": {"torchrec": [[3, 0, 0, "-", "distributed"], [6, 0, 0, "module-0", "fx"], [7, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "modules"], [10, 0, 0, "module-0", "optim"], [11, 0, 0, "module-0", "quant"], [12, 0, 0, "module-0", "sparse"]], "torchrec.distributed": [[3, 0, 0, "-", "collective_utils"], [3, 0, 0, "-", "comm"], [3, 0, 0, "-", "comm_ops"], [5, 0, 0, "-", "dist_data"], [3, 0, 0, "-", "embedding"], [3, 0, 0, "-", "embedding_lookup"], [3, 0, 0, "-", "embedding_sharding"], [3, 0, 0, "-", "embedding_types"], [3, 0, 0, "-", "embeddingbag"], [3, 0, 0, "-", "grouped_position_weighted"], [3, 0, 0, "-", "mc_embedding"], [3, 0, 0, "-", "mc_embeddingbag"], [3, 0, 0, "-", "mc_modules"], [3, 0, 0, "-", "model_parallel"], [4, 0, 0, "-", "planner"], [3, 0, 0, "-", "quant_embeddingbag"], [5, 0, 0, "-", "sharding"], [3, 0, 0, "-", "train_pipeline"], [3, 0, 0, "-", "types"], [3, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[3, 1, 1, "", "invoke_on_rank_and_broadcast_result"], [3, 1, 1, "", "is_leader"], [3, 1, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[3, 1, 1, "", "get_group_rank"], [3, 1, 1, "", "get_local_rank"], [3, 1, 1, "", "get_local_size"], [3, 1, 1, "", "get_num_groups"], [3, 1, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[3, 2, 1, "", "All2AllDenseInfo"], [3, 2, 1, "", "All2AllPooledInfo"], [3, 2, 1, "", "All2AllSequenceInfo"], [3, 2, 1, "", "All2AllVInfo"], [3, 2, 1, "", "All2All_Pooled_Req"], [3, 2, 1, "", "All2All_Pooled_Wait"], [3, 2, 1, "", "All2All_Seq_Req"], [3, 2, 1, "", "All2All_Seq_Req_Wait"], [3, 2, 1, "", "All2Allv_Req"], [3, 2, 1, "", "All2Allv_Wait"], [3, 2, 1, "", "AllGatherBaseInfo"], [3, 2, 1, "", "AllGatherBase_Req"], [3, 2, 1, "", "AllGatherBase_Wait"], [3, 2, 1, "", "ReduceScatterBaseInfo"], [3, 2, 1, "", "ReduceScatterBase_Req"], [3, 2, 1, "", "ReduceScatterBase_Wait"], [3, 2, 1, "", "ReduceScatterInfo"], [3, 2, 1, "", "ReduceScatterVInfo"], [3, 2, 1, "", "ReduceScatterV_Req"], [3, 2, 1, "", "ReduceScatterV_Wait"], [3, 2, 1, "", "ReduceScatter_Req"], [3, 2, 1, "", "ReduceScatter_Wait"], [3, 2, 1, "", "Request"], [3, 2, 1, "", "VariableBatchAll2AllPooledInfo"], [3, 2, 1, "", "Variable_Batch_All2All_Pooled_Req"], [3, 2, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [3, 1, 1, "", "all2all_pooled_sync"], [3, 1, 1, "", "all2all_sequence_sync"], [3, 1, 1, "", "all2allv_sync"], [3, 1, 1, "", "all_gather_base_pooled"], [3, 1, 1, "", "all_gather_base_sync"], [3, 1, 1, "", "alltoall_pooled"], [3, 1, 1, "", "alltoall_sequence"], [3, 1, 1, "", "alltoallv"], [3, 1, 1, "", "fn"], [3, 1, 1, "", "get_gradient_division"], [3, 1, 1, "", "reduce_scatter_base_pooled"], [3, 1, 1, "", "reduce_scatter_base_sync"], [3, 1, 1, "", "reduce_scatter_pooled"], [3, 1, 1, "", "reduce_scatter_sync"], [3, 1, 1, "", "reduce_scatter_v_per_feature_pooled"], [3, 1, 1, "", "reduce_scatter_v_pooled"], [3, 1, 1, "", "reduce_scatter_v_sync"], [3, 1, 1, "", "set_gradient_division"], [3, 1, 1, "", "variable_batch_all2all_pooled_sync"], [3, 1, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[3, 3, 1, "", "batch_size"], [3, 3, 1, "", "input_shape"], [3, 3, 1, "", "input_splits"], [3, 3, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[3, 3, 1, "id0", "batch_size_per_rank"], [3, 3, 1, "id1", "codecs"], [3, 3, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [3, 3, 1, "id3", "dim_sum_per_rank"], [3, 3, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[3, 3, 1, "id5", "backward_recat_tensor"], [3, 3, 1, "id6", "codecs"], [3, 3, 1, "id7", "embedding_dim"], [3, 3, 1, "id8", "forward_recat_tensor"], [3, 3, 1, "id9", "input_splits"], [3, 3, 1, "id10", "lengths_after_sparse_data_all2all"], [3, 3, 1, "id11", "output_splits"], [3, 3, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [3, 3, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[3, 3, 1, "id14", "B_global"], [3, 3, 1, "id15", "B_local"], [3, 3, 1, "id16", "B_local_list"], [3, 3, 1, "id17", "D_local_list"], [3, 3, 1, "", "codecs"], [3, 3, 1, "", "dim_sum_per_rank"], [3, 3, 1, "", "dims_sum_per_rank"], [3, 3, 1, "id18", "input_split_sizes"], [3, 3, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[3, 3, 1, "id23", "codecs"], [3, 3, 1, "id24", "equal_splits"], [3, 3, 1, "id25", "input_sizes"], [3, 3, 1, "id26", "input_splits"], [3, 3, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[3, 3, 1, "id28", "batch_size_per_feature_pre_a2a"], [3, 3, 1, "id29", "batch_size_per_rank_per_feature"], [3, 3, 1, "id30", "codecs"], [3, 3, 1, "id31", "emb_dim_per_rank_per_feature"], [3, 3, 1, "id32", "input_splits"], [3, 3, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.dist_data": [[5, 2, 1, "", "EmbeddingsAllToOne"], [5, 2, 1, "", "EmbeddingsAllToOneReduce"], [5, 2, 1, "", "KJTAllToAll"], [5, 2, 1, "", "KJTAllToAllSplitsAwaitable"], [5, 2, 1, "", "KJTAllToAllTensorsAwaitable"], [5, 2, 1, "", "KJTOneToAll"], [5, 2, 1, "", "PooledEmbeddingsAllGather"], [5, 2, 1, "", "PooledEmbeddingsAllToAll"], [5, 2, 1, "", "PooledEmbeddingsAwaitable"], [5, 2, 1, "", "PooledEmbeddingsReduceScatter"], [5, 2, 1, "", "SeqEmbeddingsAllToOne"], [5, 2, 1, "", "SequenceEmbeddingsAllToAll"], [5, 2, 1, "", "SequenceEmbeddingsAwaitable"], [5, 2, 1, "", "SplitsAllToAllAwaitable"], [5, 2, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [5, 2, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[5, 5, 1, "", "callbacks"], [5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[5, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[5, 5, 1, "", "callbacks"], [5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.embedding": [[3, 2, 1, "", "EmbeddingCollectionAwaitable"], [3, 2, 1, "", "EmbeddingCollectionContext"], [3, 2, 1, "", "EmbeddingCollectionSharder"], [3, 2, 1, "", "ShardedEmbeddingCollection"], [3, 1, 1, "", "create_embedding_sharding"], [3, 1, 1, "", "create_sharding_infos_by_sharding"], [3, 1, 1, "", "get_ec_index_dedup"], [3, 1, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "reset_parameters"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[3, 2, 1, "", "CommOpGradientScaling"], [3, 2, 1, "", "GroupedEmbeddingsLookup"], [3, 2, 1, "", "GroupedPooledEmbeddingsLookup"], [3, 2, 1, "", "InferCPUGroupedEmbeddingsLookup"], [3, 2, 1, "", "InferGroupedEmbeddingsLookup"], [3, 2, 1, "", "InferGroupedLookupMixin"], [3, 2, 1, "", "InferGroupedPooledEmbeddingsLookup"], [3, 2, 1, "", "MetaInferGroupedEmbeddingsLookup"], [3, 2, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [3, 1, 1, "", "embeddings_cat_empty_rank_handle"], [3, 1, 1, "", "embeddings_cat_empty_rank_handle_inference"], [3, 1, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "named_parameters_by_table"], [3, 4, 1, "", "prefetch"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "named_parameters_by_table"], [3, 4, 1, "", "prefetch"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "get_tbes_to_register"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "get_tbes_to_register"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[3, 2, 1, "", "BaseEmbeddingDist"], [3, 2, 1, "", "BaseSparseFeaturesDist"], [3, 2, 1, "", "EmbeddingSharding"], [3, 2, 1, "", "EmbeddingShardingContext"], [3, 2, 1, "", "EmbeddingShardingInfo"], [3, 2, 1, "", "FusedKJTListSplitsAwaitable"], [3, 2, 1, "", "KJTListAwaitable"], [3, 2, 1, "", "KJTListSplitsAwaitable"], [3, 2, 1, "", "KJTSplitsAllToAllMeta"], [3, 2, 1, "", "ListOfKJTListAwaitable"], [3, 2, 1, "", "ListOfKJTListSplitsAwaitable"], [3, 1, 1, "", "bucketize_kjt_before_all2all"], [3, 1, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[3, 4, 1, "", "create_input_dist"], [3, 4, 1, "", "create_lookup"], [3, 4, 1, "", "create_output_dist"], [3, 4, 1, "", "embedding_dims"], [3, 4, 1, "", "embedding_names"], [3, 4, 1, "", "embedding_names_per_rank"], [3, 4, 1, "", "embedding_shard_metadata"], [3, 4, 1, "", "embedding_tables"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "uncombined_embedding_dims"], [3, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[3, 3, 1, "", "embedding_config"], [3, 3, 1, "", "fused_params"], [3, 3, 1, "", "param"], [3, 3, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[3, 3, 1, "", "device"], [3, 3, 1, "", "input_splits"], [3, 3, 1, "", "input_tensors"], [3, 3, 1, "", "keys"], [3, 3, 1, "", "labels"], [3, 3, 1, "", "pg"], [3, 3, 1, "", "splits"], [3, 3, 1, "", "splits_tensors"], [3, 3, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[3, 2, 1, "", "BaseEmbeddingLookup"], [3, 2, 1, "", "BaseEmbeddingSharder"], [3, 2, 1, "", "BaseGroupedFeatureProcessor"], [3, 2, 1, "", "BaseQuantEmbeddingSharder"], [3, 2, 1, "", "EmbeddingAttributes"], [3, 2, 1, "", "EmbeddingComputeKernel"], [3, 2, 1, "", "FeatureShardingMixIn"], [3, 2, 1, "", "GroupedEmbeddingConfig"], [3, 2, 1, "", "KJTList"], [3, 2, 1, "", "ListOfKJTList"], [3, 2, 1, "", "ModuleShardingMixIn"], [3, 2, 1, "", "OptimType"], [3, 2, 1, "", "ShardedConfig"], [3, 2, 1, "", "ShardedEmbeddingModule"], [3, 2, 1, "", "ShardedEmbeddingTable"], [3, 2, 1, "", "ShardedMetaConfig"], [3, 1, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "fused_params"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "fused_params"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[3, 3, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[3, 3, 1, "", "DENSE"], [3, 3, 1, "", "FUSED"], [3, 3, 1, "", "FUSED_UVM"], [3, 3, 1, "", "FUSED_UVM_CACHING"], [3, 3, 1, "", "QUANT"], [3, 3, 1, "", "QUANT_UVM"], [3, 3, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[3, 4, 1, "", "feature_names"], [3, 4, 1, "", "feature_names_per_rank"], [3, 4, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[3, 3, 1, "", "compute_kernel"], [3, 3, 1, "", "data_type"], [3, 4, 1, "", "dim_sum"], [3, 4, 1, "", "embedding_dims"], [3, 4, 1, "", "embedding_names"], [3, 4, 1, "", "embedding_shard_metadata"], [3, 3, 1, "", "embedding_tables"], [3, 4, 1, "", "feature_hash_sizes"], [3, 4, 1, "", "feature_names"], [3, 3, 1, "", "fused_params"], [3, 3, 1, "", "has_feature_processor"], [3, 3, 1, "", "is_weighted"], [3, 4, 1, "", "num_features"], [3, 3, 1, "", "pooling"], [3, 4, 1, "", "table_names"]], "torchrec.distributed.embedding_types.KJTList": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[3, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[3, 3, 1, "", "ADAGRAD"], [3, 3, 1, "", "ADAM"], [3, 3, 1, "", "ADAMW"], [3, 3, 1, "", "LAMB"], [3, 3, 1, "", "LARS_SGD"], [3, 3, 1, "", "LION"], [3, 3, 1, "", "PARTIAL_ROWWISE_ADAM"], [3, 3, 1, "", "PARTIAL_ROWWISE_LAMB"], [3, 3, 1, "", "ROWWISE_ADAGRAD"], [3, 3, 1, "", "SGD"], [3, 3, 1, "", "SHAMPOO"], [3, 3, 1, "", "SHAMPOO_V2"]], "torchrec.distributed.embedding_types.ShardedConfig": [[3, 3, 1, "", "local_cols"], [3, 3, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[3, 4, 1, "", "extra_repr"], [3, 4, 1, "", "prefetch"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[3, 3, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[3, 3, 1, "", "global_metadata"], [3, 3, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[3, 2, 1, "", "EmbeddingAwaitable"], [3, 2, 1, "", "EmbeddingBagCollectionAwaitable"], [3, 2, 1, "", "EmbeddingBagCollectionContext"], [3, 2, 1, "", "EmbeddingBagCollectionSharder"], [3, 2, 1, "", "EmbeddingBagSharder"], [3, 2, 1, "", "ShardedEmbeddingBag"], [3, 2, 1, "", "ShardedEmbeddingBagCollection"], [3, 2, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [3, 1, 1, "", "construct_output_kt"], [3, 1, 1, "", "create_embedding_bag_sharding"], [3, 1, 1, "", "create_sharding_infos_by_sharding"], [3, 1, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[3, 3, 1, "", "divisor"], [3, 3, 1, "", "inverse_indices"], [3, 4, 1, "", "record_stream"], [3, 3, 1, "", "sharding_contexts"], [3, 3, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_modules"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharded_parameter_names"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "reset_parameters"], [3, 3, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[3, 2, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[3, 4, 1, "", "forward"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[3, 2, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [3, 2, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[3, 4, 1, "", "create_context"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[3, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [3, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[3, 3, 1, "", "evictions_per_table"], [3, 4, 1, "", "record_stream"], [3, 3, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[3, 4, 1, "", "create_context"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_modules": [[3, 2, 1, "", "ManagedCollisionCollectionAwaitable"], [3, 2, 1, "", "ManagedCollisionCollectionContext"], [3, 2, 1, "", "ManagedCollisionCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionCollection"], [3, 1, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "evict"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharded_parameter_names"], [3, 3, 1, "", "training"]], "torchrec.distributed.model_parallel": [[3, 2, 1, "", "DataParallelWrapper"], [3, 2, 1, "", "DefaultDataParallelWrapper"], [3, 2, 1, "", "DistributedModelParallel"], [3, 1, 1, "", "get_module"], [3, 1, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[3, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[3, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[3, 4, 1, "", "bare_named_parameters"], [3, 4, 1, "", "copy"], [3, 4, 1, "", "forward"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "init_data_parallel"], [3, 4, 1, "", "load_state_dict"], [3, 5, 1, "", "module"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 5, 1, "", "plan"], [3, 4, 1, "", "sparse_grad_parameter_names"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.planner": [[4, 0, 0, "-", "constants"], [4, 0, 0, "-", "enumerators"], [4, 0, 0, "-", "partitioners"], [4, 0, 0, "-", "perf_models"], [4, 0, 0, "-", "planners"], [4, 0, 0, "-", "proposers"], [4, 0, 0, "-", "shard_estimators"], [4, 0, 0, "-", "stats"], [4, 0, 0, "-", "storage_reservations"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[4, 1, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[4, 2, 1, "", "EmbeddingEnumerator"], [4, 1, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[4, 4, 1, "", "enumerate"], [4, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[4, 2, 1, "", "GreedyPerfPartitioner"], [4, 2, 1, "", "MemoryBalancedPartitioner"], [4, 2, 1, "", "OrderedDeviceHardware"], [4, 2, 1, "", "ShardingOptionGroup"], [4, 2, 1, "", "SortBy"], [4, 1, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[4, 3, 1, "", "device"], [4, 3, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[4, 3, 1, "", "param_count"], [4, 3, 1, "", "perf_sum"], [4, 3, 1, "", "sharding_options"], [4, 3, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[4, 3, 1, "", "PERF"], [4, 3, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[4, 2, 1, "", "NoopPerfModel"], [4, 2, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[4, 2, 1, "", "EmbeddingShardingPlanner"], [4, 2, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[4, 4, 1, "", "collective_plan"], [4, 4, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[4, 4, 1, "", "collective_plan"], [4, 4, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[4, 2, 1, "", "EmbeddingOffloadScaleupProposer"], [4, 2, 1, "", "GreedyProposer"], [4, 2, 1, "", "GridSearchProposer"], [4, 2, 1, "", "UniformProposer"], [4, 1, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[4, 4, 1, "", "allocate_budget"], [4, 4, 1, "", "build_affine_storage_model"], [4, 4, 1, "", "clf_to_bytes"], [4, 4, 1, "", "feedback"], [4, 4, 1, "", "get_budget"], [4, 4, 1, "", "get_cacheability"], [4, 4, 1, "", "get_expected_lookups"], [4, 4, 1, "", "load"], [4, 4, 1, "", "next_plan"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[4, 2, 1, "", "EmbeddingOffloadStats"], [4, 2, 1, "", "EmbeddingPerfEstimator"], [4, 2, 1, "", "EmbeddingStorageEstimator"], [4, 1, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[4, 5, 1, "", "cacheability"], [4, 4, 1, "", "estimate_cache_miss_rate"], [4, 5, 1, "", "expected_lookups"], [4, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[4, 4, 1, "", "estimate"], [4, 4, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[4, 4, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[4, 2, 1, "", "EmbeddingStats"], [4, 2, 1, "", "NoopEmbeddingStats"], [4, 1, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[4, 2, 1, "", "FixedPercentageStorageReservation"], [4, 2, 1, "", "HeuristicalStorageReservation"], [4, 2, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[4, 2, 1, "", "DeviceHardware"], [4, 2, 1, "", "Enumerator"], [4, 2, 1, "", "ParameterConstraints"], [4, 2, 1, "", "PartitionByType"], [4, 2, 1, "", "Partitioner"], [4, 2, 1, "", "Perf"], [4, 2, 1, "", "PerfModel"], [4, 6, 1, "", "PlannerError"], [4, 2, 1, "", "PlannerErrorType"], [4, 2, 1, "", "Proposer"], [4, 2, 1, "", "Shard"], [4, 2, 1, "", "ShardEstimator"], [4, 2, 1, "", "ShardingOption"], [4, 2, 1, "", "Stats"], [4, 2, 1, "", "Storage"], [4, 2, 1, "", "StorageReservation"], [4, 2, 1, "", "Topology"]], "torchrec.distributed.planner.types.DeviceHardware": [[4, 3, 1, "", "perf"], [4, 3, 1, "", "rank"], [4, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[4, 4, 1, "", "enumerate"], [4, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[4, 3, 1, "id0", "batch_sizes"], [4, 3, 1, "id1", "bounds_check_mode"], [4, 3, 1, "id2", "cache_params"], [4, 3, 1, "id3", "compute_kernels"], [4, 3, 1, "id4", "device_group"], [4, 3, 1, "id5", "enforce_hbm"], [4, 3, 1, "id6", "feature_names"], [4, 3, 1, "id7", "is_weighted"], [4, 3, 1, "id8", "min_partition"], [4, 3, 1, "id9", "num_poolings"], [4, 3, 1, "id10", "output_dtype"], [4, 3, 1, "id11", "pooling_factors"], [4, 3, 1, "id12", "sharding_types"], [4, 3, 1, "id13", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[4, 3, 1, "", "DEVICE"], [4, 3, 1, "", "HOST"], [4, 3, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[4, 3, 1, "", "bwd_comms"], [4, 3, 1, "", "bwd_compute"], [4, 3, 1, "", "fwd_comms"], [4, 3, 1, "", "fwd_compute"], [4, 3, 1, "", "prefetch_compute"], [4, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[4, 3, 1, "", "INSUFFICIENT_STORAGE"], [4, 3, 1, "", "OTHER"], [4, 3, 1, "", "PARTITION"], [4, 3, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[4, 3, 1, "", "offset"], [4, 3, 1, "", "perf"], [4, 3, 1, "", "rank"], [4, 3, 1, "", "size"], [4, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[4, 4, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[4, 3, 1, "", "batch_size"], [4, 3, 1, "", "bounds_check_mode"], [4, 5, 1, "", "cache_load_factor"], [4, 3, 1, "", "cache_params"], [4, 3, 1, "", "compute_kernel"], [4, 3, 1, "", "dependency"], [4, 3, 1, "", "enforce_hbm"], [4, 3, 1, "", "feature_names"], [4, 5, 1, "", "fqn"], [4, 3, 1, "", "input_lengths"], [4, 5, 1, "id14", "is_pooled"], [4, 5, 1, "id15", "module"], [4, 4, 1, "", "module_pooled"], [4, 3, 1, "", "name"], [4, 5, 1, "", "num_inputs"], [4, 5, 1, "", "num_shards"], [4, 3, 1, "", "output_dtype"], [4, 5, 1, "", "path"], [4, 3, 1, "", "sharding_type"], [4, 3, 1, "", "shards"], [4, 3, 1, "", "stochastic_rounding"], [4, 5, 1, "id16", "tensor"], [4, 5, 1, "", "total_perf"], [4, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[4, 3, 1, "", "ddr"], [4, 4, 1, "", "fits_in"], [4, 3, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[4, 5, 1, "", "bwd_compute_multiplier"], [4, 5, 1, "", "compute_device"], [4, 5, 1, "", "ddr_mem_bw"], [4, 5, 1, "", "devices"], [4, 5, 1, "", "hbm_mem_bw"], [4, 5, 1, "", "inter_host_bw"], [4, 5, 1, "", "intra_host_bw"], [4, 5, 1, "", "local_world_size"], [4, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[4, 2, 1, "", "BinarySearchPredicate"], [4, 2, 1, "", "LuusJaakolaSearch"], [4, 1, 1, "", "bytes_to_gb"], [4, 1, 1, "", "bytes_to_mb"], [4, 1, 1, "", "gb_to_bytes"], [4, 1, 1, "", "placement"], [4, 1, 1, "", "prod"], [4, 1, 1, "", "reset_shard_rank"], [4, 1, 1, "", "sharder_name"], [4, 1, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[4, 4, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[4, 4, 1, "", "best"], [4, 4, 1, "", "clamp"], [4, 4, 1, "", "next"], [4, 4, 1, "", "shrink_right"], [4, 4, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[3, 2, 1, "", "QuantEmbeddingBagCollectionSharder"], [3, 2, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [3, 2, 1, "", "ShardedQuantEbcInputDist"], [3, 2, 1, "", "ShardedQuantEmbeddingBagCollection"], [3, 2, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [3, 1, 1, "", "create_infer_embedding_bag_sharding"], [3, 1, 1, "", "flatten_feature_lengths"], [3, 1, 1, "", "get_device_from_parameter_sharding"], [3, 1, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "copy"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "embedding_bag_configs"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharding_type_to_sharding_infos"], [3, 5, 1, "", "shardings"], [3, 4, 1, "", "tbes_configs"], [3, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[3, 4, 1, "", "apply_feature_processor"], [3, 4, 1, "", "compute"], [3, 3, 1, "", "embedding_bags"], [3, 3, 1, "", "tbes"], [3, 3, 1, "", "training"]], "torchrec.distributed.sharding": [[5, 0, 0, "-", "cw_sharding"], [5, 0, 0, "-", "dp_sharding"], [5, 0, 0, "-", "rw_sharding"], [5, 0, 0, "-", "tw_sharding"], [5, 0, 0, "-", "twcw_sharding"], [5, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[5, 2, 1, "", "BaseCwEmbeddingSharding"], [5, 2, 1, "", "CwPooledEmbeddingSharding"], [5, 2, 1, "", "InferCwPooledEmbeddingDist"], [5, 2, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [5, 2, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "uncombined_embedding_dims"], [5, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[5, 2, 1, "", "BaseDpEmbeddingSharding"], [5, 2, 1, "", "DpPooledEmbeddingDist"], [5, 2, 1, "", "DpPooledEmbeddingSharding"], [5, 2, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[5, 2, 1, "", "BaseRwEmbeddingSharding"], [5, 2, 1, "", "InferCPURwSparseFeaturesDist"], [5, 2, 1, "", "InferRwPooledEmbeddingDist"], [5, 2, 1, "", "InferRwPooledEmbeddingSharding"], [5, 2, 1, "", "InferRwSparseFeaturesDist"], [5, 2, 1, "", "RwPooledEmbeddingDist"], [5, 2, 1, "", "RwPooledEmbeddingSharding"], [5, 2, 1, "", "RwSparseFeaturesDist"], [5, 1, 1, "", "get_block_sizes_runtime_device"], [5, 1, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferCPURwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[5, 2, 1, "", "BaseTwEmbeddingSharding"], [5, 2, 1, "", "InferTwEmbeddingSharding"], [5, 2, 1, "", "InferTwPooledEmbeddingDist"], [5, 2, 1, "", "InferTwSparseFeaturesDist"], [5, 2, 1, "", "TwPooledEmbeddingDist"], [5, 2, 1, "", "TwPooledEmbeddingSharding"], [5, 2, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"], [5, 4, 1, "", "feature_names_per_rank"], [5, 4, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[5, 2, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[5, 2, 1, "", "BaseTwRwEmbeddingSharding"], [5, 2, 1, "", "TwRwPooledEmbeddingDist"], [5, 2, 1, "", "TwRwPooledEmbeddingSharding"], [5, 2, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.types": [[3, 2, 1, "", "Awaitable"], [3, 2, 1, "", "CacheParams"], [3, 2, 1, "", "CacheStatistics"], [3, 2, 1, "", "CommOp"], [3, 2, 1, "", "ComputeKernel"], [3, 2, 1, "", "EmbeddingModuleShardingPlan"], [3, 2, 1, "", "GenericMeta"], [3, 2, 1, "", "GetItemLazyAwaitable"], [3, 2, 1, "", "LazyAwaitable"], [3, 2, 1, "", "LazyGetItemMixin"], [3, 2, 1, "", "LazyNoWait"], [3, 2, 1, "", "ModuleSharder"], [3, 2, 1, "", "ModuleShardingPlan"], [3, 2, 1, "", "NoOpQuantizedCommCodec"], [3, 2, 1, "", "NoWait"], [3, 2, 1, "", "NullShardedModuleContext"], [3, 2, 1, "", "NullShardingContext"], [3, 2, 1, "", "ParameterSharding"], [3, 2, 1, "", "ParameterStorage"], [3, 2, 1, "", "QuantizedCommCodec"], [3, 2, 1, "", "QuantizedCommCodecs"], [3, 2, 1, "", "ShardedModule"], [3, 2, 1, "", "ShardingEnv"], [3, 2, 1, "", "ShardingPlan"], [3, 2, 1, "", "ShardingPlanner"], [3, 2, 1, "", "ShardingType"], [3, 1, 1, "", "get_tensor_size_bytes"], [3, 1, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[3, 5, 1, "", "callbacks"], [3, 4, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[3, 3, 1, "id34", "algorithm"], [3, 3, 1, "id35", "load_factor"], [3, 3, 1, "id36", "precision"], [3, 3, 1, "id37", "prefetch_pipeline"], [3, 3, 1, "id38", "reserved_memory"], [3, 3, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[3, 5, 1, "", "cacheability"], [3, 5, 1, "", "expected_lookups"], [3, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[3, 3, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [3, 3, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [3, 3, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[3, 3, 1, "", "DEFAULT"]], "torchrec.distributed.types.ModuleSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "module_type"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[3, 4, 1, "", "calc_quantized_size"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "decode"], [3, 4, 1, "", "encode"], [3, 4, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.types.ParameterSharding": [[3, 3, 1, "", "bounds_check_mode"], [3, 3, 1, "", "cache_params"], [3, 3, 1, "", "compute_kernel"], [3, 3, 1, "", "enforce_hbm"], [3, 3, 1, "", "output_dtype"], [3, 3, 1, "", "ranks"], [3, 3, 1, "", "sharding_spec"], [3, 3, 1, "", "sharding_type"], [3, 3, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[3, 3, 1, "", "DDR"], [3, 3, 1, "", "HBM"]], "torchrec.distributed.types.QuantizedCommCodec": [[3, 4, 1, "", "calc_quantized_size"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "decode"], [3, 4, 1, "", "encode"], [3, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[3, 3, 1, "", "backward"], [3, 3, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "sharded_parameter_names"], [3, 3, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[3, 4, 1, "", "from_local"], [3, 4, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[3, 4, 1, "", "get_plan_for_module"], [3, 3, 1, "id40", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[3, 4, 1, "", "collective_plan"], [3, 4, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[3, 3, 1, "", "COLUMN_WISE"], [3, 3, 1, "", "DATA_PARALLEL"], [3, 3, 1, "", "ROW_WISE"], [3, 3, 1, "", "TABLE_COLUMN_WISE"], [3, 3, 1, "", "TABLE_ROW_WISE"], [3, 3, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[3, 2, 1, "", "CopyableMixin"], [3, 1, 1, "", "add_params_from_parameter_sharding"], [3, 1, 1, "", "add_prefix_to_state_dict"], [3, 1, 1, "", "append_prefix"], [3, 1, 1, "", "convert_to_fbgemm_types"], [3, 1, 1, "", "copy_to_device"], [3, 1, 1, "", "filter_state_dict"], [3, 1, 1, "", "get_unsharded_module_names"], [3, 1, 1, "", "init_parameters"], [3, 1, 1, "", "merge_fused_params"], [3, 1, 1, "", "none_throws"], [3, 1, 1, "", "optimizer_type_to_emb_opt_type"], [3, 2, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[3, 4, 1, "", "copy"], [3, 3, 1, "", "training"]], "torchrec.fx": [[6, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[6, 2, 1, "", "Tracer"], [6, 1, 1, "", "is_fx_tracing"], [6, 1, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[6, 4, 1, "", "create_arg"], [6, 3, 1, "", "graph"], [6, 4, 1, "", "is_leaf_module"], [6, 3, 1, "", "module_stack"], [6, 3, 1, "", "node_name_to_scope"], [6, 4, 1, "", "path_of_module"], [6, 3, 1, "", "scope"], [6, 4, 1, "", "trace"]], "torchrec.inference": [[7, 0, 0, "-", "model_packager"], [7, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[7, 2, 1, "", "PredictFactoryPackager"], [7, 1, 1, "", "load_config_text"], [7, 1, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[7, 4, 1, "", "save_predict_factory"], [7, 4, 1, "", "set_extern_modules"], [7, 4, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[7, 2, 1, "", "BatchingMetadata"], [7, 2, 1, "", "PredictFactory"], [7, 2, 1, "", "PredictModule"], [7, 2, 1, "", "QualNameMetadata"], [7, 1, 1, "", "quantize_dense"], [7, 1, 1, "", "quantize_embeddings"], [7, 1, 1, "", "quantize_feature"], [7, 1, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[7, 3, 1, "", "device"], [7, 3, 1, "", "pinned"], [7, 3, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[7, 4, 1, "", "batching_metadata"], [7, 4, 1, "", "batching_metadata_json"], [7, 4, 1, "", "create_predict_module"], [7, 4, 1, "", "model_inputs_data"], [7, 4, 1, "", "qualname_metadata"], [7, 4, 1, "", "qualname_metadata_json"], [7, 4, 1, "", "result_metadata"], [7, 4, 1, "", "run_weights_dependent_transformations"], [7, 4, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[7, 4, 1, "", "forward"], [7, 4, 1, "", "predict_forward"], [7, 5, 1, "", "predict_module"], [7, 4, 1, "", "state_dict"], [7, 3, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[7, 3, 1, "", "need_preproc"]], "torchrec.models": [[8, 0, 0, "-", "deepfm"]], "torchrec.models.deepfm": [[8, 2, 1, "", "DenseArch"], [8, 2, 1, "", "FMInteractionArch"], [8, 2, 1, "", "OverArch"], [8, 2, 1, "", "SimpleDeepFMNN"], [8, 2, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.modules": [[9, 0, 0, "-", "activation"], [9, 0, 0, "-", "crossnet"], [9, 0, 0, "-", "deepfm"], [9, 0, 0, "-", "embedding_configs"], [9, 0, 0, "-", "embedding_modules"], [9, 0, 0, "-", "feature_processor"], [9, 0, 0, "-", "lazy_extension"], [9, 0, 0, "-", "mc_embedding_modules"], [9, 0, 0, "-", "mc_modules"], [9, 0, 0, "-", "mlp"], [9, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[9, 2, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet": [[9, 2, 1, "", "CrossNet"], [9, 2, 1, "", "LowRankCrossNet"], [9, 2, 1, "", "LowRankMixtureCrossNet"], [9, 2, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.deepfm": [[9, 2, 1, "", "DeepFM"], [9, 2, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_configs": [[9, 2, 1, "", "BaseEmbeddingConfig"], [9, 2, 1, "", "EmbeddingBagConfig"], [9, 2, 1, "", "EmbeddingConfig"], [9, 2, 1, "", "EmbeddingTableConfig"], [9, 2, 1, "", "PoolingType"], [9, 2, 1, "", "QuantConfig"], [9, 2, 1, "", "ShardingType"], [9, 1, 1, "", "data_type_to_dtype"], [9, 1, 1, "", "data_type_to_sparse_type"], [9, 1, 1, "", "dtype_to_data_type"], [9, 1, 1, "", "pooling_type_to_pooling_mode"], [9, 1, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[9, 3, 1, "", "data_type"], [9, 3, 1, "", "embedding_dim"], [9, 3, 1, "", "feature_names"], [9, 4, 1, "", "get_weight_init_max"], [9, 4, 1, "", "get_weight_init_min"], [9, 3, 1, "", "init_fn"], [9, 3, 1, "", "name"], [9, 3, 1, "", "need_pos"], [9, 3, 1, "", "num_embeddings"], [9, 4, 1, "", "num_features"], [9, 3, 1, "", "pruning_indices_remapping"], [9, 3, 1, "", "weight_init_max"], [9, 3, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[9, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[9, 3, 1, "", "embedding_dim"], [9, 3, 1, "", "feature_names"], [9, 3, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[9, 3, 1, "", "embedding_names"], [9, 3, 1, "", "has_feature_processor"], [9, 3, 1, "", "is_weighted"], [9, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[9, 3, 1, "", "MEAN"], [9, 3, 1, "", "NONE"], [9, 3, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[9, 3, 1, "", "activation"], [9, 3, 1, "", "per_table_weight_dtype"], [9, 3, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[9, 3, 1, "", "COLUMN_WISE"], [9, 3, 1, "", "DATA_PARALLEL"], [9, 3, 1, "", "ROW_WISE"], [9, 3, 1, "", "TABLE_COLUMN_WISE"], [9, 3, 1, "", "TABLE_ROW_WISE"], [9, 3, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[9, 2, 1, "", "EmbeddingBagCollection"], [9, 2, 1, "", "EmbeddingBagCollectionInterface"], [9, 2, 1, "", "EmbeddingCollection"], [9, 2, 1, "", "EmbeddingCollectionInterface"], [9, 1, 1, "", "get_embedding_names_by_table"], [9, 1, 1, "", "process_pooled_embeddings"], [9, 1, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[9, 5, 1, "", "device"], [9, 4, 1, "", "embedding_bag_configs"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "is_weighted"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[9, 4, 1, "", "embedding_bag_configs"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "is_weighted"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[9, 5, 1, "", "device"], [9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "embedding_dim"], [9, 4, 1, "", "embedding_names_by_table"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "need_indices"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "embedding_dim"], [9, 4, 1, "", "embedding_names_by_table"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "need_indices"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor": [[9, 2, 1, "", "BaseFeatureProcessor"], [9, 2, 1, "", "BaseGroupedFeatureProcessor"], [9, 2, 1, "", "PositionWeightedModule"], [9, 2, 1, "", "PositionWeightedProcessor"], [9, 1, 1, "", "offsets_to_range_traceble"], [9, 1, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[9, 4, 1, "", "forward"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[9, 4, 1, "", "forward"], [9, 4, 1, "", "named_buffers"], [9, 4, 1, "", "state_dict"], [9, 3, 1, "", "training"]], "torchrec.modules.lazy_extension": [[9, 2, 1, "", "LazyModuleExtensionMixin"], [9, 1, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[9, 4, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[9, 2, 1, "", "BaseManagedCollisionEmbeddingCollection"], [9, 2, 1, "", "ManagedCollisionEmbeddingBagCollection"], [9, 2, 1, "", "ManagedCollisionEmbeddingCollection"], [9, 1, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[9, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[9, 3, 1, "", "training"]], "torchrec.modules.mc_modules": [[9, 2, 1, "", "DistanceLFU_EvictionPolicy"], [9, 2, 1, "", "LFU_EvictionPolicy"], [9, 2, 1, "", "LRU_EvictionPolicy"], [9, 2, 1, "", "MCHEvictionPolicy"], [9, 2, 1, "", "MCHEvictionPolicyMetadataInfo"], [9, 2, 1, "", "MCHManagedCollisionModule"], [9, 2, 1, "", "ManagedCollisionCollection"], [9, 2, 1, "", "ManagedCollisionModule"], [9, 1, 1, "", "apply_mc_method_to_jt_dict"], [9, 1, 1, "", "average_threshold_filter"], [9, 1, 1, "", "dynamic_threshold_filter"], [9, 1, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[9, 3, 1, "", "is_history_metadata"], [9, 3, 1, "", "is_mch_metadata"], [9, 3, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "input_size"], [9, 4, 1, "", "output_size"], [9, 4, 1, "", "preprocess"], [9, 4, 1, "", "profile"], [9, 4, 1, "", "rebuild_with_output_id_range"], [9, 4, 1, "", "remap"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[9, 5, 1, "", "device"], [9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "input_size"], [9, 4, 1, "", "output_size"], [9, 4, 1, "", "preprocess"], [9, 4, 1, "", "rebuild_with_output_id_range"], [9, 3, 1, "", "training"]], "torchrec.modules.mlp": [[9, 2, 1, "", "MLP"], [9, 2, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.utils": [[9, 1, 1, "", "check_module_output_dimension"], [9, 1, 1, "", "construct_jagged_tensors"], [9, 1, 1, "", "construct_jagged_tensors_inference"], [9, 1, 1, "", "construct_modulelist_from_single_module"], [9, 1, 1, "", "convert_list_of_modules_to_modulelist"], [9, 1, 1, "", "extract_module_or_tensor_callable"], [9, 1, 1, "", "get_module_output_dimension"], [9, 1, 1, "", "init_mlp_weights_xavier_uniform"]], "torchrec.optim": [[10, 0, 0, "-", "clipping"], [10, 0, 0, "-", "fused"], [10, 0, 0, "-", "keyed"], [10, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[10, 2, 1, "", "GradientClipping"], [10, 2, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[10, 3, 1, "", "NONE"], [10, 3, 1, "", "NORM"], [10, 3, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[10, 4, 1, "", "step"]], "torchrec.optim.fused": [[10, 2, 1, "", "EmptyFusedOptimizer"], [10, 2, 1, "", "FusedOptimizer"], [10, 2, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[10, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[10, 2, 1, "", "CombinedOptimizer"], [10, 2, 1, "", "KeyedOptimizer"], [10, 2, 1, "", "KeyedOptimizerWrapper"], [10, 2, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[10, 5, 1, "", "optimizers"], [10, 5, 1, "", "param_groups"], [10, 5, 1, "", "params"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "prepend_opt_key"], [10, 4, 1, "", "save_param_groups"], [10, 5, 1, "", "state"], [10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[10, 4, 1, "", "add_param_group"], [10, 4, 1, "", "init_state"], [10, 4, 1, "", "load_state_dict"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "save_param_groups"], [10, 4, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[10, 4, 1, "", "add_param_group"], [10, 4, 1, "", "load_state_dict"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "save_param_groups"], [10, 4, 1, "", "state_dict"], [10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[10, 2, 1, "", "WarmupOptimizer"], [10, 2, 1, "", "WarmupPolicy"], [10, 2, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[10, 3, 1, "", "CONSTANT"], [10, 3, 1, "", "INVSQRT"], [10, 3, 1, "", "LINEAR"], [10, 3, 1, "", "NONE"], [10, 3, 1, "", "POLY"], [10, 3, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[10, 3, 1, "", "decay_iters"], [10, 3, 1, "", "lr_scale"], [10, 3, 1, "", "max_iters"], [10, 3, 1, "", "policy"], [10, 3, 1, "", "value"]], "torchrec.quant": [[11, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[11, 2, 1, "", "EmbeddingBagCollection"], [11, 2, 1, "", "EmbeddingCollection"], [11, 2, 1, "", "FeatureProcessedEmbeddingBagCollection"], [11, 1, 1, "", "features_to_dict"], [11, 1, 1, "", "for_each_module_of_type_do"], [11, 1, 1, "", "pruned_num_embeddings"], [11, 1, 1, "", "quant_prep_customize_row_alignment"], [11, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [11, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [11, 1, 1, "", "quant_prep_enable_register_tbes"], [11, 1, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_bag_configs"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "output_dtype"], [11, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_configs"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "embedding_names_by_table"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 4, 1, "", "need_indices"], [11, 4, 1, "", "output_dtype"], [11, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[11, 3, 1, "", "embedding_bags"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 3, 1, "", "tbes"], [11, 3, 1, "", "training"]], "torchrec.sparse": [[12, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[12, 2, 1, "", "ComputeJTDictToKJT"], [12, 2, 1, "", "ComputeKJTToJTDict"], [12, 2, 1, "", "JaggedTensor"], [12, 2, 1, "", "JaggedTensorMeta"], [12, 2, 1, "", "KeyedJaggedTensor"], [12, 2, 1, "", "KeyedTensor"], [12, 1, 1, "", "flatten_kjt_list"], [12, 1, 1, "", "is_non_strict_exporting"], [12, 1, 1, "", "jt_is_equal"], [12, 1, 1, "", "kjt_is_equal"], [12, 1, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[12, 4, 1, "", "forward"], [12, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[12, 4, 1, "", "forward"], [12, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[12, 4, 1, "", "empty"], [12, 4, 1, "", "from_dense"], [12, 4, 1, "", "from_dense_lengths"], [12, 4, 1, "", "lengths"], [12, 4, 1, "", "lengths_or_none"], [12, 4, 1, "", "offsets"], [12, 4, 1, "", "offsets_or_none"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dense"], [12, 4, 1, "", "to_dense_weights"], [12, 4, 1, "", "to_padded_dense"], [12, 4, 1, "", "to_padded_dense_weights"], [12, 4, 1, "", "values"], [12, 4, 1, "", "weights"], [12, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[12, 4, 1, "", "concat"], [12, 4, 1, "", "device"], [12, 4, 1, "", "dist_init"], [12, 4, 1, "", "dist_labels"], [12, 4, 1, "", "dist_splits"], [12, 4, 1, "", "dist_tensors"], [12, 4, 1, "", "empty"], [12, 4, 1, "", "empty_like"], [12, 4, 1, "", "flatten_lengths"], [12, 4, 1, "", "from_jt_dict"], [12, 4, 1, "", "from_lengths_sync"], [12, 4, 1, "", "from_offsets_sync"], [12, 4, 1, "", "index_per_key"], [12, 4, 1, "", "inverse_indices"], [12, 4, 1, "", "inverse_indices_or_none"], [12, 4, 1, "", "keys"], [12, 4, 1, "", "length_per_key"], [12, 4, 1, "", "length_per_key_or_none"], [12, 4, 1, "", "lengths"], [12, 4, 1, "", "lengths_offset_per_key"], [12, 4, 1, "", "lengths_or_none"], [12, 4, 1, "", "offset_per_key"], [12, 4, 1, "", "offset_per_key_or_none"], [12, 4, 1, "", "offsets"], [12, 4, 1, "", "offsets_or_none"], [12, 4, 1, "", "permute"], [12, 4, 1, "", "pin_memory"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "split"], [12, 4, 1, "", "stride"], [12, 4, 1, "", "stride_per_key"], [12, 4, 1, "", "stride_per_key_per_rank"], [12, 4, 1, "", "sync"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dict"], [12, 4, 1, "", "unsync"], [12, 4, 1, "", "values"], [12, 4, 1, "", "variable_stride_per_key"], [12, 4, 1, "", "weights"], [12, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[12, 4, 1, "", "from_tensor_list"], [12, 4, 1, "", "key_dim"], [12, 4, 1, "", "keys"], [12, 4, 1, "", "length_per_key"], [12, 4, 1, "", "offset_per_key"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "regroup"], [12, 4, 1, "", "regroup_as_dict"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dict"], [12, 4, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:attribute", "4": "py:method", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "method", "Python method"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "document": 0, "tutori": 0, "api": 0, "content": [0, 6, 7, 8, 10, 11, 12], "indic": 0, "tabl": 0, "dataset": [1, 2], "criteo": 1, "movielen": 1, "random": 1, "util": [1, 3, 4, 9], "script": 2, "contiguous_preproc_criteo": 2, "npy_preproc_criteo": 2, "distribut": [3, 4, 5], "collective_util": 3, "comm": 3, "comm_op": 3, "dist_data": [3, 5], "embed": 3, "embedding_lookup": 3, "embedding_shard": 3, "embedding_typ": 3, "embeddingbag": 3, "grouped_position_weight": 3, "model_parallel": 3, "quant_embeddingbag": 3, "train_pipelin": 3, "type": [3, 4], "mc_modul": [3, 9], "mc_embeddingbag": 3, "mc_embed": 3, "planner": 4, "constant": 4, "enumer": 4, "partition": 4, "perf_model": 4, "propos": 4, "shard_estim": 4, "stat": 4, "storage_reserv": 4, "shard": 5, "cw_shard": 5, "dp_shard": 5, "rw_shard": 5, "tw_shard": 5, "twcw_shard": 5, "twrw_shard": 5, "fx": 6, "tracer": 6, "modul": [6, 7, 8, 9, 10, 11, 12], "infer": 7, "model_packag": 7, "model": 8, "deepfm": [8, 9], "dlrm": 8, "activ": 9, "crossnet": 9, "embedding_config": 9, "embedding_modul": [9, 11], "feature_processor": 9, "lazy_extens": 9, "mlp": 9, "mc_embedding_modul": 9, "optim": 10, "clip": 10, "fuse": 10, "kei": 10, "warmup": 10, "quant": 11, "spars": 12, "jagged_tensor": 12}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 56}})
\ No newline at end of file
+Search.setIndex({"docnames": ["index", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"pytorch": [0, 3, 9, 10, 12], "domain": 0, "librari": [0, 7], "built": [0, 9], "provid": [0, 3, 4, 5, 7, 9, 11], "common": [0, 9, 12], "sparsiti": 0, "parallel": [0, 3, 5], "primit": [0, 3, 5], "need": [0, 3, 5, 6, 7, 9, 10, 11, 12], "larg": [0, 4], "scale": 0, "recommend": 0, "system": [0, 3, 4], "recsi": [0, 8, 10], "It": [0, 3, 4, 5, 7, 9, 10, 11, 12], "allow": [0, 3, 4, 6, 9, 10], "author": [0, 3, 7], "train": [0, 3, 4, 5, 7, 8, 9, 10, 11, 12], "model": [0, 3, 4, 5, 6, 7, 9, 10, 11], "embed": [0, 4, 5, 6, 8, 9, 11, 12], "shard": [0, 3, 4, 7, 9, 10, 11], "across": [0, 3, 4, 5], "mani": [0, 3, 5], "gpu": [0, 3, 4, 7], "For": [0, 3, 4, 5, 8, 9, 10, 11, 12], "instal": 0, "instruct": 0, "visit": 0, "http": [0, 3, 4, 8, 9, 12], "github": [0, 9], "com": [0, 9], "readm": 0, "In": [0, 3, 4, 9, 10, 12], "thi": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "we": [0, 3, 4, 5, 6, 7, 9, 10, 11, 12], "introduc": [0, 10], "primari": [0, 7], "call": [0, 3, 4, 5, 7, 9, 10, 11], "distributedmodelparallel": [0, 3], "dmp": [0, 3], "like": [0, 3, 4, 5, 6, 9, 10, 12], "s": [0, 3, 4, 6, 7, 8, 9, 10, 11, 12], "distributeddataparallel": 0, "wrap": [0, 3, 5, 10], "enabl": [0, 3, 4, 10], "distribut": [0, 7, 9, 10, 12], "sourc": [0, 8, 9], "open": 0, "googl": 0, "colab": 0, "dataset": [0, 3, 4], "criteo": 0, "movielen": 0, "random": [0, 4], "util": [0, 5], "script": [0, 12], "contiguous_preproc_criteo": 0, "npy_preproc_criteo": 0, "collective_util": 0, "comm": 0, "comm_op": 0, "dist_data": 0, "embedding_lookup": 0, "embedding_shard": 0, "embedding_typ": 0, "embeddingbag": [0, 4, 6, 8, 9, 11], "grouped_position_weight": 0, "model_parallel": 0, "quant_embeddingbag": 0, "train_pipelin": 0, "type": [0, 5, 6, 7, 8, 9, 11, 12], "mc_modul": 0, "mc_embeddingbag": 0, "mc_embed": 0, "planner": [0, 3], "constant": [0, 10], "enumer": [0, 3, 9, 10], "partition": 0, "perf_model": 0, "propos": [0, 8], "shard_estim": 0, "stat": [0, 3], "storage_reserv": 0, "cw_shard": 0, "dp_shard": 0, "rw_shard": 0, "tw_shard": 0, "twcw_shard": 0, "twrw_shard": 0, "fx": [0, 7], "tracer": 0, "modul": [0, 3, 4, 5], "infer": [0, 3, 4, 5, 11, 12], "model_packag": 0, "deepfm": 0, "dlrm": [0, 7], "activ": [0, 11], "crossnet": 0, "embedding_config": [0, 3, 11], "embedding_modul": 0, "feature_processor": [0, 3, 5, 11], "lazy_extens": 0, "mlp": [0, 8], "mc_embedding_modul": 0, "optim": [0, 3, 4, 9, 11], "clip": 0, "fuse": [0, 3, 5], "kei": [0, 3, 4, 5, 7, 8, 9, 11, 12], "warmup": 0, "quant": [0, 3], "spars": [0, 3, 5, 8, 9, 11], "jagged_tensor": [0, 3], "index": [0, 9, 12], "search": [0, 4], "page": 0, "necessari": [3, 4, 5], "oper": [3, 4, 5, 6, 9, 12], "These": [3, 4, 7, 9], "includ": [3, 4, 6, 7, 9, 12], "through": [3, 6, 10], "collect": [3, 5, 8, 9, 10, 11], "all": [3, 4, 5, 7, 8, 9, 10, 12], "reduc": [3, 5, 9, 11], "scatter": [3, 5], "wrapper": [3, 10], "featur": [3, 4, 5, 8, 9, 11, 12], "kjt": [3, 4, 5, 8, 9, 11, 12], "variou": [3, 7, 9], "implement": [3, 4, 5, 7, 9, 10, 12], "shardedembeddingbag": 3, "nn": [3, 4, 6, 9, 11], "shardedembeddingbagcollect": [3, 9, 11], "embeddingbagcollect": [3, 8, 9, 11], "sharder": [3, 4], "defin": [3, 5, 7, 8, 9], "ani": [3, 4, 5, 6, 7, 9, 10, 12], "support": [3, 4, 5, 6, 9, 10], "comput": [3, 4, 5, 7, 8, 9, 11], "kernel": [3, 4, 9], "which": [3, 4, 5, 7, 9, 10, 12], "ar": [3, 4, 5, 7, 9, 10, 11, 12], "devic": [3, 4, 5, 6, 7, 8, 9, 11, 12], "cpu": [3, 4], "mai": [3, 12], "batch": [3, 4, 5, 6, 7, 8, 9, 11, 12], "togeth": [3, 9], "tabl": [3, 4, 5, 6, 8, 9, 11], "fusion": 3, "pipelin": [3, 4, 9, 12], "trainpipelinesparsedist": 3, "overlap": 3, "dataload": 3, "transfer": 3, "copi": [3, 5, 7, 9, 10, 12], "inter": [3, 9], "commun": [3, 4, 5], "input_dist": [3, 9], "forward": [3, 4, 5, 7, 8, 9, 11, 12], "backward": [3, 4, 6, 10], "increas": 3, "perform": [3, 4, 5, 7, 9, 10, 11], "quantiz": [3, 5, 6, 11], "precis": [3, 9, 11], "file": 3, "contain": [3, 4, 7, 9, 10, 11], "construct": [3, 6, 9, 12], "base": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "control": [3, 6], "flow": [3, 6], "invoke_on_rank_and_broadcast_result": 3, "pg": [3, 4, 5], "processgroup": [3, 4, 5], "rank": [3, 4, 5, 9, 10, 12], "int": [3, 4, 5, 6, 8, 9, 10, 11, 12], "func": 3, "callabl": [3, 5, 6, 9, 10, 11], "t": [3, 4, 5, 6, 7, 9, 10, 12], "arg": [3, 4, 7, 9, 11, 12], "kwarg": [3, 9, 12], "invok": [3, 4], "function": [3, 4, 5, 6, 7, 9, 10, 12], "design": [3, 7, 9], "broadcast": [3, 4], "result": [3, 4, 5, 7, 9, 11], "member": [3, 9], "within": [3, 4, 5, 7, 9, 12], "group": [3, 4, 5, 9, 10, 12], "exampl": [3, 4, 5, 7, 8, 9, 10, 11, 12], "id": [3, 4, 5, 9], "0": [3, 4, 5, 8, 9, 10, 11, 12], "allocate_id": 3, "is_lead": 3, "option": [3, 4, 5, 6, 7, 9, 10, 11, 12], "leader_rank": 3, "bool": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "check": [3, 4, 9, 10, 12], "current": [3, 4, 7, 9], "processs": 3, "leader": 3, "paramet": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "dist": [3, 5], "process": [3, 4, 5, 8, 9, 11], "us": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "determin": [3, 4, 5], "being": [3, 4, 7, 9], "none": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "impli": 3, "onli": [3, 4, 5, 9, 12], "e": [3, 4, 5, 6, 7, 8, 9, 10], "g": [3, 4, 7, 9, 10], "singl": [3, 4, 5, 9, 10], "program": 3, "definit": [3, 6, 7], "default": [3, 4, 6, 7, 8, 9, 10, 11, 12], "The": [3, 4, 5, 6, 7, 8, 9, 10, 12], "caller": 3, "can": [3, 4, 7, 9, 10, 12], "overrid": [3, 4, 6, 7], "context": [3, 5, 12], "specif": [3, 4, 7, 10], "run_on_lead": 3, "get_group_rank": 3, "world_siz": [3, 4, 5], "get": [3, 4, 5], "worker": 3, "also": [3, 4, 7, 9, 10], "avail": [3, 4, 5], "group_rank": 3, "environ": [3, 7], "varibl": 3, "A": [3, 4, 5, 6, 7, 10, 12], "number": [3, 4, 5, 8, 9, 12], "between": [3, 4, 7, 8, 9, 12], "get_num_group": 3, "see": [3, 4, 5, 6, 9, 12], "org": [3, 4, 8, 12], "doc": [3, 12], "stabl": [3, 12], "elast": 3, "run": [3, 4, 5, 7, 9, 10], "html": [3, 12], "get_local_rank": 3, "local": [3, 4, 5, 9], "usual": [3, 4, 9], "its": [3, 4, 5, 9, 10, 12], "node": [3, 6], "get_local_s": 3, "equival": 3, "max_nnod": 3, "intra_and_cross_node_pg": 3, "backend": [3, 5, 7], "str": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "tupl": [3, 4, 5, 6, 7, 9, 10, 11, 12], "creat": [3, 6, 7, 9, 10, 12], "sub": 3, "intra": 3, "cross": [3, 9], "class": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "all2alldenseinfo": 3, "output_split": [3, 5], "list": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "batch_siz": [3, 4, 5, 9, 12], "input_shap": 3, "input_split": [3, 5], "object": [3, 4, 7, 9, 10], "data": [3, 4, 5, 6, 7, 9, 10, 11, 12], "attribut": [3, 4, 10], "when": [3, 4, 6, 9, 10], "alltoall_dens": 3, "all2allpooledinfo": 3, "batch_size_per_rank": [3, 5], "dim_sum_per_rank": [3, 5], "dim_sum_per_rank_tensor": 3, "tensor": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "cumsum_dim_sum_per_rank_tensor": 3, "codec": [3, 5], "quantizedcommcodec": [3, 5], "alltoall_pool": [3, 5], "size": [3, 4, 5, 8, 9, 11, 12], "each": [3, 4, 5, 8, 9, 11, 12], "sum": [3, 4, 5, 9], "dimens": [3, 4, 5, 8, 9, 11, 12], "version": [3, 9, 11], "fast": 3, "_recat_pooled_embedding_grad_out": 3, "cumul": [3, 12], "all2allsequenceinfo": 3, "embedding_dim": [3, 4, 5, 8, 9, 11], "lengths_after_sparse_data_all2al": 3, "forward_recat_tensor": 3, "backward_recat_tensor": 3, "variable_batch_s": 3, "fals": [3, 4, 5, 7, 9, 10, 11, 12], "permuted_lengths_after_sparse_data_all2al": 3, "alltoall_sequ": 3, "length": [3, 4, 5, 8, 9, 11, 12], "after": [3, 4, 5, 9], "alltoal": [3, 5], "recat": [3, 5, 12], "input": [3, 4, 5, 6, 7, 8, 9, 11, 12], "split": [3, 4, 5, 7, 12], "output": [3, 4, 5, 7, 8, 9, 11, 12], "whether": [3, 4, 6, 9, 11], "variabl": [3, 5, 9, 11, 12], "befor": [3, 5, 9, 10], "all2allvinfo": 3, "dims_sum_per_rank": 3, "b_global": 3, "b_local": 3, "b_local_list": 3, "d_local_list": 3, "input_split_s": 3, "factori": [3, 4, 9], "output_split_s": 3, "alltoallv": 3, "global": [3, 4, 5], "my": 3, "rememb": [3, 12], "how": [3, 4, 5, 7, 10], "do": [3, 4, 9, 10, 12], "all_to_all_singl": 3, "fill": 3, "all2all_pooled_req": 3, "static": [3, 4, 10, 12], "ctx": 3, "unus": 3, "formula": 3, "differenti": 3, "mode": [3, 4], "automat": [3, 4, 12], "overridden": [3, 5, 7, 9], "subclass": [3, 5, 7, 9, 10], "vjp": 3, "must": [3, 4, 5, 7, 9], "accept": [3, 4, 7, 9], "first": [3, 4, 5, 9, 10, 12], "argument": [3, 6, 7, 9], "follow": [3, 4, 5, 8, 9, 10, 12], "return": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "pass": [3, 4, 5, 7, 9, 10, 11, 12], "non": [3, 4, 5, 6, 9, 11], "should": [3, 4, 5, 7, 8, 9, 10, 12], "were": 3, "gradient": [3, 4, 10], "w": [3, 5, 9, 12], "r": [3, 9], "given": [3, 4, 5, 6, 9], "valu": [3, 4, 5, 6, 8, 9, 10, 11, 12], "correspond": [3, 4, 5, 7, 9, 12], "If": [3, 4, 7, 9, 10, 12], "an": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "requir": [3, 4, 9, 10], "grad": [3, 10], "you": [3, 5, 6, 12], "just": [3, 4, 8, 9, 12], "retriev": 3, "save": [3, 9, 10], "dure": [3, 4, 10], "ha": [3, 4, 9, 12], "needs_input_grad": 3, "boolean": 3, "repres": [3, 4, 7, 8, 9, 11, 12], "have": [3, 4, 5, 8, 9, 10, 12], "true": [3, 4, 7, 9, 10, 12], "myreq": 3, "request": [3, 7, 10], "a2ai": 3, "input_embed": [3, 9], "custom": [3, 4, 6, 9], "autograd": [3, 7, 9], "There": 3, "two": [3, 4, 9, 12], "wai": [3, 4], "usag": [3, 4], "1": [3, 4, 5, 8, 9, 10, 11, 12], "combin": [3, 9, 10], "staticmethod": 3, "def": [3, 9], "other": [3, 4, 10], "more": [3, 4, 5, 9], "detail": [3, 4, 5, 9], "2": [3, 4, 5, 8, 9, 10, 11, 12], "separ": 3, "setup_context": 3, "longer": [3, 4], "instead": [3, 5, 7, 9, 10], "torch": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "handl": [3, 4, 5, 6, 7, 9, 10], "set": [3, 4, 7, 9, 10], "up": [3, 11], "extend": 3, "store": [3, 4, 5, 12], "arbitrari": 3, "directli": [3, 10], "though": 3, "enforc": [3, 7, 9], "compat": [3, 6, 10], "either": [3, 9], "save_for_backward": 3, "thei": [3, 12], "intend": 3, "save_for_forward": 3, "jvp": 3, "all2all_pooled_wait": 3, "grad_output": 3, "dummy_tensor": 3, "all2all_seq_req": 3, "sharded_input_embed": 3, "all2all_seq_req_wait": 3, "sharded_grad_output": 3, "all2allv_req": 3, "all2allv_wait": 3, "allgatherbaseinfo": 3, "input_s": [3, 9], "all_gatther_base_pool": 3, "allgatherbase_req": 3, "agi": 3, "allgatherbase_wait": 3, "reducescatterbaseinfo": 3, "reduce_scatter_base_pool": 3, "flatten": [3, 5, 9], "reducescatterbase_req": 3, "rsi": 3, "reducescatterbase_wait": 3, "reducescatterinfo": 3, "reduce_scatter_pool": 3, "produc": [3, 4], "reducescattervinfo": 3, "equal_split": 3, "total_input_s": 3, "reduce_scatter_v_pool": 3, "along": [3, 5, 10, 12], "dim": [3, 5], "total": [3, 4, 5], "reducescatterv_req": 3, "reducescatterv_wait": 3, "reducescatter_req": 3, "reducescatter_wait": 3, "await": [3, 5, 6], "variablebatchall2allpooledinfo": 3, "batch_size_per_rank_per_featur": [3, 5], "batch_size_per_feature_pre_a2a": [3, 5], "emb_dim_per_rank_per_featur": [3, 5], "variable_batch_alltoall_pool": [3, 5], "per": [3, 4, 5, 9, 12], "variable_batch_all2all_pooled_req": 3, "variable_batch_all2all_pooled_wait": 3, "all2all_pooled_sync": 3, "all2all_sequence_sync": 3, "all2allv_sync": 3, "all_gather_base_pool": 3, "gather": [3, 5], "from": [3, 4, 5, 6, 7, 9, 10, 12], "form": [3, 9, 11], "pool": [3, 4, 5, 8, 9, 11, 12], "output_tensor_s": 3, "work": [3, 4, 7, 9, 12], "async": [3, 5], "wait": [3, 5], "later": [3, 9], "experiment": [3, 9], "subject": 3, "chang": [3, 9, 10], "all_gather_base_sync": 3, "a2a_pooled_embs_tensor": 3, "world": [3, 5], "Then": 3, "concaten": [3, 5, 9, 12], "receiv": [3, 10], "Its": 3, "shape": [3, 5, 9, 12], "b": [3, 4, 5, 8, 9, 11, 12], "x": [3, 4, 5, 8, 9, 11, 12], "d_local_sum": 3, "where": [3, 4, 5, 9, 11], "a2a_sequence_embs_tensor": 3, "sequenc": [3, 4, 5], "doe": [3, 8, 9, 10, 12], "mix": 3, "out_split": 3, "per_rank_split_length": 3, "one": [3, 4, 5, 7, 8, 9, 10], "differ": [3, 4, 5, 9, 10, 12], "specifi": [3, 4, 5, 6, 9, 10], "assumpt": [3, 12], "emb": 3, "same": [3, 4, 5, 7, 8, 9, 12], "fn": [3, 9], "get_gradient_divis": 3, "reduce_scatter_base_sync": 3, "chunk": [3, 5], "reduce_scatter_sync": 3, "reduce_scatter_v_per_feature_pool": 3, "v": [3, 5, 9, 12], "d": [3, 8, 9, 11, 12], "unevenli": 3, "accord": [3, 4, 5, 7, 8, 10, 12], "reduce_scatter_v_sync": 3, "set_gradient_divis": 3, "val": 3, "variable_batch_all2all_pooled_sync": 3, "embeddingsalltoon": [3, 5], "cat_dim": [3, 5, 12], "merg": [3, 5], "buffer": [3, 5, 7, 9], "alloc": [3, 5, 7], "topolog": [3, 4, 5], "would": [3, 5, 12], "alltoon": [3, 5], "set_devic": [3, 5], "device_str": [3, 5], "embeddingsalltoonereduc": [3, 5], "kjtalltoal": [3, 5], "stagger": [3, 5, 12], "redistribut": [3, 5], "keyedjaggedtensor": [3, 5, 8, 9, 11, 12], "part": [3, 4, 5, 9, 10], "kjtalltoallsplitsawait": [3, 5], "transmit": [3, 5], "correct": [3, 5, 12], "space": [3, 4, 5, 8], "kjtalltoalltensorsawait": [3, 5], "actual": [3, 4, 5, 7, 9], "asynchron": [3, 5], "len": [3, 5, 8], "indic": [3, 5, 7, 9, 10, 11, 12], "send": [3, 5], "assum": [3, 4, 5, 7, 8, 10], "order": [3, 4, 5, 7, 9, 12], "destin": [3, 5, 7, 9], "appli": [3, 5, 8, 9], "_get_recat": [3, 5], "c": [3, 5, 7, 12], "kjta2a": [3, 5], "rank0_input": [3, 5], "hold": [3, 4, 5, 10, 12], "v0": [3, 5, 12], "v1": [3, 5, 9, 12], "v2": [3, 5, 9, 12], "rank1_input": [3, 5], "v3": [3, 5, 12], "v4": [3, 5, 12], "rank0_output": [3, 5], "3": [3, 4, 5, 8, 9, 10, 11, 12], "4": [3, 4, 5, 8, 9, 11, 12], "5": [3, 5, 8, 9, 11, 12], "rank1_output": [3, 5], "relev": [3, 4, 5], "issu": [3, 5, 9], "second": [3, 4, 5, 9, 12], "label": [3, 5], "tensor_split": [3, 5], "input_tensor": [3, 5], "dict": [3, 4, 5, 6, 7, 9, 10, 11, 12], "ie": [3, 4, 5, 9, 12], "stride_per_rank": [3, 5, 12], "stride": [3, 5, 12], "case": [3, 4, 5, 9, 10, 12], "kjtonetoal": [3, 5], "onetoal": [3, 5], "essenti": [3, 5, 12], "p2p": [3, 5], "keyjaggedtensor": [3, 5], "them": [3, 5, 7, 9, 10], "kjtlist": [3, 5], "slice": [3, 5, 6, 12], "pooledembeddingsallgath": [3, 5], "layout": [3, 5, 6], "want": [3, 5], "nccl": [3, 5], "happen": [3, 5], "init_distribut": [3, 5], "new_group": [3, 5], "randn": [3, 5, 8, 9], "m": [3, 5, 6, 9], "local_emb": [3, 5], "pooledembeddingsawait": [3, 5], "num_bucket": [3, 5], "pooledembeddingsalltoal": [3, 5], "callback": [3, 5], "a2a": [3, 5], "t0": [3, 5], "rand": [3, 5, 8], "6": [3, 4, 5, 8, 9, 11, 12], "t1": [3, 5, 8, 9, 11], "print": [3, 5, 9, 11], "properti": [3, 4, 5, 7, 9, 10, 11], "tensor_await": [3, 5], "pooledembeddingsreducescatt": [3, 5], "row": [3, 4, 5], "wise": [3, 4, 5, 9], "twrw": [3, 4, 5], "over": [3, 5, 9, 10], "unequ": [3, 5], "bucket": [3, 5], "seqembeddingsalltoon": [3, 5], "concat": [3, 5, 9, 12], "sequenceembeddingsalltoal": [3, 5], "features_per_rank": [3, 5], "sharding_ctx": [3, 5], "sequenceshardingcontext": [3, 5], "lengths_after_input_dist": [3, 5], "unbucketize_permute_tensor": [3, 5], "sparse_features_recat": [3, 5], "sequenceembeddingsawait": [3, 5], "permut": [3, 5, 12], "splitsalltoallawait": [3, 5], "variablebatchpooledembeddingsalltoal": [3, 5], "kjt_split": [3, 5], "24": [3, 5], "r0_batch_siz": [3, 5], "r1_batch_siz": [3, 5], "f_0": [3, 5], "f_1": [3, 5], "f_2": [3, 5], "r0_batch_size_per_rank_per_featur": [3, 5], "r1_batch_size_per_rank_per_featur": [3, 5], "r0_batch_size_per_feature_pre_a2a": [3, 5], "r1_batch_size_per_feature_pre_a2a": [3, 5], "r0": [3, 5], "r1": [3, 5], "16": [3, 5, 9, 11], "14": [3, 5], "post": [3, 5], "rank_0": [3, 5], "rank_1": [3, 5], "variablebatchpooledembeddingsreducescatt": [3, 5], "rw": [3, 4, 5, 9], "1d": [3, 4, 5], "multipli": [3, 4, 5], "batch_size_r0_f0": [3, 5], "emb_dim_f0": [3, 5], "embeddingcollectionawait": 3, "lazyawait": 3, "jaggedtensor": [3, 9, 11, 12], "embeddingcollectioncontext": 3, "sharding_context": 3, "input_featur": 3, "reverse_indic": [3, 9], "multistream": 3, "record_stream": [3, 12], "stream": [3, 12], "gener": [3, 4, 6, 7, 8, 9, 10, 12], "embeddingcollectionshard": 3, "fused_param": [3, 5], "qcomm_codecs_registri": [3, 5], "use_index_dedup": 3, "baseembeddingshard": 3, "embeddingcollect": [3, 9, 11], "module_typ": [3, 11], "param": [3, 10], "parametershard": 3, "env": [3, 5], "shardingenv": [3, 5], "shardedembeddingcollect": [3, 9, 11], "locat": 3, "replic": [3, 4, 5], "embeddingmoduleshardingplan": 3, "fulli": [3, 4, 10], "qualifi": 3, "name": [3, 4, 7, 8, 9, 10, 11, 12], "path": [3, 4, 7], "spec": 3, "shardedmodul": 3, "shardable_paramet": 3, "sharding_typ": [3, 4, 9], "compute_device_typ": 3, "shardingtyp": [3, 4, 9], "well": [3, 4, 9], "known": [3, 4, 9], "table_name_to_parameter_shard": 3, "shardedembeddingmodul": 3, "fusedoptimizermodul": [3, 10], "public": [3, 9], "api": [3, 6, 9], "manual": [3, 10], "dist_input": 3, "compute_and_output_dist": 3, "multipl": [3, 4, 9, 10], "make": [3, 9, 10], "sens": [3, 10], "method": [3, 6, 7, 9], "initi": [3, 9, 10], "distibut": 3, "soon": 3, "complet": [3, 4], "create_context": 3, "fused_optim": [3, 10], "keyedoptim": [3, 10], "output_dist": 3, "reset_paramet": [3, 9], "create_embedding_shard": 3, "sharding_info": [3, 5], "embeddingshardinginfo": [3, 5], "embeddingshard": [3, 5], "create_sharding_infos_by_shard": 3, "embeddingcollectioninterfac": [3, 9, 11], "get_ec_index_dedup": 3, "set_ec_index_dedup": 3, "commopgradientsc": 3, "functionctx": 3, "scale_gradient_factor": 3, "groupedembeddingslookup": 3, "grouped_config": 3, "groupedembeddingconfig": [3, 5], "baseembeddinglookup": [3, 5], "lookup": [3, 4, 8, 9, 11], "i": [3, 4, 5, 6, 8, 9], "flush": 3, "sparse_featur": [3, 5, 8], "everi": [3, 4, 5, 7, 9], "although": [3, 5, 7, 9], "recip": [3, 5, 7, 9], "instanc": [3, 5, 6, 7, 9], "afterward": [3, 5, 7, 9], "sinc": [3, 5, 7, 9], "former": [3, 5, 7, 9], "take": [3, 4, 5, 7, 9, 10], "care": [3, 5, 7, 9], "regist": [3, 5, 6, 7, 9], "hook": [3, 5, 7, 9], "while": [3, 5, 6, 7, 9], "latter": [3, 5, 7, 9], "silent": [3, 5, 7, 9], "ignor": [3, 4, 5, 7, 9], "load_state_dict": [3, 10], "state_dict": [3, 7, 9, 10], "ordereddict": [3, 6, 7, 9], "union": [3, 4, 6, 7, 9, 10], "shardedtensor": [3, 10], "strict": [3, 10], "_incompatiblekei": 3, "descend": [3, 4], "exactli": 3, "match": [3, 4, 7, 9], "assign": [3, 12], "unless": [3, 10], "get_swap_module_params_on_convers": 3, "persist": [3, 7, 9], "strictli": [3, 9], "preserv": [3, 9], "state": [3, 7, 9, 10], "except": [3, 4, 9], "requires_grad": 3, "field": [3, 9, 10, 12], "missing_kei": 3, "expect": [3, 4, 8, 9], "miss": [3, 4], "unexpected_kei": 3, "present": [3, 10], "namedtupl": 3, "exist": [3, 5, 7, 12], "rais": 3, "runtimeerror": 3, "named_buff": [3, 9], "prefix": [3, 7, 9], "recurs": [3, 9], "remove_dupl": [3, 9], "iter": [3, 4, 9, 10], "yield": [3, 9], "both": [3, 7, 8, 9, 10, 12], "itself": [3, 9], "prepend": [3, 9], "submodul": [3, 9, 10], "otherwis": [3, 4, 7, 9, 10, 12], "direct": [3, 9], "remov": [3, 6, 9], "duplic": [3, 9, 10], "xdoctest": [3, 7, 9], "skip": [3, 7, 9, 10], "undefin": [3, 7, 9], "var": [3, 7, 9], "buf": [3, 9], "self": [3, 4, 9, 12], "running_var": [3, 9], "named_paramet": 3, "bia": [3, 7, 9], "named_parameters_by_t": 3, "tablebatchedembeddingslic": 3, "table_nam": 3, "embedding_weight": 3, "cw": [3, 4], "weight": [3, 4, 5, 7, 9, 10, 11, 12], "compos": [3, 7, 9], "prefetch": [3, 4], "forward_stream": 3, "purg": 3, "keep_var": [3, 7, 9], "dictionari": [3, 7, 9], "refer": [3, 7, 9, 12], "whole": [3, 7, 9], "averag": [3, 4, 7, 9], "shallow": [3, 7, 9], "posit": [3, 4, 5, 7, 9], "howev": [3, 7, 9, 10], "deprec": [3, 7, 9], "keyword": [3, 7, 9], "futur": [3, 7, 9], "releas": [3, 7, 9], "pleas": [3, 4, 7, 9, 12], "avoid": [3, 7, 9, 10], "end": [3, 4, 7, 9], "user": [3, 4, 7, 9, 10], "updat": [3, 4, 7, 9, 10], "ad": [3, 7, 9, 10], "detach": [3, 7, 9], "groupedpooledembeddingslookup": 3, "basegroupedfeatureprocessor": [3, 5, 9], "scale_weight_gradi": 3, "infercpugroupedembeddingslookup": 3, "grouped_configs_per_rank": 3, "infergroupedlookupmixin": 3, "tbetoregistermixin": 3, "get_tbes_to_regist": 3, "intnbittablebatchedembeddingbagscodegen": 3, "infergroupedembeddingslookup": 3, "abc": [3, 4, 7, 9, 10], "infergroupedpooledembeddingslookup": 3, "metainfergroupedembeddingslookup": 3, "meta": [3, 4], "tbe": [3, 4, 11], "op": [3, 4, 5, 10, 11], "metainfergroupedpooledembeddingslookup": 3, "bag": [3, 5, 6, 8, 9], "embeddings_cat_empty_rank_handl": 3, "dummy_embs_tensor": 3, "embeddings_cat_empty_rank_handle_infer": 3, "dtype": [3, 4, 5, 6, 7, 9, 11, 12], "fx_wrap_tensor_view2d": 3, "dim0": 3, "dim1": 3, "baseembeddingdist": [3, 5], "convert": [3, 6, 7, 12], "embeddinglookup": 3, "abstract": [3, 4, 7, 9, 10], "basesparsefeaturesdist": [3, 5], "f": [3, 4, 5, 8, 9, 11], "featureshardingmixin": 3, "table_wis": [3, 9], "create_input_dist": [3, 5], "create_lookup": [3, 5], "create_output_dist": [3, 5], "embedding_nam": [3, 5, 9], "embedding_names_per_rank": [3, 5], "embedding_shard_metadata": [3, 5], "shardmetadata": [3, 5], "embedding_t": [3, 5], "shardedembeddingt": [3, 5], "uncombined_embedding_dim": [3, 5], "uncombined_embedding_nam": [3, 5], "embeddingshardingcontext": [3, 5], "variable_batch_per_featur": 3, "embeddingtableconfig": [3, 9], "param_shard": 3, "nonetyp": [3, 9], "fusedkjtlistsplitsawait": 3, "kjtlistsplitsawait": 3, "kjtlistawait": 3, "info": [3, 9], "metadata": [3, 7, 9], "kjtsplitsalltoallmeta": 3, "distributed_c10d": 3, "_input": 3, "splits_tensor": 3, "listofkjtlistawait": 3, "listofkjtlist": 3, "listofkjtlistsplitsawait": 3, "bucketize_kjt_before_all2al": 3, "block_siz": [3, 5], "output_permut": 3, "bucketize_po": 3, "block_bucketize_row_po": 3, "readjust": 3, "note": [3, 4, 5, 9, 12], "memori": [3, 4, 10], "map": [3, 9, 10, 11], "unbucket": 3, "offset": [3, 4, 8, 9, 11, 12], "group_tabl": 3, "tables_per_rank": 3, "datatyp": [3, 4, 9, 11, 12], "poolingtyp": [3, 9], "embeddingcomputekernel": [3, 4], "consist": 3, "weighted": 3, "interfac": [3, 7, 9], "reli": [3, 7, 9, 11], "etc": [3, 7, 10, 12], "moduleshard": [3, 4], "compute_kernel": [3, 4], "storage_usag": 3, "resourc": 3, "processor": [3, 5, 9], "basequantembeddingshard": 3, "shardable_param": 3, "embeddingattribut": 3, "dens": [3, 4, 8, 9, 12], "enum": [3, 4, 9, 10], "fused_uvm": 3, "fused_uvm_cach": 3, "quant_uvm": 3, "quant_uvm_cach": 3, "awar": [3, 12], "feature_nam": [3, 4, 5, 8, 9, 11], "feature_names_per_rank": [3, 5], "data_typ": [3, 9], "is_weight": [3, 4, 9, 11, 12], "has_feature_processor": [3, 5, 9], "dim_sum": 3, "feature_hash_s": [3, 5], "num_featur": [3, 5, 8, 9], "moduleshardingmixin": 3, "access": [3, 4, 10, 12], "scheme": 3, "optimtyp": 3, "adagrad": [3, 10], "adam": [3, 10], "adamw": 3, "lamb": 3, "lars_sgd": 3, "lion": 3, "partial_rowwise_adam": 3, "partial_rowwise_lamb": 3, "rowwise_adagrad": 3, "sgd": 3, "shampoo": 3, "shampoo_v2": 3, "shardedconfig": 3, "local_row": [3, 4], "local_col": [3, 4], "compin": 3, "distout": 3, "out": [3, 9, 12], "shrdctx": 3, "commop": 3, "extra_repr": 3, "pretti": 3, "represent": [3, 4, 6, 9, 12], "num_embed": [3, 4, 8, 9, 11], "fp32": [3, 4, 9], "weight_init_max": [3, 9], "float": [3, 4, 6, 9, 10, 12], "weight_init_min": [3, 9], "pruning_indices_remap": [3, 9], "init_fn": [3, 9], "need_po": [3, 5, 9], "local_metadata": 3, "_shard": 3, "global_metadata": 3, "sharded_tensor": 3, "shardedtensormetadata": 3, "shardedmetaconfig": 3, "compute_kernel_to_embedding_loc": 3, "embeddingloc": 3, "embeddingawait": 3, "embeddingbagcollectionawait": 3, "lazygetitemmixin": 3, "keyedtensor": [3, 8, 9, 11, 12], "embeddingbagcollectioncontext": 3, "inverse_indic": [3, 9, 12], "divisor": 3, "embeddingbagcollectionshard": 3, "embeddingbagshard": 3, "nullshardedmodulecontext": 3, "per_sample_weight": 3, "named_modul": 3, "memo": 3, "network": [3, 4, 9, 10], "alreadi": [3, 5, 7, 10], "onc": [3, 9], "l": [3, 9, 11], "linear": [3, 4, 9, 10], "net": [3, 9], "sequenti": [3, 4, 9], "idx": 3, "in_featur": [3, 8, 9], "out_featur": [3, 9], "sharded_parameter_nam": 3, "embeddingbagcollectioninterfac": [3, 9, 11], "variablebatchembeddingbagcollectionawait": 3, "construct_output_kt": 3, "create_embedding_bag_shard": 3, "permute_embed": [3, 5], "suffix": 3, "replace_placement_with_meta_devic": 3, "placement": [3, 4], "could": [3, 4, 12], "unmatch": 3, "some": [3, 12], "scenario": [3, 9, 11], "cuda": [3, 4, 7], "embeddingshardingplann": [3, 4], "groupedpositionweightedmodul": 3, "max_feature_length": [3, 9], "dataparallelwrapp": 3, "defaultdataparallelwrapp": 3, "bucket_cap_mb": 3, "25": 3, "static_graph": 3, "find_unused_paramet": 3, "allreduce_comm_precis": 3, "unshard": [3, 4, 9, 11], "plan": [3, 4, 9], "shardingplan": [3, 4], "init_data_parallel": 3, "init_paramet": 3, "data_parallel_wrapp": 3, "entri": 3, "point": [3, 4], "collective_plan": [3, 4], "lazi": [3, 9, 10], "delai": 3, "until": 3, "still": [3, 12], "no_grad": [3, 9], "init_weight": [3, 9], "isinst": 3, "fill_": [3, 9], "elif": 3, "init": 3, "kaiming_normal_": 3, "mymodel": 3, "bare_named_paramet": 3, "new": [3, 4], "origin": [3, 4], "tor": 3, "safe": 3, "time": [3, 4, 7, 9], "ddp": 3, "fsdp": 3, "sparse_grad_parameter_nam": [3, 10], "get_modul": 3, "unwrap": 3, "so": [3, 4, 10, 12], "get_unwrapped_modul": 3, "quantembeddingbagcollectionshard": 3, "shardedquantembeddingbagcollect": 3, "quantfeatureprocessedembeddingbagcollectionshard": 3, "featureprocessedembeddingbagcollect": [3, 11], "shardedquantebcinputdist": 3, "sharding_type_to_shard": 3, "nullshardingcontext": [3, 5], "sqebc_input_dist": 3, "infertwsequenceembeddingshard": 3, "f1": [3, 8, 9, 11], "f2": [3, 8, 9, 11], "7": [3, 8, 9, 11, 12], "8": [3, 4, 8, 9, 11, 12], "shardedquantembeddingmodulest": 3, "embedding_bag_config": [3, 9, 11], "embeddingbagconfig": [3, 8, 9, 11], "execut": [3, 4, 7, 9, 11], "step": [3, 4, 10], "sharding_type_to_sharding_info": 3, "tbes_config": 3, "shardedquantfeatureprocessedembeddingbagcollect": 3, "featureprocessorscollect": [3, 11], "apply_feature_processor": 3, "kjt_list": [3, 12], "embedding_bag": [3, 11], "moduledict": [3, 9, 11], "modulelist": [3, 9, 11], "create_infer_embedding_bag_shard": 3, "flatten_feature_length": 3, "get_device_from_parameter_shard": 3, "ps": 3, "get_device_from_sharding_info": 3, "emb_shard_info": 3, "cacheparam": [3, 4], "algorithm": 3, "cachealgorithm": 3, "load_factor": [3, 4], "reserved_memori": 3, "prefetch_pipelin": [3, 4], "cachestatist": [3, 4], "cach": [3, 4], "relat": [3, 4], "most": [3, 10], "fbgemm": [3, 4, 11], "uvm": [3, 4], "lru": [3, 4], "lfu": 3, "load": [3, 4, 10], "factor": [3, 4, 9], "decid": 3, "crucial": 3, "reserv": [3, 4], "ideal": 3, "aka": 3, "statist": [3, 4], "better": 3, "tune": [3, 10], "cacheabl": [3, 4], "summar": [3, 4], "measur": [3, 4], "difficulti": [3, 4], "independ": [3, 4], "score": [3, 4, 5, 9], "mean": [3, 4, 9], "veri": [3, 4], "high": [3, 4, 9], "difficult": [3, 4], "expected_lookup": [3, 4], "distinct": [3, 4], "expected_miss_r": [3, 4], "clf": [3, 4], "rate": [3, 4, 10], "100": [3, 4, 8, 9], "hit": [3, 4], "extrem": [3, 4], "estim": [3, 4], "knowledg": [3, 4], "pooled_embeddings_all_to_al": 3, "pooled_embeddings_reduce_scatt": 3, "sequence_embeddings_all_to_al": 3, "computekernel": 3, "moduleshardingplan": 3, "describ": 3, "genericmeta": 3, "getitemlazyawait": 3, "parentw": 3, "kt": [3, 12], "__getitem__": 3, "parent": 3, "expos": [3, 10], "concret": 3, "behavior": [3, 6, 10], "achiev": 3, "late": 3, "possibl": [3, 4], "__torch_function__": 3, "below": 3, "help": 3, "doesn": [3, 9, 10], "python": [3, 6, 7], "magic": 3, "__getattr__": 3, "caveat": 3, "arbitari": 3, "mechan": [3, 9], "ensur": [3, 9, 12], "perfect": 3, "quickli": 3, "long": [3, 4, 9], "kwd": 3, "vt_co": 3, "augment": 3, "trigger": [3, 9], "keyedlazyawait": 3, "anoth": 3, "defer": 3, "mixin": 3, "inherit": [3, 9], "mro": 3, "properli": [3, 9], "select": [3, 4, 5, 12], "lazynowait": 3, "classmethod": [3, 4, 7, 11], "noopquantizedcommcodec": 3, "quantizationcontext": 3, "No": [3, 5], "calc_quantized_s": 3, "input_len": 3, "decod": 3, "input_grad": 3, "encod": 3, "quantized_dtyp": 3, "nowait": [3, 6], "obj": 3, "sharding_spec": 3, "shardingspec": 3, "cache_param": [3, 4], "enforce_hbm": [3, 4], "stochastic_round": [3, 4], "bounds_check_mod": [3, 4], "boundscheckmod": [3, 4], "output_dtyp": [3, 4, 7, 11], "hbm": [3, 4], "stochast": [3, 4], "round": [3, 4], "bound": [3, 4], "place": [3, 4, 5, 10, 12], "column_wis": [3, 9], "seen": [3, 6], "individu": [3, 4], "table_row_wis": [3, 9], "row_wis": [3, 9], "data_parallel": [3, 4, 9], "parameterstorag": 3, "physic": 3, "constraint": [3, 4], "shardingplann": [3, 4], "ddr": [3, 4], "pooled_all_to_al": 3, "reduce_scatt": 3, "float32": [3, 7, 9, 11], "quantized_tensor": 3, "quantized_comm_codec": 3, "collective_cal": 3, "output_tensor": 3, "assert_clos": 3, "int8": 3, "addit": [3, 4, 6, 7, 9, 10, 12], "carri": 3, "session": 3, "respect": [3, 9], "sequence_all_to_al": 3, "modulenocopymixin": [3, 11], "respons": 3, "transform": [3, 7, 9], "vise": [3, 10], "versa": [3, 10], "practic": 3, "from_loc": 3, "host": [3, 4, 5], "typic": [3, 4, 6, 9, 10, 12], "from_process_group": 3, "fqn": [3, 4], "larger": 3, "desir": 3, "get_plan_for_modul": 3, "module_path": 3, "re": [3, 10], "stabil": 3, "table_column_wis": [3, 9], "get_tensor_size_byt": 3, "scope": [3, 6], "copyablemixin": 3, "target": [3, 8], "mymodul": 3, "add_params_from_parameter_shard": 3, "parameter_shard": 3, "extract": 3, "add": [3, 6, 9, 10], "ones": 3, "add_prefix_to_state_dict": 3, "filter": [3, 9], "append_prefix": 3, "append": 3, "convert_to_fbgemm_typ": 3, "copy_to_devic": 3, "current_devic": [3, 7], "to_devic": 3, "filter_state_dict": 3, "start": [3, 9, 12], "strip": 3, "begin": [3, 10], "get_unsharded_module_nam": 3, "top": [3, 9], "level": [3, 5], "don": [3, 7, 9], "merge_fused_param": 3, "param_fused_param": 3, "configur": 3, "cache_precis": 3, "preset": 3, "table_level_fused_param": 3, "precid": 3, "grouped_fused_param": 3, "null": 3, "none_throw": 3, "_t": 3, "messag": [3, 4], "unexpect": 3, "assertionerror": 3, "optimizer_type_to_emb_opt_typ": 3, "optimizer_class": 3, "emboptimtyp": 3, "sharded_model_copi": 3, "m_cpu": 3, "deepcopi": 3, "managedcollisioncollectionawait": 3, "managedcollisioncollectioncontext": 3, "managedcollisioncollectionshard": 3, "managedcollisioncollect": [3, 9], "shardedmanagedcollisioncollect": 3, "evict": [3, 9], "create_mc_shard": 3, "managedcollisionembeddingbagcollectioncontext": 3, "evictions_per_t": 3, "remapped_kjt": 3, "managedcollisionembeddingbagcollectionshard": 3, "ebc_shard": 3, "mc_sharder": 3, "basemanagedcollisionembeddingcollectionshard": 3, "managedcollisionembeddingbagcollect": [3, 9], "shardedmanagedcollisionembeddingbagcollect": 3, "baseshardedmanagedcollisionembeddingcollect": 3, "managedcollisionembeddingcollectioncontext": 3, "managedcollisionembeddingcollectionshard": 3, "ec_shard": 3, "managedcollisionembeddingcollect": [3, 9], "shardedmanagedcollisionembeddingcollect": 3, "consid": [4, 9, 11, 12], "build": 4, "perf": 4, "storag": [4, 12], "peak": 4, "elimin": 4, "might": [4, 12], "oom": 4, "customiz": 4, "partit": [4, 5], "kernel_bw_lookup": 4, "compute_devic": 4, "hbm_mem_bw": 4, "ddr_mem_bw": 4, "caching_ratio": 4, "calcul": 4, "bandwidth": 4, "ratio": 4, "embeddingenumer": 4, "parameterconstraint": 4, "shardestim": 4, "shardingopt": 4, "valid": [4, 9, 12], "popul": [4, 9], "populate_estim": 4, "sharding_opt": 4, "descript": 4, "get_partition_by_typ": 4, "string": [4, 7, 9], "partitionbytyp": 4, "greedyperfpartition": 4, "sort_bi": 4, "sortbi": 4, "balance_modul": 4, "greedi": 4, "sort": 4, "smaller": 4, "effect": [4, 9], "balanc": 4, "storage_constraint": 4, "partition_bi": 4, "uniform": [4, 9], "strategi": 4, "final": [4, 8, 9, 11, 12], "docstr": [4, 12], "partition_by_devic": 4, "done": [4, 9, 10, 12], "clariti": 4, "memorybalancedpartition": 4, "max_search_count": 4, "10": [4, 8, 9, 11, 12], "toler": 4, "02": 4, "maximum": [4, 9], "greedypartition": 4, "reject": 4, "200": 4, "wors": 4, "repeatedli": 4, "find": 4, "least": 4, "amount": 4, "ordereddevicehardwar": 4, "devicehardwar": 4, "local_world_s": 4, "shardingoptiongroup": 4, "storage_sum": 4, "perf_sum": 4, "param_count": 4, "set_hbm_per_devic": 4, "hbm_per_devic": 4, "noopperfmodel": 4, "perfmodel": 4, "among": [4, 8], "here": 4, "without": [4, 12], "noopstoragemodel": 4, "storagereserv": 4, "performance_model": 4, "debug": 4, "shardabl": 4, "heteroembeddingshardingplann": 4, "topology_group": 4, "embeddingoffloadscaleuppropos": 4, "use_depth": 4, "allocate_budget": 4, "budget": 4, "allocation_prior": 4, "build_affine_storage_model": 4, "uvm_caching_sharding_opt": 4, "clf_to_byt": 4, "feedback": 4, "perf_rat": 4, "get_budget": 4, "get_cach": 4, "get_expected_lookup": 4, "search_spac": 4, "next_plan": 4, "starting_propos": 4, "greedypropos": 4, "threshold": [4, 9], "fashion": [4, 5], "On": [4, 9], "largest": 4, "tri": [4, 10], "next": 4, "max": [4, 9, 10], "earli": 4, "stop": 4, "consecut": 4, "than": [4, 9, 10], "best_perf_r": 4, "gridsearchpropos": 4, "max_propos": 4, "10000": 4, "uniformpropos": 4, "proposers_to_proposals_list": 4, "proposers_list": 4, "static_feedback": 4, "embeddingoffloadstat": 4, "mrc_hist_count": 4, "height": 4, "uvm_fused_cach": 4, "cachebl": 4, "area": 4, "under": 4, "curv": 4, "uniqu": [4, 9], "n": [4, 7, 9, 12], "histogram": 4, "bin": 4, "nth": 4, "wa": [4, 7], "estimate_cache_miss_r": 4, "cache_s": 4, "hist": 4, "mrc": 4, "embeddingperfestim": 4, "is_infer": 4, "wall": 4, "sharder_map": 4, "perf_func_emb_wall_tim": 4, "shard_siz": 4, "input_length": 4, "input_data_type_s": 4, "table_data_type_s": 4, "output_data_type_s": 4, "fwd_a2a_comm_data_type_s": 4, "bwd_a2a_comm_data_type_s": 4, "fwd_sr_comm_data_type_s": 4, "bwd_sr_comm_data_type_s": 4, "num_pool": 4, "intra_host_bw": 4, "inter_host_bw": 4, "bwd_compute_multipli": 4, "is_pool": 4, "expected_cache_fetch": 4, "attempt": 4, "rel": [4, 9], "tw": 4, "dp": 4, "queri": 4, "fwd_comm_data_type_s": 4, "bwd_comm_data_type_s": 4, "sampl": [4, 9], "thread": 4, "machin": [4, 9], "unpool": 4, "ebc": [4, 8, 9, 11], "signifi": 4, "fetch": 4, "embeddingstorageestim": 4, "calculate_shard_storag": 4, "compris": 4, "synonym": 4, "byte": [4, 7], "embeddingstat": 4, "log": 4, "sharding_plan": 4, "num_propos": 4, "num_plan": 4, "run_tim": 4, "best_plan": 4, "tabular": 4, "view": 4, "chosen": [4, 9], "evalu": [4, 9], "successfulli": 4, "taken": 4, "noopembeddingstat": 4, "noop": 4, "round_to_one_sigfig": 4, "fixedpercentagestoragereserv": 4, "percentag": 4, "heuristicalstoragereserv": 4, "parameter_multipli": 4, "dense_tensor_estim": 4, "heurist": 4, "extra": 4, "percent": 4, "act": 4, "margin": 4, "error": [4, 9, 12], "beyond": 4, "inferencestoragereserv": 4, "customtopologydata": 4, "get_data": 4, "has_data": 4, "supported_field": 4, "ddr_cap": 4, "hbm_cap": 4, "512": 4, "min_partit": 4, "pooling_factor": 4, "fbgemm_gpu": 4, "split_table_batched_embeddings_ops_common": 4, "device_group": 4, "around": 4, "lower": [4, 6, 7, 10, 11], "column": [4, 5], "rang": [4, 6, 9], "divid": 4, "divis": 4, "optionallist": 4, "momentum": 4, "determinist": 4, "import": [4, 7, 9, 11], "maintain": 4, "accuraci": [4, 9], "term": [4, 9], "fp16": 4, "exce": 4, "todai": 4, "bldm": 4, "fwd_comput": 4, "fwd_comm": 4, "bwd_comput": 4, "bwd_comm": 4, "prefetch_comput": 4, "breakdown": 4, "plannererror": 4, "error_typ": 4, "plannererrortyp": 4, "classifi": 4, "insufficient_storag": 4, "strict_constraint": 4, "prospos": 4, "paritit": 4, "subset": 4, "much": [4, 10], "depend": [4, 7, 9], "One": [4, 9], "eval": 4, "job": 4, "tower": [4, 9], "cache_load_factor": 4, "module_pool": 4, "sharding_option_nam": 4, "num_input": 4, "num_shard": 4, "total_perf": 4, "total_storag": 4, "capac": 4, "hardwar": 4, "fits_in": 4, "963146416": 4, "128": 4, "54760833": 4, "024": 4, "644245094": 4, "13421772": 4, "custom_topology_data": 4, "binarysearchpred": 4, "extern": [4, 8], "predic": 4, "discov": 4, "binari": 4, "minim": 4, "invoc": 4, "try": 4, "prior_result": 4, "probe": 4, "prior": 4, "entir": [4, 5], "explor": 4, "reach": 4, "luusjaakolasearch": 4, "max_iter": 4, "seed": 4, "42": 4, "left_cost": 4, "clamp": 4, "variant": 4, "luu": 4, "jaakola": 4, "en": 4, "wikipedia": 4, "wiki": 4, "best": 4, "far": 4, "associ": 4, "cost": [4, 9], "left": [4, 12], "right": [4, 9], "fy": 4, "y": [4, 9], "previou": 4, "subsequ": 4, "been": [4, 9], "shrink_right": 4, "shrink": 4, "boundari": 4, "infin": 4, "bytes_to_gb": 4, "num_byt": 4, "bytes_to_mb": 4, "gb_to_byt": 4, "gb": 4, "local_s": [4, 5], "format": [4, 7, 12], "prod": 4, "reset_shard_rank": 4, "sharder_nam": 4, "storage_repr_in_gb": 4, "basecwembeddingshard": 5, "basetwembeddingshard": 5, "cwpooledembeddingshard": 5, "infercwpooledembeddingdist": 5, "infercwpooledembeddingdistwithpermut": 5, "infercwpooledembeddingshard": 5, "basedpembeddingshard": 5, "dppooledembeddingdist": 5, "dppooledembeddingshard": 5, "dpsparsefeaturesdist": 5, "sparsefeatur": 5, "baserwembeddingshard": 5, "infercpurwsparsefeaturesdist": 5, "is_sequ": 5, "emb_shard": 5, "inferrwpooledembeddingdist": 5, "inferrwpooledembeddingshard": 5, "inferrwsparsefeaturesdist": 5, "rwpooledembeddingdist": 5, "share": [5, 9], "rwpooledembeddingshard": 5, "evenli": 5, "rwsparsefeaturesdist": 5, "intra_pg": 5, "hash": [5, 9], "get_block_sizes_runtime_devic": 5, "runtime_devic": 5, "tensor_cach": 5, "int32": [5, 12], "get_embedding_shard_metadata": 5, "grouped_embedding_configs_per_rank": 5, "infertwembeddingshard": 5, "infertwpooledembeddingdist": 5, "infertwsparsefeaturesdist": 5, "twpooledembeddingdist": 5, "twpooledembeddingshard": 5, "twsparsefeaturesdist": 5, "twcwpooledembeddingshard": 5, "basetwrwembeddingshard": 5, "twrwpooledembeddingdist": 5, "cross_pg": 5, "dim_sum_per_nod": 5, "emb_dim_per_node_per_featur": 5, "twrwpooledembeddingshard": 5, "twrwsparsefeaturesdist": 5, "id_list_features_per_rank": 5, "id_score_list_features_per_rank": 5, "id_list_feature_hash_s": 5, "id_score_list_feature_hash_s": 5, "shuffl": 5, "look": [5, 6, 12], "reorder": 5, "document": [6, 8], "leaf_modul": 6, "trace": [6, 7], "torchscript": 6, "create_arg": 6, "complex": 6, "memory_format": 6, "opoverload": 6, "prepar": [6, 9], "graph": 6, "emit": 6, "appropri": 6, "is_leaf_modul": 6, "module_qualified_nam": 6, "module_stack": 6, "node_name_to_scop": 6, "path_of_modul": 6, "mod": 6, "abil": 6, "made": [6, 10], "root": 6, "concrete_arg": 6, "guarante": [6, 10], "is_fx_trac": 6, "symbolic_trac": 6, "graphmodul": 6, "symbol": 6, "record": [6, 9], "partial": 6, "special": [6, 9, 10], "your": 6, "structur": [6, 10], "deploi": 7, "packag": 7, "predictmodul": 7, "predictfactori": 7, "contract": 7, "serv": 7, "predictfactorypackag": 7, "batchingqueu": 7, "config": [7, 9], "gpuexecutor": 7, "insid": 7, "dlrm_packag": 7, "py": 7, "demonstr": 7, "export": 7, "dlrm_predict": 7, "show": 7, "save_predict_factori": 7, "predict_factori": 7, "pathlib": 7, "binaryio": 7, "extra_fil": 7, "loader_cod": 7, "nimport": 7, "nmodule_factori": 7, "package_import": 7, "_sysimport": 7, "set_extern_modul": 7, "decor": 7, "abstractmethod": 7, "set_mocked_modul": 7, "load_config_text": 7, "load_pickle_config": 7, "clazz": 7, "batchingmetadata": 7, "pin": 7, "kept": 7, "sync": [7, 12], "learn": [7, 8, 9, 10], "batching_metadata": 7, "infom": 7, "batching_metadata_json": 7, "serial": 7, "json": 7, "eas": [7, 9], "pars": 7, "create_predict_modul": 7, "transformmodul": 7, "transform_state_dict": 7, "init_process_group": 7, "get_world_s": 7, "model_inputs_data": 7, "benchmark": 7, "qualname_metadata": 7, "qualnamemetadata": 7, "qualnam": 7, "inform": [7, 12], "qualname_metadata_json": 7, "result_metadata": 7, "run_weights_dependent_transform": 7, "predict_modul": 7, "predict": 7, "run_weights_independent_tranform": 7, "predict_forward": 7, "need_preproc": 7, "quantize_dens": 7, "additional_embedding_module_typ": 7, "quantize_embed": 7, "inplac": [7, 11], "additional_qconfig_spec_kei": 7, "additional_map": 7, "per_table_weight_dtyp": [7, 9], "quantize_featur": 7, "trim_torch_package_prefix_from_typenam": 7, "typenam": 7, "densearch": 8, "hidden_layer_s": 8, "deepfmnn": 8, "layer": [8, 9, 10], "embedding_dimens": 8, "dimension": 8, "hidden": [8, 9], "sparsearch": 8, "20": [8, 9], "dense_arch": 8, "dense_arch_input": 8, "dense_embed": 8, "fminteractionarch": 8, "fm_in_featur": 8, "sparse_feature_nam": 8, "deep_fm_dimens": 8, "dense_featur": [8, 9], "interact": [8, 9], "paper": [8, 9], "arxiv": 8, "pdf": 8, "1703": 8, "04247": 8, "cat": [8, 9], "dense_modul": [8, 9], "deep": [8, 9], "di": 8, "arch": 8, "fm_inter_arch": 8, "length_per_kei": [8, 12], "cat_fm_output": 8, "overarch": 8, "simpl": 8, "over_arch": 8, "logit": 8, "simpledeepfmnn": 8, "num_dense_featur": 8, "embedding_bag_collect": [8, 9], "basic": [8, 12], "relationship": 8, "project": 8, "those": [8, 9], "deep_fm": 8, "notat": 8, "throughout": 8, "eb1_config": [8, 11], "f3": 8, "eb2_config": [8, 11], "t2": [8, 9, 11], "sparse_nn": 8, "over_embedding_dim": 8, "9": 8, "from_offsets_sync": [8, 9, 11, 12], "sparse_arch": 8, "extens": 9, "establish": 9, "pattern": 9, "swishlayernorm": 9, "positionweightedmodul": 9, "lazymoduleextensionmixin": 9, "embeddingtow": 9, "embeddingtowercollect": 9, "logic": 9, "input_dim": 9, "swish": 9, "normal": 9, "sigmoid": 9, "layernorm": 9, "d1": 9, "d2": 9, "d3": 9, "last": [9, 12], "sln": 9, "num_lay": 9, "stack": 9, "learnabl": 9, "polynom": 9, "full": [9, 10, 12], "matrix": 9, "nxn": 9, "cover": 9, "bit": 9, "x_": 9, "x_0": 9, "w_l": 9, "cdot": 9, "x_l": 9, "b_l": 9, "squar": 9, "element": 9, "dcn": 9, "lowrankcrossnet": 9, "low_rank": 9, "low": 9, "highli": 9, "effici": 9, "matric": 9, "simplifi": 9, "v_l": 9, "vector": 9, "smartli": 9, "setup": 9, "alwai": [9, 12], "lowrankmixturecrossnet": 9, "num_expert": 9, "relu": 9, "mixtur": 9, "expert": 9, "compar": [9, 12], "leverag": 9, "k": 9, "subspac": 9, "adapt": 9, "gate": 9, "moe": 9, "expert_i": 9, "k_": 9, "u_": 9, "li": 9, "c_": 9, "v_": 9, "vectorcrossnet": 9, "keep": 9, "nx1": 9, "dot": 9, "thu": [9, 10], "further": [9, 12], "cut": 9, "off": 9, "implent": 9, "framework": 9, "factorizationmachin": 9, "fm": 9, "abov": [9, 12], "publish": 9, "compon": 9, "learnt": 9, "To": 9, "flexibl": 9, "raw": 9, "limit": 9, "architectur": 9, "90": 9, "30": 9, "40": 9, "equal": [9, 12], "count": 9, "fb": 9, "lazymlp": 9, "output_dim": 9, "64": 9, "32": 9, "192": 9, "deep_fm_output": 9, "common_spars": 9, "specialized_spars": 9, "embedding_featur": 9, "raw_embedding_featur": 9, "nativ": 9, "trained_embed": 9, "native_embed": 9, "ident": 9, "mention": 9, "2nd": 9, "baseembeddingconfig": 9, "get_weight_init_max": 9, "get_weight_init_min": 9, "embeddingconfig": [9, 11], "quantconfig": 9, "placeholderobserv": [9, 11], "alia": 9, "data_type_to_dtyp": 9, "data_type_to_sparse_typ": 9, "sparsetyp": 9, "dtype_to_data_typ": 9, "pooling_type_to_pooling_mod": 9, "pooling_typ": 9, "poolingmod": 9, "pooling_type_to_str": 9, "sensit": [9, 11], "jag": [9, 11, 12], "table_0": [9, 11], "table_1": [9, 11], "pooled_embed": 9, "8899": 9, "1342": 9, "9060": 9, "0905": 9, "2814": 9, "9369": 9, "7783": 9, "0000": 9, "1598": 9, "0695": 9, "3265": 9, "1011": 9, "4256": 9, "1846": 9, "1648": 9, "0893": 9, "3590": 9, "9784": 9, "7681": 9, "grad_fn": [9, 11], "catbackward0": 9, "offset_per_kei": [9, 12], "need_indic": [9, 11], "e1_config": [9, 11], "e2_config": [9, 11], "ec": [9, 11], "feature_embed": [9, 11], "2050": [9, 11], "5478": [9, 11], "6054": [9, 11], "7352": [9, 11], "3210": [9, 11], "0399": [9, 11], "1279": [9, 11], "1756": [9, 11], "4130": [9, 11], "7519": [9, 11], "4341": [9, 11], "0499": [9, 11], "9329": [9, 11], "0697": [9, 11], "8095": [9, 11], "embeddingbackward": [9, 11], "embedding_names_by_t": [9, 11], "get_embedding_names_by_t": 9, "process_pooled_embed": 9, "reorder_inverse_indic": 9, "basefeatureprocessor": 9, "max_length": 9, "truncat": 9, "positionweightedprocessor": 9, "feature_length": 9, "feature0": [9, 12], "feature1": [9, 12], "feature2": 9, "from_lengths_sync": [9, 12], "pw": 9, "featureprocessorcollect": 9, "feature_processor_modul": 9, "positionweightedfeatureprocessor": 9, "fp_featur": 9, "non_fp_featur": 9, "non_fp": 9, "feature_process": 9, "come": 9, "And": 9, "offsets_to_range_tracebl": 9, "position_weighted_module_update_featur": 9, "weighted_featur": 9, "lazymodulemixin": 9, "temporari": 9, "upstream": 9, "59923": 9, "testlazymoduleextensionmixin": 9, "unit": 9, "test": 9, "_infer_paramet": 9, "code": 9, "pariti": 9, "_call_impl": 9, "pre": [9, 10], "children": 9, "uniniti": 9, "dummi": [9, 10], "lazylinear": 9, "fail": [9, 12], "becaus": [9, 10], "hasn": 9, "yet": 9, "now": [9, 12], "lazy_appli": 9, "attach": 9, "numer": 9, "immedi": 9, "seq": 9, "in_siz": 9, "layer_s": 9, "perceptron": 9, "multi": 9, "out_siz": 9, "swish_layernorm": 9, "won": 9, "constructor": 9, "mlp_modul": 9, "assert": 9, "o": 9, "channel": 9, "check_module_output_dimens": 9, "verifi": 9, "construct_jagged_tensor": 9, "features_to_permute_indic": 9, "original_featur": 9, "construct_jagged_tensors_infer": 9, "construct_modulelist_from_single_modul": 9, "nest": 9, "reiniti": 9, "convert_list_of_modules_to_modulelist": 9, "extract_module_or_tensor_cal": 9, "module_or_cal": 9, "get_module_output_dimens": 9, "init_mlp_weights_xavier_uniform": 9, "distancelfu_evictionpolici": 9, "decay_expon": 9, "threshold_filtering_func": 9, "mchevictionpolici": 9, "coalesce_history_metadata": 9, "current_it": 9, "history_metadata": 9, "unique_ids_count": 9, "unique_inverse_map": 9, "additional_id": 9, "threshold_mask": 9, "histori": 9, "invers": [9, 12], "history_accumul": 9, "coalesc": 9, "metadata_info": 9, "mchevictionpolicymetadatainfo": 9, "record_history_metadata": 9, "incoming_id": 9, "incom": 9, "polici": [9, 10], "update_metadata_and_generate_eviction_scor": 9, "mch_size": 9, "coalesced_history_argsort_map": 9, "coalesced_history_sorted_unique_ids_count": 9, "coalesced_history_mch_matching_elements_mask": 9, "coalesced_history_mch_matching_indic": 9, "mch_metadata": 9, "coalesced_history_metadata": 9, "evicted_indic": 9, "selected_new_indic": 9, "mch": 9, "lfu_evictionpolici": 9, "lru_evictionpolici": 9, "metadata_nam": 9, "is_mch_metadata": 9, "is_history_metadata": 9, "mchmanagedcollisionmodul": 9, "zch_size": 9, "eviction_polici": 9, "eviction_interv": 9, "input_hash_s": 9, "9223372036854775808": 9, "input_hash_func": 9, "mch_hash_func": 9, "output_global_offset": 9, "managedcollisionmodul": 9, "zch": 9, "manag": 9, "collis": 9, "output_size_offset": 9, "interv": 9, "drive": 9, "greater": 9, "residu": 9, "legaci": 9, "intern": [9, 12], "shift": 9, "zch_output_rang": 9, "down": 9, "applic": 9, "slot": 9, "reset": [9, 10], "assumptionn": 9, "downstream": 9, "modifi": [9, 10], "jt": [9, 12], "rtype": 9, "output_s": 9, "vs": 9, "preprocess": 9, "profil": 9, "rebuild_with_output_id_rang": 9, "output_id_rang": 9, "mc": 9, "hack": 9, "remap": 9, "managed_collision_modul": 9, "mcc": 9, "embedding_confg": 9, "collsion": 9, "max_output_id": 9, "remapping_range_start_index": 9, "mcm": 9, "mcm_jt": 9, "fp": 9, "apply_mc_method_to_jt_dict": 9, "features_dict": 9, "table_to_featur": 9, "managed_collis": 9, "average_threshold_filt": 9, "id_count": 9, "dynamic_threshold_filt": 9, "threshold_skew_multipli": 9, "total_count": 9, "num_id": 9, "probabilistic_threshold_filt": 9, "per_id_prob": 9, "01": 9, "probabl": 9, "appear": 9, "60": 9, "randomli": 9, "chanc": 9, "basemanagedcollisionembeddingcollect": 9, "managed_collision_collect": 9, "return_remapped_featur": 9, "embedding_collect": 9, "meaning": 10, "prohibit": 10, "empti": [10, 12], "sever": 10, "combinedoptim": 10, "optimizerwrapp": 10, "rowwis": 10, "gradientclip": 10, "norm": 10, "gradientclippingoptim": 10, "max_gradi": 10, "closur": 10, "reevalu": 10, "loss": 10, "emptyfusedoptim": 10, "fusedoptim": 10, "zero_grad": 10, "set_to_non": 10, "zero": [10, 12], "footprint": 10, "modestli": 10, "improv": 10, "certain": 10, "0s": 10, "behav": 10, "did": 10, "altogeth": 10, "param_group": 10, "meant": 10, "post_load_state_dict": 10, "prepend_opt_kei": 10, "opt_kei": 10, "save_param_group": 10, "stricter": 10, "old": 10, "switch": 10, "flag": 10, "reason": 10, "identifi": 10, "littl": 10, "add_param_group": 10, "fine": 10, "frozen": 10, "trainabl": 10, "progress": 10, "what": 10, "init_st": 10, "checkpoint": 10, "usabl": 10, "sure": 10, "sd": 10, "load_checkpoint": 10, "replac": 10, "advanc": 10, "protocol": 10, "keyedoptimizerwrapp": 10, "optim_factori": 10, "conveni": 10, "warmupoptim": 10, "stage": 10, "warmupstag": 10, "lr": 10, "lr_param": 10, "param_nam": 10, "__warmup": 10, "adjust": 10, "schedul": 10, "go": 10, "fake": 10, "warmuppolici": 10, "invsqrt": 10, "inv_sqrt": 10, "poli": 10, "max_it": 10, "lr_scale": 10, "decay_it": 10, "speed": 11, "trec_quant": 11, "trec": 11, "qconfig": 11, "with_arg": 11, "qint8": 11, "quantize_dynam": 11, "qconfig_spec": 11, "table_name_to_quantized_weight": 11, "register_tb": 11, "quant_state_dict_split_scale_bia": 11, "row_align": 11, "qebc": 11, "quantembeddingbagcollect": 11, "from_float": 11, "quantized_embed": 11, "for_each_module_of_type_do": 11, "pruned_num_embed": 11, "pruning_indices_map": 11, "quant_prep_customize_row_align": 11, "quant_prep_enable_quant_state_dict_split_scale_bia": 11, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 11, "quant_prep_enable_register_tb": 11, "quantize_state_dict": 11, "table_name_to_data_typ": 11, "table_name_to_pruning_indices_map": 11, "whose": 12, "dimes": 12, "computejtdicttokjt": 12, "jt_dict": 12, "v5": 12, "v6": 12, "v7": 12, "dim_1": 12, "dim_0": 12, "computekjttojtdict": 12, "keyed_jagged_tensor": 12, "jit": 12, "abl": 12, "NOT": 12, "expens": 12, "values_dtyp": 12, "weights_dtyp": 12, "lengths_dtyp": 12, "from_dens": 12, "2d": 12, "11": 12, "12": 12, "j1": 12, "from_dense_length": 12, "lengths_or_non": 12, "offsets_or_non": 12, "non_block": 12, "new_devic": 12, "to_dens": 12, "inttensor": 12, "values_list": 12, "to_dense_weight": 12, "weights_list": 12, "to_padded_dens": 12, "desired_length": 12, "padding_valu": 12, "longest": 12, "pad": 12, "dt": 12, "to_padded_dense_weight": 12, "d_wt": 12, "weights_or_non": 12, "jaggedtensormeta": 12, "namespac": 12, "abcmeta": 12, "proxyableclassmeta": 12, "stride_per_key_per_rank": 12, "outer": 12, "inner": 12, "index_per_kei": 12, "expand": 12, "dedupl": 12, "dim_2": 12, "w0": 12, "w1": 12, "w2": 12, "w3": 12, "w4": 12, "w5": 12, "w6": 12, "w7": 12, "dist_init": 12, "variable_stride_per_kei": 12, "num_work": 12, "dist_label": 12, "dist_split": 12, "key_split": 12, "dist_tensor": 12, "empty_lik": 12, "flatten_length": 12, "from_jt_dict": 12, "implicit": 12, "visual": 12, "variable_feature_dim": 12, "But": 12, "That": 12, "didn": 12, "notic": 12, "correctli": 12, "technic": 12, "know": 12, "violat": 12, "precondit": 12, "fix": 12, "inverse_indices_or_non": 12, "length_per_key_or_non": 12, "lengths_offset_per_kei": 12, "offset_per_key_or_non": 12, "indices_tensor": 12, "include_inverse_indic": 12, "pin_memori": 12, "segment": 12, "stride_per_kei": 12, "to_dict": 12, "unsync": 12, "key_dim": 12, "tensor_list": 12, "from_tensor_list": 12, "regroup": 12, "keyed_tensor": 12, "regroup_as_dict": 12, "flatten_kjt_list": 12, "kjt_arr": 12, "is_non_strict_export": 12, "jt_is_equ": 12, "jt_1": 12, "jt_2": 12, "comparison": 12, "themselv": 12, "treat": 12, "kjt_is_equ": 12, "kjt_1": 12, "kjt_2": 12, "unflatten_kjt_list": 12}, "objects": {"torchrec": [[3, 0, 0, "-", "distributed"], [6, 0, 0, "module-0", "fx"], [7, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "modules"], [10, 0, 0, "module-0", "optim"], [11, 0, 0, "module-0", "quant"], [12, 0, 0, "module-0", "sparse"]], "torchrec.distributed": [[3, 0, 0, "-", "collective_utils"], [3, 0, 0, "-", "comm"], [3, 0, 0, "-", "comm_ops"], [5, 0, 0, "-", "dist_data"], [3, 0, 0, "-", "embedding"], [3, 0, 0, "-", "embedding_lookup"], [3, 0, 0, "-", "embedding_sharding"], [3, 0, 0, "-", "embedding_types"], [3, 0, 0, "-", "embeddingbag"], [3, 0, 0, "-", "grouped_position_weighted"], [3, 0, 0, "-", "mc_embedding"], [3, 0, 0, "-", "mc_embeddingbag"], [3, 0, 0, "-", "mc_modules"], [3, 0, 0, "-", "model_parallel"], [4, 0, 0, "-", "planner"], [3, 0, 0, "-", "quant_embeddingbag"], [5, 0, 0, "-", "sharding"], [3, 0, 0, "-", "train_pipeline"], [3, 0, 0, "-", "types"], [3, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[3, 1, 1, "", "invoke_on_rank_and_broadcast_result"], [3, 1, 1, "", "is_leader"], [3, 1, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[3, 1, 1, "", "get_group_rank"], [3, 1, 1, "", "get_local_rank"], [3, 1, 1, "", "get_local_size"], [3, 1, 1, "", "get_num_groups"], [3, 1, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[3, 2, 1, "", "All2AllDenseInfo"], [3, 2, 1, "", "All2AllPooledInfo"], [3, 2, 1, "", "All2AllSequenceInfo"], [3, 2, 1, "", "All2AllVInfo"], [3, 2, 1, "", "All2All_Pooled_Req"], [3, 2, 1, "", "All2All_Pooled_Wait"], [3, 2, 1, "", "All2All_Seq_Req"], [3, 2, 1, "", "All2All_Seq_Req_Wait"], [3, 2, 1, "", "All2Allv_Req"], [3, 2, 1, "", "All2Allv_Wait"], [3, 2, 1, "", "AllGatherBaseInfo"], [3, 2, 1, "", "AllGatherBase_Req"], [3, 2, 1, "", "AllGatherBase_Wait"], [3, 2, 1, "", "ReduceScatterBaseInfo"], [3, 2, 1, "", "ReduceScatterBase_Req"], [3, 2, 1, "", "ReduceScatterBase_Wait"], [3, 2, 1, "", "ReduceScatterInfo"], [3, 2, 1, "", "ReduceScatterVInfo"], [3, 2, 1, "", "ReduceScatterV_Req"], [3, 2, 1, "", "ReduceScatterV_Wait"], [3, 2, 1, "", "ReduceScatter_Req"], [3, 2, 1, "", "ReduceScatter_Wait"], [3, 2, 1, "", "Request"], [3, 2, 1, "", "VariableBatchAll2AllPooledInfo"], [3, 2, 1, "", "Variable_Batch_All2All_Pooled_Req"], [3, 2, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [3, 1, 1, "", "all2all_pooled_sync"], [3, 1, 1, "", "all2all_sequence_sync"], [3, 1, 1, "", "all2allv_sync"], [3, 1, 1, "", "all_gather_base_pooled"], [3, 1, 1, "", "all_gather_base_sync"], [3, 1, 1, "", "alltoall_pooled"], [3, 1, 1, "", "alltoall_sequence"], [3, 1, 1, "", "alltoallv"], [3, 1, 1, "", "fn"], [3, 1, 1, "", "get_gradient_division"], [3, 1, 1, "", "reduce_scatter_base_pooled"], [3, 1, 1, "", "reduce_scatter_base_sync"], [3, 1, 1, "", "reduce_scatter_pooled"], [3, 1, 1, "", "reduce_scatter_sync"], [3, 1, 1, "", "reduce_scatter_v_per_feature_pooled"], [3, 1, 1, "", "reduce_scatter_v_pooled"], [3, 1, 1, "", "reduce_scatter_v_sync"], [3, 1, 1, "", "set_gradient_division"], [3, 1, 1, "", "variable_batch_all2all_pooled_sync"], [3, 1, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[3, 3, 1, "", "batch_size"], [3, 3, 1, "", "input_shape"], [3, 3, 1, "", "input_splits"], [3, 3, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[3, 3, 1, "id0", "batch_size_per_rank"], [3, 3, 1, "id1", "codecs"], [3, 3, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [3, 3, 1, "id3", "dim_sum_per_rank"], [3, 3, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[3, 3, 1, "id5", "backward_recat_tensor"], [3, 3, 1, "id6", "codecs"], [3, 3, 1, "id7", "embedding_dim"], [3, 3, 1, "id8", "forward_recat_tensor"], [3, 3, 1, "id9", "input_splits"], [3, 3, 1, "id10", "lengths_after_sparse_data_all2all"], [3, 3, 1, "id11", "output_splits"], [3, 3, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [3, 3, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[3, 3, 1, "id14", "B_global"], [3, 3, 1, "id15", "B_local"], [3, 3, 1, "id16", "B_local_list"], [3, 3, 1, "id17", "D_local_list"], [3, 3, 1, "", "codecs"], [3, 3, 1, "", "dim_sum_per_rank"], [3, 3, 1, "", "dims_sum_per_rank"], [3, 3, 1, "id18", "input_split_sizes"], [3, 3, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[3, 3, 1, "", "codecs"], [3, 3, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[3, 3, 1, "id23", "codecs"], [3, 3, 1, "id24", "equal_splits"], [3, 3, 1, "id25", "input_sizes"], [3, 3, 1, "id26", "input_splits"], [3, 3, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[3, 3, 1, "id28", "batch_size_per_feature_pre_a2a"], [3, 3, 1, "id29", "batch_size_per_rank_per_feature"], [3, 3, 1, "id30", "codecs"], [3, 3, 1, "id31", "emb_dim_per_rank_per_feature"], [3, 3, 1, "id32", "input_splits"], [3, 3, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.dist_data": [[5, 2, 1, "", "EmbeddingsAllToOne"], [5, 2, 1, "", "EmbeddingsAllToOneReduce"], [5, 2, 1, "", "KJTAllToAll"], [5, 2, 1, "", "KJTAllToAllSplitsAwaitable"], [5, 2, 1, "", "KJTAllToAllTensorsAwaitable"], [5, 2, 1, "", "KJTOneToAll"], [5, 2, 1, "", "PooledEmbeddingsAllGather"], [5, 2, 1, "", "PooledEmbeddingsAllToAll"], [5, 2, 1, "", "PooledEmbeddingsAwaitable"], [5, 2, 1, "", "PooledEmbeddingsReduceScatter"], [5, 2, 1, "", "SeqEmbeddingsAllToOne"], [5, 2, 1, "", "SequenceEmbeddingsAllToAll"], [5, 2, 1, "", "SequenceEmbeddingsAwaitable"], [5, 2, 1, "", "SplitsAllToAllAwaitable"], [5, 2, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [5, 2, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[5, 5, 1, "", "callbacks"], [5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[5, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[5, 4, 1, "", "forward"], [5, 4, 1, "", "set_device"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[5, 5, 1, "", "callbacks"], [5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.embedding": [[3, 2, 1, "", "EmbeddingCollectionAwaitable"], [3, 2, 1, "", "EmbeddingCollectionContext"], [3, 2, 1, "", "EmbeddingCollectionSharder"], [3, 2, 1, "", "ShardedEmbeddingCollection"], [3, 1, 1, "", "create_embedding_sharding"], [3, 1, 1, "", "create_sharding_infos_by_sharding"], [3, 1, 1, "", "get_ec_index_dedup"], [3, 1, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "reset_parameters"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[3, 2, 1, "", "CommOpGradientScaling"], [3, 2, 1, "", "GroupedEmbeddingsLookup"], [3, 2, 1, "", "GroupedPooledEmbeddingsLookup"], [3, 2, 1, "", "InferCPUGroupedEmbeddingsLookup"], [3, 2, 1, "", "InferGroupedEmbeddingsLookup"], [3, 2, 1, "", "InferGroupedLookupMixin"], [3, 2, 1, "", "InferGroupedPooledEmbeddingsLookup"], [3, 2, 1, "", "MetaInferGroupedEmbeddingsLookup"], [3, 2, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [3, 1, 1, "", "embeddings_cat_empty_rank_handle"], [3, 1, 1, "", "embeddings_cat_empty_rank_handle_inference"], [3, 1, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[3, 4, 1, "", "backward"], [3, 4, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "named_parameters_by_table"], [3, 4, 1, "", "prefetch"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "named_parameters_by_table"], [3, 4, 1, "", "prefetch"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[3, 4, 1, "", "forward"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "get_tbes_to_register"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "get_tbes_to_register"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[3, 4, 1, "", "flush"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "get_tbes_to_register"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "purge"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[3, 2, 1, "", "BaseEmbeddingDist"], [3, 2, 1, "", "BaseSparseFeaturesDist"], [3, 2, 1, "", "EmbeddingSharding"], [3, 2, 1, "", "EmbeddingShardingContext"], [3, 2, 1, "", "EmbeddingShardingInfo"], [3, 2, 1, "", "FusedKJTListSplitsAwaitable"], [3, 2, 1, "", "KJTListAwaitable"], [3, 2, 1, "", "KJTListSplitsAwaitable"], [3, 2, 1, "", "KJTSplitsAllToAllMeta"], [3, 2, 1, "", "ListOfKJTListAwaitable"], [3, 2, 1, "", "ListOfKJTListSplitsAwaitable"], [3, 1, 1, "", "bucketize_kjt_before_all2all"], [3, 1, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[3, 4, 1, "", "create_input_dist"], [3, 4, 1, "", "create_lookup"], [3, 4, 1, "", "create_output_dist"], [3, 4, 1, "", "embedding_dims"], [3, 4, 1, "", "embedding_names"], [3, 4, 1, "", "embedding_names_per_rank"], [3, 4, 1, "", "embedding_shard_metadata"], [3, 4, 1, "", "embedding_tables"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "uncombined_embedding_dims"], [3, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[3, 3, 1, "", "embedding_config"], [3, 3, 1, "", "fused_params"], [3, 3, 1, "", "param"], [3, 3, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[3, 3, 1, "", "device"], [3, 3, 1, "", "input_splits"], [3, 3, 1, "", "input_tensors"], [3, 3, 1, "", "keys"], [3, 3, 1, "", "labels"], [3, 3, 1, "", "pg"], [3, 3, 1, "", "splits"], [3, 3, 1, "", "splits_tensors"], [3, 3, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[3, 2, 1, "", "BaseEmbeddingLookup"], [3, 2, 1, "", "BaseEmbeddingSharder"], [3, 2, 1, "", "BaseGroupedFeatureProcessor"], [3, 2, 1, "", "BaseQuantEmbeddingSharder"], [3, 2, 1, "", "EmbeddingAttributes"], [3, 2, 1, "", "EmbeddingComputeKernel"], [3, 2, 1, "", "FeatureShardingMixIn"], [3, 2, 1, "", "GroupedEmbeddingConfig"], [3, 2, 1, "", "KJTList"], [3, 2, 1, "", "ListOfKJTList"], [3, 2, 1, "", "ModuleShardingMixIn"], [3, 2, 1, "", "OptimType"], [3, 2, 1, "", "ShardedConfig"], [3, 2, 1, "", "ShardedEmbeddingModule"], [3, 2, 1, "", "ShardedEmbeddingTable"], [3, 2, 1, "", "ShardedMetaConfig"], [3, 1, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "fused_params"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "fused_params"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[3, 3, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[3, 3, 1, "", "DENSE"], [3, 3, 1, "", "FUSED"], [3, 3, 1, "", "FUSED_UVM"], [3, 3, 1, "", "FUSED_UVM_CACHING"], [3, 3, 1, "", "QUANT"], [3, 3, 1, "", "QUANT_UVM"], [3, 3, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[3, 4, 1, "", "feature_names"], [3, 4, 1, "", "feature_names_per_rank"], [3, 4, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[3, 3, 1, "", "compute_kernel"], [3, 3, 1, "", "data_type"], [3, 4, 1, "", "dim_sum"], [3, 4, 1, "", "embedding_dims"], [3, 4, 1, "", "embedding_names"], [3, 4, 1, "", "embedding_shard_metadata"], [3, 3, 1, "", "embedding_tables"], [3, 4, 1, "", "feature_hash_sizes"], [3, 4, 1, "", "feature_names"], [3, 3, 1, "", "fused_params"], [3, 3, 1, "", "has_feature_processor"], [3, 3, 1, "", "is_weighted"], [3, 4, 1, "", "num_features"], [3, 3, 1, "", "pooling"], [3, 4, 1, "", "table_names"]], "torchrec.distributed.embedding_types.KJTList": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[3, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[3, 3, 1, "", "ADAGRAD"], [3, 3, 1, "", "ADAM"], [3, 3, 1, "", "ADAMW"], [3, 3, 1, "", "LAMB"], [3, 3, 1, "", "LARS_SGD"], [3, 3, 1, "", "LION"], [3, 3, 1, "", "PARTIAL_ROWWISE_ADAM"], [3, 3, 1, "", "PARTIAL_ROWWISE_LAMB"], [3, 3, 1, "", "ROWWISE_ADAGRAD"], [3, 3, 1, "", "SGD"], [3, 3, 1, "", "SHAMPOO"], [3, 3, 1, "", "SHAMPOO_V2"]], "torchrec.distributed.embedding_types.ShardedConfig": [[3, 3, 1, "", "local_cols"], [3, 3, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[3, 4, 1, "", "extra_repr"], [3, 4, 1, "", "prefetch"], [3, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[3, 3, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[3, 3, 1, "", "global_metadata"], [3, 3, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[3, 2, 1, "", "EmbeddingAwaitable"], [3, 2, 1, "", "EmbeddingBagCollectionAwaitable"], [3, 2, 1, "", "EmbeddingBagCollectionContext"], [3, 2, 1, "", "EmbeddingBagCollectionSharder"], [3, 2, 1, "", "EmbeddingBagSharder"], [3, 2, 1, "", "ShardedEmbeddingBag"], [3, 2, 1, "", "ShardedEmbeddingBagCollection"], [3, 2, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [3, 1, 1, "", "construct_output_kt"], [3, 1, 1, "", "create_embedding_bag_sharding"], [3, 1, 1, "", "create_sharding_infos_by_sharding"], [3, 1, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[3, 3, 1, "", "divisor"], [3, 3, 1, "", "inverse_indices"], [3, 4, 1, "", "record_stream"], [3, 3, 1, "", "sharding_contexts"], [3, 3, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "load_state_dict"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_modules"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharded_parameter_names"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "reset_parameters"], [3, 3, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[3, 2, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[3, 4, 1, "", "forward"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[3, 2, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [3, 2, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[3, 4, 1, "", "create_context"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[3, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [3, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[3, 3, 1, "", "evictions_per_table"], [3, 4, 1, "", "record_stream"], [3, 3, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[3, 4, 1, "", "create_context"], [3, 3, 1, "", "training"]], "torchrec.distributed.mc_modules": [[3, 2, 1, "", "ManagedCollisionCollectionAwaitable"], [3, 2, 1, "", "ManagedCollisionCollectionContext"], [3, 2, 1, "", "ManagedCollisionCollectionSharder"], [3, 2, 1, "", "ShardedManagedCollisionCollection"], [3, 1, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "evict"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharded_parameter_names"], [3, 3, 1, "", "training"]], "torchrec.distributed.model_parallel": [[3, 2, 1, "", "DataParallelWrapper"], [3, 2, 1, "", "DefaultDataParallelWrapper"], [3, 2, 1, "", "DistributedModelParallel"], [3, 1, 1, "", "get_module"], [3, 1, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[3, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[3, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[3, 4, 1, "", "bare_named_parameters"], [3, 4, 1, "", "copy"], [3, 4, 1, "", "forward"], [3, 5, 1, "", "fused_optimizer"], [3, 4, 1, "", "init_data_parallel"], [3, 4, 1, "", "load_state_dict"], [3, 5, 1, "", "module"], [3, 4, 1, "", "named_buffers"], [3, 4, 1, "", "named_parameters"], [3, 5, 1, "", "plan"], [3, 4, 1, "", "sparse_grad_parameter_names"], [3, 4, 1, "", "state_dict"], [3, 3, 1, "", "training"]], "torchrec.distributed.planner": [[4, 0, 0, "-", "constants"], [4, 0, 0, "-", "enumerators"], [4, 0, 0, "-", "partitioners"], [4, 0, 0, "-", "perf_models"], [4, 0, 0, "-", "planners"], [4, 0, 0, "-", "proposers"], [4, 0, 0, "-", "shard_estimators"], [4, 0, 0, "-", "stats"], [4, 0, 0, "-", "storage_reservations"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[4, 1, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[4, 2, 1, "", "EmbeddingEnumerator"], [4, 1, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[4, 4, 1, "", "enumerate"], [4, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[4, 2, 1, "", "GreedyPerfPartitioner"], [4, 2, 1, "", "MemoryBalancedPartitioner"], [4, 2, 1, "", "OrderedDeviceHardware"], [4, 2, 1, "", "ShardingOptionGroup"], [4, 2, 1, "", "SortBy"], [4, 1, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[4, 3, 1, "", "device"], [4, 3, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[4, 3, 1, "", "param_count"], [4, 3, 1, "", "perf_sum"], [4, 3, 1, "", "sharding_options"], [4, 3, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[4, 3, 1, "", "PERF"], [4, 3, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[4, 2, 1, "", "NoopPerfModel"], [4, 2, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[4, 2, 1, "", "EmbeddingShardingPlanner"], [4, 2, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[4, 4, 1, "", "collective_plan"], [4, 4, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[4, 4, 1, "", "collective_plan"], [4, 4, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[4, 2, 1, "", "EmbeddingOffloadScaleupProposer"], [4, 2, 1, "", "GreedyProposer"], [4, 2, 1, "", "GridSearchProposer"], [4, 2, 1, "", "UniformProposer"], [4, 1, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[4, 4, 1, "", "allocate_budget"], [4, 4, 1, "", "build_affine_storage_model"], [4, 4, 1, "", "clf_to_bytes"], [4, 4, 1, "", "feedback"], [4, 4, 1, "", "get_budget"], [4, 4, 1, "", "get_cacheability"], [4, 4, 1, "", "get_expected_lookups"], [4, 4, 1, "", "load"], [4, 4, 1, "", "next_plan"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[4, 2, 1, "", "EmbeddingOffloadStats"], [4, 2, 1, "", "EmbeddingPerfEstimator"], [4, 2, 1, "", "EmbeddingStorageEstimator"], [4, 1, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[4, 5, 1, "", "cacheability"], [4, 4, 1, "", "estimate_cache_miss_rate"], [4, 5, 1, "", "expected_lookups"], [4, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[4, 4, 1, "", "estimate"], [4, 4, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[4, 4, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[4, 2, 1, "", "EmbeddingStats"], [4, 2, 1, "", "NoopEmbeddingStats"], [4, 1, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[4, 2, 1, "", "FixedPercentageStorageReservation"], [4, 2, 1, "", "HeuristicalStorageReservation"], [4, 2, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[4, 2, 1, "", "CustomTopologyData"], [4, 2, 1, "", "DeviceHardware"], [4, 2, 1, "", "Enumerator"], [4, 2, 1, "", "ParameterConstraints"], [4, 2, 1, "", "PartitionByType"], [4, 2, 1, "", "Partitioner"], [4, 2, 1, "", "Perf"], [4, 2, 1, "", "PerfModel"], [4, 6, 1, "", "PlannerError"], [4, 2, 1, "", "PlannerErrorType"], [4, 2, 1, "", "Proposer"], [4, 2, 1, "", "Shard"], [4, 2, 1, "", "ShardEstimator"], [4, 2, 1, "", "ShardingOption"], [4, 2, 1, "", "Stats"], [4, 2, 1, "", "Storage"], [4, 2, 1, "", "StorageReservation"], [4, 2, 1, "", "Topology"]], "torchrec.distributed.planner.types.CustomTopologyData": [[4, 4, 1, "", "get_data"], [4, 4, 1, "", "has_data"], [4, 3, 1, "", "supported_fields"]], "torchrec.distributed.planner.types.DeviceHardware": [[4, 3, 1, "", "perf"], [4, 3, 1, "", "rank"], [4, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[4, 4, 1, "", "enumerate"], [4, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[4, 3, 1, "id0", "batch_sizes"], [4, 3, 1, "id1", "bounds_check_mode"], [4, 3, 1, "id2", "cache_params"], [4, 3, 1, "id3", "compute_kernels"], [4, 3, 1, "id4", "device_group"], [4, 3, 1, "id5", "enforce_hbm"], [4, 3, 1, "id6", "feature_names"], [4, 3, 1, "id7", "is_weighted"], [4, 3, 1, "id8", "min_partition"], [4, 3, 1, "id9", "num_poolings"], [4, 3, 1, "id10", "output_dtype"], [4, 3, 1, "id11", "pooling_factors"], [4, 3, 1, "id12", "sharding_types"], [4, 3, 1, "id13", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[4, 3, 1, "", "DEVICE"], [4, 3, 1, "", "HOST"], [4, 3, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[4, 4, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[4, 3, 1, "", "bwd_comms"], [4, 3, 1, "", "bwd_compute"], [4, 3, 1, "", "fwd_comms"], [4, 3, 1, "", "fwd_compute"], [4, 3, 1, "", "prefetch_compute"], [4, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[4, 4, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[4, 3, 1, "", "INSUFFICIENT_STORAGE"], [4, 3, 1, "", "OTHER"], [4, 3, 1, "", "PARTITION"], [4, 3, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[4, 4, 1, "", "feedback"], [4, 4, 1, "", "load"], [4, 4, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[4, 3, 1, "", "offset"], [4, 3, 1, "", "perf"], [4, 3, 1, "", "rank"], [4, 3, 1, "", "size"], [4, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[4, 4, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[4, 3, 1, "", "batch_size"], [4, 3, 1, "", "bounds_check_mode"], [4, 5, 1, "", "cache_load_factor"], [4, 3, 1, "", "cache_params"], [4, 3, 1, "", "compute_kernel"], [4, 3, 1, "", "dependency"], [4, 3, 1, "", "enforce_hbm"], [4, 3, 1, "", "feature_names"], [4, 5, 1, "", "fqn"], [4, 3, 1, "", "input_lengths"], [4, 5, 1, "id14", "is_pooled"], [4, 5, 1, "id15", "module"], [4, 4, 1, "", "module_pooled"], [4, 3, 1, "", "name"], [4, 5, 1, "", "num_inputs"], [4, 5, 1, "", "num_shards"], [4, 3, 1, "", "output_dtype"], [4, 5, 1, "", "path"], [4, 3, 1, "", "sharding_type"], [4, 3, 1, "", "shards"], [4, 3, 1, "", "stochastic_rounding"], [4, 5, 1, "id16", "tensor"], [4, 5, 1, "", "total_perf"], [4, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[4, 4, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[4, 3, 1, "", "ddr"], [4, 4, 1, "", "fits_in"], [4, 3, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[4, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[4, 5, 1, "", "bwd_compute_multiplier"], [4, 5, 1, "", "compute_device"], [4, 5, 1, "", "ddr_mem_bw"], [4, 5, 1, "", "devices"], [4, 5, 1, "", "hbm_mem_bw"], [4, 5, 1, "", "inter_host_bw"], [4, 5, 1, "", "intra_host_bw"], [4, 5, 1, "", "local_world_size"], [4, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[4, 2, 1, "", "BinarySearchPredicate"], [4, 2, 1, "", "LuusJaakolaSearch"], [4, 1, 1, "", "bytes_to_gb"], [4, 1, 1, "", "bytes_to_mb"], [4, 1, 1, "", "gb_to_bytes"], [4, 1, 1, "", "placement"], [4, 1, 1, "", "prod"], [4, 1, 1, "", "reset_shard_rank"], [4, 1, 1, "", "sharder_name"], [4, 1, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[4, 4, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[4, 4, 1, "", "best"], [4, 4, 1, "", "clamp"], [4, 4, 1, "", "next"], [4, 4, 1, "", "shrink_right"], [4, 4, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[3, 2, 1, "", "QuantEmbeddingBagCollectionSharder"], [3, 2, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [3, 2, 1, "", "ShardedQuantEbcInputDist"], [3, 2, 1, "", "ShardedQuantEmbeddingBagCollection"], [3, 2, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [3, 1, 1, "", "create_infer_embedding_bag_sharding"], [3, 1, 1, "", "flatten_feature_lengths"], [3, 1, 1, "", "get_device_from_parameter_sharding"], [3, 1, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "module_type"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[3, 4, 1, "", "forward"], [3, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "copy"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "embedding_bag_configs"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 4, 1, "", "sharding_type_to_sharding_infos"], [3, 5, 1, "", "shardings"], [3, 4, 1, "", "tbes_configs"], [3, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[3, 4, 1, "", "apply_feature_processor"], [3, 4, 1, "", "compute"], [3, 3, 1, "", "embedding_bags"], [3, 3, 1, "", "tbes"], [3, 3, 1, "", "training"]], "torchrec.distributed.sharding": [[5, 0, 0, "-", "cw_sharding"], [5, 0, 0, "-", "dp_sharding"], [5, 0, 0, "-", "rw_sharding"], [5, 0, 0, "-", "tw_sharding"], [5, 0, 0, "-", "twcw_sharding"], [5, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[5, 2, 1, "", "BaseCwEmbeddingSharding"], [5, 2, 1, "", "CwPooledEmbeddingSharding"], [5, 2, 1, "", "InferCwPooledEmbeddingDist"], [5, 2, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [5, 2, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "uncombined_embedding_dims"], [5, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[5, 2, 1, "", "BaseDpEmbeddingSharding"], [5, 2, 1, "", "DpPooledEmbeddingDist"], [5, 2, 1, "", "DpPooledEmbeddingSharding"], [5, 2, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[5, 2, 1, "", "BaseRwEmbeddingSharding"], [5, 2, 1, "", "InferCPURwSparseFeaturesDist"], [5, 2, 1, "", "InferRwPooledEmbeddingDist"], [5, 2, 1, "", "InferRwPooledEmbeddingSharding"], [5, 2, 1, "", "InferRwSparseFeaturesDist"], [5, 2, 1, "", "RwPooledEmbeddingDist"], [5, 2, 1, "", "RwPooledEmbeddingSharding"], [5, 2, 1, "", "RwSparseFeaturesDist"], [5, 1, 1, "", "get_block_sizes_runtime_device"], [5, 1, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferCPURwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[5, 2, 1, "", "BaseTwEmbeddingSharding"], [5, 2, 1, "", "InferTwEmbeddingSharding"], [5, 2, 1, "", "InferTwPooledEmbeddingDist"], [5, 2, 1, "", "InferTwSparseFeaturesDist"], [5, 2, 1, "", "TwPooledEmbeddingDist"], [5, 2, 1, "", "TwPooledEmbeddingSharding"], [5, 2, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "embedding_tables"], [5, 4, 1, "", "feature_names"], [5, 4, 1, "", "feature_names_per_rank"], [5, 4, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[5, 2, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[5, 2, 1, "", "BaseTwRwEmbeddingSharding"], [5, 2, 1, "", "TwRwPooledEmbeddingDist"], [5, 2, 1, "", "TwRwPooledEmbeddingSharding"], [5, 2, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[5, 4, 1, "", "embedding_dims"], [5, 4, 1, "", "embedding_names"], [5, 4, 1, "", "embedding_names_per_rank"], [5, 4, 1, "", "embedding_shard_metadata"], [5, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[5, 4, 1, "", "create_input_dist"], [5, 4, 1, "", "create_lookup"], [5, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[5, 4, 1, "", "forward"], [5, 3, 1, "", "training"]], "torchrec.distributed.types": [[3, 2, 1, "", "Awaitable"], [3, 2, 1, "", "CacheParams"], [3, 2, 1, "", "CacheStatistics"], [3, 2, 1, "", "CommOp"], [3, 2, 1, "", "ComputeKernel"], [3, 2, 1, "", "EmbeddingModuleShardingPlan"], [3, 2, 1, "", "GenericMeta"], [3, 2, 1, "", "GetItemLazyAwaitable"], [3, 2, 1, "", "LazyAwaitable"], [3, 2, 1, "", "LazyGetItemMixin"], [3, 2, 1, "", "LazyNoWait"], [3, 2, 1, "", "ModuleSharder"], [3, 2, 1, "", "ModuleShardingPlan"], [3, 2, 1, "", "NoOpQuantizedCommCodec"], [3, 2, 1, "", "NoWait"], [3, 2, 1, "", "NullShardedModuleContext"], [3, 2, 1, "", "NullShardingContext"], [3, 2, 1, "", "ParameterSharding"], [3, 2, 1, "", "ParameterStorage"], [3, 2, 1, "", "QuantizedCommCodec"], [3, 2, 1, "", "QuantizedCommCodecs"], [3, 2, 1, "", "ShardedModule"], [3, 2, 1, "", "ShardingEnv"], [3, 2, 1, "", "ShardingPlan"], [3, 2, 1, "", "ShardingPlanner"], [3, 2, 1, "", "ShardingType"], [3, 1, 1, "", "get_tensor_size_bytes"], [3, 1, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[3, 5, 1, "", "callbacks"], [3, 4, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[3, 3, 1, "id34", "algorithm"], [3, 3, 1, "id35", "load_factor"], [3, 3, 1, "id36", "precision"], [3, 3, 1, "id37", "prefetch_pipeline"], [3, 3, 1, "id38", "reserved_memory"], [3, 3, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[3, 5, 1, "", "cacheability"], [3, 5, 1, "", "expected_lookups"], [3, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[3, 3, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [3, 3, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [3, 3, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[3, 3, 1, "", "DEFAULT"]], "torchrec.distributed.types.ModuleSharder": [[3, 4, 1, "", "compute_kernels"], [3, 5, 1, "", "module_type"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "shard"], [3, 4, 1, "", "shardable_parameters"], [3, 4, 1, "", "sharding_types"], [3, 4, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[3, 4, 1, "", "calc_quantized_size"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "decode"], [3, 4, 1, "", "encode"], [3, 4, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[3, 4, 1, "", "record_stream"]], "torchrec.distributed.types.ParameterSharding": [[3, 3, 1, "", "bounds_check_mode"], [3, 3, 1, "", "cache_params"], [3, 3, 1, "", "compute_kernel"], [3, 3, 1, "", "enforce_hbm"], [3, 3, 1, "", "output_dtype"], [3, 3, 1, "", "ranks"], [3, 3, 1, "", "sharding_spec"], [3, 3, 1, "", "sharding_type"], [3, 3, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[3, 3, 1, "", "DDR"], [3, 3, 1, "", "HBM"]], "torchrec.distributed.types.QuantizedCommCodec": [[3, 4, 1, "", "calc_quantized_size"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "decode"], [3, 4, 1, "", "encode"], [3, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[3, 3, 1, "", "backward"], [3, 3, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[3, 4, 1, "", "compute"], [3, 4, 1, "", "compute_and_output_dist"], [3, 4, 1, "", "create_context"], [3, 4, 1, "", "forward"], [3, 4, 1, "", "input_dist"], [3, 4, 1, "", "output_dist"], [3, 5, 1, "", "qcomm_codecs_registry"], [3, 4, 1, "", "sharded_parameter_names"], [3, 3, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[3, 4, 1, "", "from_local"], [3, 4, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[3, 4, 1, "", "get_plan_for_module"], [3, 3, 1, "id40", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[3, 4, 1, "", "collective_plan"], [3, 4, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[3, 3, 1, "", "COLUMN_WISE"], [3, 3, 1, "", "DATA_PARALLEL"], [3, 3, 1, "", "ROW_WISE"], [3, 3, 1, "", "TABLE_COLUMN_WISE"], [3, 3, 1, "", "TABLE_ROW_WISE"], [3, 3, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[3, 2, 1, "", "CopyableMixin"], [3, 1, 1, "", "add_params_from_parameter_sharding"], [3, 1, 1, "", "add_prefix_to_state_dict"], [3, 1, 1, "", "append_prefix"], [3, 1, 1, "", "convert_to_fbgemm_types"], [3, 1, 1, "", "copy_to_device"], [3, 1, 1, "", "filter_state_dict"], [3, 1, 1, "", "get_unsharded_module_names"], [3, 1, 1, "", "init_parameters"], [3, 1, 1, "", "merge_fused_params"], [3, 1, 1, "", "none_throws"], [3, 1, 1, "", "optimizer_type_to_emb_opt_type"], [3, 2, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[3, 4, 1, "", "copy"], [3, 3, 1, "", "training"]], "torchrec.fx": [[6, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[6, 2, 1, "", "Tracer"], [6, 1, 1, "", "is_fx_tracing"], [6, 1, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[6, 4, 1, "", "create_arg"], [6, 3, 1, "", "graph"], [6, 4, 1, "", "is_leaf_module"], [6, 3, 1, "", "module_stack"], [6, 3, 1, "", "node_name_to_scope"], [6, 4, 1, "", "path_of_module"], [6, 3, 1, "", "scope"], [6, 4, 1, "", "trace"]], "torchrec.inference": [[7, 0, 0, "-", "model_packager"], [7, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[7, 2, 1, "", "PredictFactoryPackager"], [7, 1, 1, "", "load_config_text"], [7, 1, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[7, 4, 1, "", "save_predict_factory"], [7, 4, 1, "", "set_extern_modules"], [7, 4, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[7, 2, 1, "", "BatchingMetadata"], [7, 2, 1, "", "PredictFactory"], [7, 2, 1, "", "PredictModule"], [7, 2, 1, "", "QualNameMetadata"], [7, 1, 1, "", "quantize_dense"], [7, 1, 1, "", "quantize_embeddings"], [7, 1, 1, "", "quantize_feature"], [7, 1, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[7, 3, 1, "", "device"], [7, 3, 1, "", "pinned"], [7, 3, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[7, 4, 1, "", "batching_metadata"], [7, 4, 1, "", "batching_metadata_json"], [7, 4, 1, "", "create_predict_module"], [7, 4, 1, "", "model_inputs_data"], [7, 4, 1, "", "qualname_metadata"], [7, 4, 1, "", "qualname_metadata_json"], [7, 4, 1, "", "result_metadata"], [7, 4, 1, "", "run_weights_dependent_transformations"], [7, 4, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[7, 4, 1, "", "forward"], [7, 4, 1, "", "predict_forward"], [7, 5, 1, "", "predict_module"], [7, 4, 1, "", "state_dict"], [7, 3, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[7, 3, 1, "", "need_preproc"]], "torchrec.models": [[8, 0, 0, "-", "deepfm"]], "torchrec.models.deepfm": [[8, 2, 1, "", "DenseArch"], [8, 2, 1, "", "FMInteractionArch"], [8, 2, 1, "", "OverArch"], [8, 2, 1, "", "SimpleDeepFMNN"], [8, 2, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[8, 4, 1, "", "forward"], [8, 3, 1, "", "training"]], "torchrec.modules": [[9, 0, 0, "-", "activation"], [9, 0, 0, "-", "crossnet"], [9, 0, 0, "-", "deepfm"], [9, 0, 0, "-", "embedding_configs"], [9, 0, 0, "-", "embedding_modules"], [9, 0, 0, "-", "feature_processor"], [9, 0, 0, "-", "lazy_extension"], [9, 0, 0, "-", "mc_embedding_modules"], [9, 0, 0, "-", "mc_modules"], [9, 0, 0, "-", "mlp"], [9, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[9, 2, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet": [[9, 2, 1, "", "CrossNet"], [9, 2, 1, "", "LowRankCrossNet"], [9, 2, 1, "", "LowRankMixtureCrossNet"], [9, 2, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.deepfm": [[9, 2, 1, "", "DeepFM"], [9, 2, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_configs": [[9, 2, 1, "", "BaseEmbeddingConfig"], [9, 2, 1, "", "EmbeddingBagConfig"], [9, 2, 1, "", "EmbeddingConfig"], [9, 2, 1, "", "EmbeddingTableConfig"], [9, 2, 1, "", "PoolingType"], [9, 2, 1, "", "QuantConfig"], [9, 2, 1, "", "ShardingType"], [9, 1, 1, "", "data_type_to_dtype"], [9, 1, 1, "", "data_type_to_sparse_type"], [9, 1, 1, "", "dtype_to_data_type"], [9, 1, 1, "", "pooling_type_to_pooling_mode"], [9, 1, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[9, 3, 1, "", "data_type"], [9, 3, 1, "", "embedding_dim"], [9, 3, 1, "", "feature_names"], [9, 4, 1, "", "get_weight_init_max"], [9, 4, 1, "", "get_weight_init_min"], [9, 3, 1, "", "init_fn"], [9, 3, 1, "", "name"], [9, 3, 1, "", "need_pos"], [9, 3, 1, "", "num_embeddings"], [9, 4, 1, "", "num_features"], [9, 3, 1, "", "pruning_indices_remapping"], [9, 3, 1, "", "weight_init_max"], [9, 3, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[9, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[9, 3, 1, "", "embedding_dim"], [9, 3, 1, "", "feature_names"], [9, 3, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[9, 3, 1, "", "embedding_names"], [9, 3, 1, "", "has_feature_processor"], [9, 3, 1, "", "is_weighted"], [9, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[9, 3, 1, "", "MEAN"], [9, 3, 1, "", "NONE"], [9, 3, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[9, 3, 1, "", "activation"], [9, 3, 1, "", "per_table_weight_dtype"], [9, 3, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[9, 3, 1, "", "COLUMN_WISE"], [9, 3, 1, "", "DATA_PARALLEL"], [9, 3, 1, "", "ROW_WISE"], [9, 3, 1, "", "TABLE_COLUMN_WISE"], [9, 3, 1, "", "TABLE_ROW_WISE"], [9, 3, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[9, 2, 1, "", "EmbeddingBagCollection"], [9, 2, 1, "", "EmbeddingBagCollectionInterface"], [9, 2, 1, "", "EmbeddingCollection"], [9, 2, 1, "", "EmbeddingCollectionInterface"], [9, 1, 1, "", "get_embedding_names_by_table"], [9, 1, 1, "", "process_pooled_embeddings"], [9, 1, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[9, 5, 1, "", "device"], [9, 4, 1, "", "embedding_bag_configs"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "is_weighted"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[9, 4, 1, "", "embedding_bag_configs"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "is_weighted"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[9, 5, 1, "", "device"], [9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "embedding_dim"], [9, 4, 1, "", "embedding_names_by_table"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "need_indices"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "embedding_dim"], [9, 4, 1, "", "embedding_names_by_table"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "need_indices"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor": [[9, 2, 1, "", "BaseFeatureProcessor"], [9, 2, 1, "", "BaseGroupedFeatureProcessor"], [9, 2, 1, "", "PositionWeightedModule"], [9, 2, 1, "", "PositionWeightedProcessor"], [9, 1, 1, "", "offsets_to_range_traceble"], [9, 1, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[9, 4, 1, "", "forward"], [9, 4, 1, "", "reset_parameters"], [9, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[9, 4, 1, "", "forward"], [9, 4, 1, "", "named_buffers"], [9, 4, 1, "", "state_dict"], [9, 3, 1, "", "training"]], "torchrec.modules.lazy_extension": [[9, 2, 1, "", "LazyModuleExtensionMixin"], [9, 1, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[9, 4, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[9, 2, 1, "", "BaseManagedCollisionEmbeddingCollection"], [9, 2, 1, "", "ManagedCollisionEmbeddingBagCollection"], [9, 2, 1, "", "ManagedCollisionEmbeddingCollection"], [9, 1, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[9, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[9, 3, 1, "", "training"]], "torchrec.modules.mc_modules": [[9, 2, 1, "", "DistanceLFU_EvictionPolicy"], [9, 2, 1, "", "LFU_EvictionPolicy"], [9, 2, 1, "", "LRU_EvictionPolicy"], [9, 2, 1, "", "MCHEvictionPolicy"], [9, 2, 1, "", "MCHEvictionPolicyMetadataInfo"], [9, 2, 1, "", "MCHManagedCollisionModule"], [9, 2, 1, "", "ManagedCollisionCollection"], [9, 2, 1, "", "ManagedCollisionModule"], [9, 1, 1, "", "apply_mc_method_to_jt_dict"], [9, 1, 1, "", "average_threshold_filter"], [9, 1, 1, "", "dynamic_threshold_filter"], [9, 1, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[9, 4, 1, "", "coalesce_history_metadata"], [9, 5, 1, "", "metadata_info"], [9, 4, 1, "", "record_history_metadata"], [9, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[9, 3, 1, "", "is_history_metadata"], [9, 3, 1, "", "is_mch_metadata"], [9, 3, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "input_size"], [9, 4, 1, "", "output_size"], [9, 4, 1, "", "preprocess"], [9, 4, 1, "", "profile"], [9, 4, 1, "", "rebuild_with_output_id_range"], [9, 4, 1, "", "remap"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[9, 4, 1, "", "embedding_configs"], [9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[9, 5, 1, "", "device"], [9, 4, 1, "", "evict"], [9, 4, 1, "", "forward"], [9, 4, 1, "", "input_size"], [9, 4, 1, "", "output_size"], [9, 4, 1, "", "preprocess"], [9, 4, 1, "", "rebuild_with_output_id_range"], [9, 3, 1, "", "training"]], "torchrec.modules.mlp": [[9, 2, 1, "", "MLP"], [9, 2, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[9, 4, 1, "", "forward"], [9, 3, 1, "", "training"]], "torchrec.modules.utils": [[9, 1, 1, "", "check_module_output_dimension"], [9, 1, 1, "", "construct_jagged_tensors"], [9, 1, 1, "", "construct_jagged_tensors_inference"], [9, 1, 1, "", "construct_modulelist_from_single_module"], [9, 1, 1, "", "convert_list_of_modules_to_modulelist"], [9, 1, 1, "", "extract_module_or_tensor_callable"], [9, 1, 1, "", "get_module_output_dimension"], [9, 1, 1, "", "init_mlp_weights_xavier_uniform"]], "torchrec.optim": [[10, 0, 0, "-", "clipping"], [10, 0, 0, "-", "fused"], [10, 0, 0, "-", "keyed"], [10, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[10, 2, 1, "", "GradientClipping"], [10, 2, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[10, 3, 1, "", "NONE"], [10, 3, 1, "", "NORM"], [10, 3, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[10, 4, 1, "", "step"]], "torchrec.optim.fused": [[10, 2, 1, "", "EmptyFusedOptimizer"], [10, 2, 1, "", "FusedOptimizer"], [10, 2, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[10, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[10, 2, 1, "", "CombinedOptimizer"], [10, 2, 1, "", "KeyedOptimizer"], [10, 2, 1, "", "KeyedOptimizerWrapper"], [10, 2, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[10, 5, 1, "", "optimizers"], [10, 5, 1, "", "param_groups"], [10, 5, 1, "", "params"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "prepend_opt_key"], [10, 4, 1, "", "save_param_groups"], [10, 5, 1, "", "state"], [10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[10, 4, 1, "", "add_param_group"], [10, 4, 1, "", "init_state"], [10, 4, 1, "", "load_state_dict"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "save_param_groups"], [10, 4, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[10, 4, 1, "", "add_param_group"], [10, 4, 1, "", "load_state_dict"], [10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "save_param_groups"], [10, 4, 1, "", "state_dict"], [10, 4, 1, "", "step"], [10, 4, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[10, 2, 1, "", "WarmupOptimizer"], [10, 2, 1, "", "WarmupPolicy"], [10, 2, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[10, 4, 1, "", "post_load_state_dict"], [10, 4, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[10, 3, 1, "", "CONSTANT"], [10, 3, 1, "", "INVSQRT"], [10, 3, 1, "", "LINEAR"], [10, 3, 1, "", "NONE"], [10, 3, 1, "", "POLY"], [10, 3, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[10, 3, 1, "", "decay_iters"], [10, 3, 1, "", "lr_scale"], [10, 3, 1, "", "max_iters"], [10, 3, 1, "", "policy"], [10, 3, 1, "", "value"]], "torchrec.quant": [[11, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[11, 2, 1, "", "EmbeddingBagCollection"], [11, 2, 1, "", "EmbeddingCollection"], [11, 2, 1, "", "FeatureProcessedEmbeddingBagCollection"], [11, 1, 1, "", "for_each_module_of_type_do"], [11, 1, 1, "", "pruned_num_embeddings"], [11, 1, 1, "", "quant_prep_customize_row_alignment"], [11, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [11, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [11, 1, 1, "", "quant_prep_enable_register_tbes"], [11, 1, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_bag_configs"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "output_dtype"], [11, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_configs"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "embedding_names_by_table"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 4, 1, "", "need_indices"], [11, 4, 1, "", "output_dtype"], [11, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[11, 3, 1, "", "embedding_bags"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "from_float"], [11, 3, 1, "", "tbes"], [11, 3, 1, "", "training"]], "torchrec.sparse": [[12, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[12, 2, 1, "", "ComputeJTDictToKJT"], [12, 2, 1, "", "ComputeKJTToJTDict"], [12, 2, 1, "", "JaggedTensor"], [12, 2, 1, "", "JaggedTensorMeta"], [12, 2, 1, "", "KeyedJaggedTensor"], [12, 2, 1, "", "KeyedTensor"], [12, 1, 1, "", "flatten_kjt_list"], [12, 1, 1, "", "is_non_strict_exporting"], [12, 1, 1, "", "jt_is_equal"], [12, 1, 1, "", "kjt_is_equal"], [12, 1, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[12, 4, 1, "", "forward"], [12, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[12, 4, 1, "", "forward"], [12, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[12, 4, 1, "", "empty"], [12, 4, 1, "", "from_dense"], [12, 4, 1, "", "from_dense_lengths"], [12, 4, 1, "", "lengths"], [12, 4, 1, "", "lengths_or_none"], [12, 4, 1, "", "offsets"], [12, 4, 1, "", "offsets_or_none"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dense"], [12, 4, 1, "", "to_dense_weights"], [12, 4, 1, "", "to_padded_dense"], [12, 4, 1, "", "to_padded_dense_weights"], [12, 4, 1, "", "values"], [12, 4, 1, "", "weights"], [12, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[12, 4, 1, "", "concat"], [12, 4, 1, "", "device"], [12, 4, 1, "", "dist_init"], [12, 4, 1, "", "dist_labels"], [12, 4, 1, "", "dist_splits"], [12, 4, 1, "", "dist_tensors"], [12, 4, 1, "", "empty"], [12, 4, 1, "", "empty_like"], [12, 4, 1, "", "flatten_lengths"], [12, 4, 1, "", "from_jt_dict"], [12, 4, 1, "", "from_lengths_sync"], [12, 4, 1, "", "from_offsets_sync"], [12, 4, 1, "", "index_per_key"], [12, 4, 1, "", "inverse_indices"], [12, 4, 1, "", "inverse_indices_or_none"], [12, 4, 1, "", "keys"], [12, 4, 1, "", "length_per_key"], [12, 4, 1, "", "length_per_key_or_none"], [12, 4, 1, "", "lengths"], [12, 4, 1, "", "lengths_offset_per_key"], [12, 4, 1, "", "lengths_or_none"], [12, 4, 1, "", "offset_per_key"], [12, 4, 1, "", "offset_per_key_or_none"], [12, 4, 1, "", "offsets"], [12, 4, 1, "", "offsets_or_none"], [12, 4, 1, "", "permute"], [12, 4, 1, "", "pin_memory"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "split"], [12, 4, 1, "", "stride"], [12, 4, 1, "", "stride_per_key"], [12, 4, 1, "", "stride_per_key_per_rank"], [12, 4, 1, "", "sync"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dict"], [12, 4, 1, "", "unsync"], [12, 4, 1, "", "values"], [12, 4, 1, "", "variable_stride_per_key"], [12, 4, 1, "", "weights"], [12, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[12, 4, 1, "", "from_tensor_list"], [12, 4, 1, "", "key_dim"], [12, 4, 1, "", "keys"], [12, 4, 1, "", "length_per_key"], [12, 4, 1, "", "offset_per_key"], [12, 4, 1, "", "record_stream"], [12, 4, 1, "", "regroup"], [12, 4, 1, "", "regroup_as_dict"], [12, 4, 1, "", "to"], [12, 4, 1, "", "to_dict"], [12, 4, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:attribute", "4": "py:method", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "method", "Python method"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], "document": 0, "tutori": 0, "api": 0, "content": [0, 6, 7, 8, 10, 11, 12], "indic": 0, "tabl": 0, "dataset": [1, 2], "criteo": 1, "movielen": 1, "random": 1, "util": [1, 3, 4, 9], "script": 2, "contiguous_preproc_criteo": 2, "npy_preproc_criteo": 2, "distribut": [3, 4, 5], "collective_util": 3, "comm": 3, "comm_op": 3, "dist_data": [3, 5], "embed": 3, "embedding_lookup": 3, "embedding_shard": 3, "embedding_typ": 3, "embeddingbag": 3, "grouped_position_weight": 3, "model_parallel": 3, "quant_embeddingbag": 3, "train_pipelin": 3, "type": [3, 4], "mc_modul": [3, 9], "mc_embeddingbag": 3, "mc_embed": 3, "planner": 4, "constant": 4, "enumer": 4, "partition": 4, "perf_model": 4, "propos": 4, "shard_estim": 4, "stat": 4, "storage_reserv": 4, "shard": 5, "cw_shard": 5, "dp_shard": 5, "rw_shard": 5, "tw_shard": 5, "twcw_shard": 5, "twrw_shard": 5, "fx": 6, "tracer": 6, "modul": [6, 7, 8, 9, 10, 11, 12], "infer": 7, "model_packag": 7, "model": 8, "deepfm": [8, 9], "dlrm": 8, "activ": 9, "crossnet": 9, "embedding_config": 9, "embedding_modul": [9, 11], "feature_processor": 9, "lazy_extens": 9, "mlp": 9, "mc_embedding_modul": 9, "optim": 10, "clip": 10, "fuse": 10, "kei": 10, "warmup": 10, "quant": 11, "spars": 12, "jagged_tensor": 12}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 56}})
\ No newline at end of file
diff --git a/torchrec.distributed.planner.html b/torchrec.distributed.planner.html
index 32f7d01de..864e7bc64 100644
--- a/torchrec.distributed.planner.html
+++ b/torchrec.distributed.planner.html
@@ -1167,6 +1167,28 @@
torchrec.distributed.planner.types
+
+-
+class torchrec.distributed.planner.types.CustomTopologyData(data: Dict[str, List[int]], world_size: int)
+Bases: object
+Custom device data for individual device in a topology.
+
+-
+get_data(key: str) → List[int]
+
+
+
+-
+has_data(key: str) → bool
+
+
+
+-
+supported_fields = ['ddr_cap', 'hbm_cap']
+
+
+
+
-
class torchrec.distributed.planner.types.DeviceHardware(rank: int, storage: Storage, perf: Perf)
@@ -1959,7 +1981,7 @@
-
-class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, hbm_mem_bw: float = 963146416.128, ddr_mem_bw: float = 54760833.024, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8, bwd_compute_multiplier: float = 2)
+class torchrec.distributed.planner.types.Topology(world_size: int, compute_device: str, hbm_cap: Optional[int] = None, ddr_cap: Optional[int] = None, local_world_size: Optional[int] = None, hbm_mem_bw: float = 963146416.128, ddr_mem_bw: float = 54760833.024, intra_host_bw: float = 644245094.4, inter_host_bw: float = 13421772.8, bwd_compute_multiplier: float = 2, custom_topology_data: Optional[CustomTopologyData] = None)
Bases: object
-
diff --git a/torchrec.quant.html b/torchrec.quant.html
index 4a33d4bd1..5177abda6 100644
--- a/torchrec.quant.html
+++ b/torchrec.quant.html
@@ -647,11 +647,6 @@
-
--
-torchrec.quant.embedding_modules.features_to_dict(features: KeyedJaggedTensor) → Dict[str, JaggedTensor]
-
-
-
torchrec.quant.embedding_modules.for_each_module_of_type_do(module: Module, module_types: List[Type[Module]], op: Callable[[Module], None]) → None
|