Skip to content

Commit

Permalink
make dict attribute of train input pipelineable (pytorch#2278)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2278

When the training input has attribute of dict, accessing those with "get" will break the train pipeline. Here we add support to make dict attribute pipelineable.

Reviewed By: ge0405

Differential Revision: D60635872

fbshipit-source-id: f0b2c314ab90a5e4247bc5c75846633b39630cf2
  • Loading branch information
Yinbin Ma authored and facebook-github-bot committed Aug 8, 2024
1 parent 9573fea commit 26d4244
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,12 @@ def _get_node_args_helper(
arg = child_node.kwargs["values"]
else:
arg = child_node.args[1]
elif child_node.op == "call_method" and child_node.target == "get":
# pyre-ignore[6]
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg_info.preproc_modules.insert(0, None)
arg = child_node.args[0]
elif child_node.op == "call_module":
preproc_module_fqn = str(child_node.target)
preproc_module = getattr(model, preproc_module_fqn, None)
Expand Down

0 comments on commit 26d4244

Please sign in to comment.