diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 37665a6da8..7876806d6d 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool: if not is_param_node(program, node): return True + # Annotate that this node is going to represented as a tensorref in the Vulkan + # compute graph. This will be useful for later graph passes. + node.meta["vkdg_tensorref"] = True + for user in node.users: if user.op == "call_function" and handles_own_prepacking( # pyre-ignore diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index fd0bd3648e..0a6a2d42d4 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -39,6 +39,30 @@ def set_memory_metadata( utils.set_node_spec_attr(node, "vk_memory_layout", layout) +def insert_transition_node( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + arg: torch.fx.Node, + storage: VkStorageType, + layout: VkMemoryLayout, +) -> None: + """ + Insert a clone node to copy the original tensor to a tensor with the desired storage + type and memory layout. + """ + with graph_module.graph.inserting_before(node): + clone_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.clone.default, + (arg,), + ) + clone_node.meta["val"] = arg.meta["val"] + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"].const = False + set_memory_metadata(clone_node, storage, layout) + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + + class TagMemoryMetaPass(ExportPass): """ There are a variety of ways that tensors can be represented in Vulkan. The two main @@ -174,14 +198,33 @@ def propose_node_layout( else: return next(iter(valid_layouts)) + def should_annotate(self, node) -> bool: + if not isinstance(node, torch.fx.Node): + return False + + if not isinstance(node.meta["val"], FakeTensor): + return False + + # Storage type and memory layout for tensorref will be determined at runtime + # so there's no use in setting those attributes ahead of time. + if node.meta.get("vkdg_tensorref", False): + return False + + return True + + def should_delay_annotation(self, node: torch.fx.Node) -> bool: + # For prepack nodes, delay setting the storage type and memory layout as long as + # possible. This is to minimize the number of transitions, since it can be + # difficult to predict what storage type and memory layout should be used at the + # time the prepack node is observed. + return node.target == exir_ops.edge.et_vk.prepack.default + + # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) for node in sorted_nodes: - if not isinstance(node.meta["val"], FakeTensor): - continue - - if node.target == exir_ops.edge.et_vk.prepack.default: + if not self.should_annotate(node) or self.should_delay_annotation(node): continue storage = self.propose_node_storage(node) @@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: inserting_transitions_for_node = False for i, arg in enumerate(node.args): - if not isinstance(arg, torch.fx.Node): - continue - if not isinstance(arg.meta["val"], FakeTensor): + if not self.should_annotate(arg): continue + assert isinstance(arg, torch.fx.Node) + arg_storage = utils.get_node_storage_type(arg) arg_layout = utils.get_node_memory_layout(arg) @@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" ) + insert_transition_node(graph_module, node, arg, storage, layout) + logger.info( f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" ) - # Insert a clone node to copy the original tensor to a tensor with the - # desired storage type and memory layout. - with graph_module.graph.inserting_before(node): - clone_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.clone.default, - (arg,), - ) - clone_node.meta["val"] = arg.meta["val"] - clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) - clone_node.meta["spec"].const = False - set_memory_metadata(clone_node, storage, layout) - arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) - return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 3a6191bccb..eeec5ab37e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -90,6 +90,9 @@ class OpFeatures: # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. "handles_own_prepacking", + # Optional dictionary to specify a custom function to calculate the required + # image extents for a particular argument index. + "skip_limits_check", # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "check_node_fn", @@ -103,6 +106,7 @@ def __init__( optimal_storage: Optional[VkStorageType] = None, optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, + skip_limits_check: Optional[Set[int]] = None, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl @@ -111,6 +115,11 @@ def __init__( self.optimal_storage: Optional[VkStorageType] = optimal_storage self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking + + self.skip_limits_check: Set[int] = set() + if skip_limits_check is not None: + self.skip_limits_check = skip_limits_check + self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn @@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures): features.optimal_storage = VkStorageType.TEXTURE_3D features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True + features.skip_limits_check = {1, 2} return features diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 7b2ad3fdfd..64e672fd69 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -82,8 +82,13 @@ def op_node_is_compatible( valid_texture_layouts = utils.possible_node_memory_layouts( node, self.texture_limits ) - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): + + for i, arg in enumerate(node.args): + if ( + isinstance(arg, torch.fx.Node) + and utils.is_tensor_node(arg) + and i not in features.skip_limits_check + ): arg_texture_layouts = utils.possible_node_memory_layouts( arg, self.texture_limits )