From 26d4244c4880fd7c9c132c2c23db6cc3160bfb24 Mon Sep 17 00:00:00 2001 From: Yinbin Ma Date: Wed, 7 Aug 2024 18:33:47 -0700 Subject: [PATCH] make dict attribute of train input pipelineable (#2278) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2278 When the training input has attribute of dict, accessing those with "get" will break the train pipeline. Here we add support to make dict attribute pipelineable. Reviewed By: ge0405 Differential Revision: D60635872 fbshipit-source-id: f0b2c314ab90a5e4247bc5c75846633b39630cf2 --- torchrec/distributed/train_pipeline/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 00c8a4295..8096d2fc3 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -790,6 +790,12 @@ def _get_node_args_helper( arg = child_node.kwargs["values"] else: arg = child_node.args[1] + elif child_node.op == "call_method" and child_node.target == "get": + # pyre-ignore[6] + arg_info.input_attrs.insert(0, child_node.args[1]) + arg_info.is_getitems.insert(0, True) + arg_info.preproc_modules.insert(0, None) + arg = child_node.args[0] elif child_node.op == "call_module": preproc_module_fqn = str(child_node.target) preproc_module = getattr(model, preproc_module_fqn, None)