Skip to content

Commit

Permalink
[ET-VK] Update partitioner to account for custom packed arguments
Browse files Browse the repository at this point in the history
Differential Revision: D65759236

Pull Request resolved: #6753
  • Loading branch information
SS-JIA authored Nov 11, 2024
1 parent 793f17e commit d5a0743
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 23 deletions.
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
if not is_param_node(program, node):
return True

# Annotate that this node is going to represented as a tensorref in the Vulkan
# compute graph. This will be useful for later graph passes.
node.meta["vkdg_tensorref"] = True

for user in node.users:
if user.op == "call_function" and handles_own_prepacking(
# pyre-ignore
Expand Down
73 changes: 52 additions & 21 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def set_memory_metadata(
utils.set_node_spec_attr(node, "vk_memory_layout", layout)


def insert_transition_node(
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
arg: torch.fx.Node,
storage: VkStorageType,
layout: VkMemoryLayout,
) -> None:
"""
Insert a clone node to copy the original tensor to a tensor with the desired storage
type and memory layout.
"""
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)


class TagMemoryMetaPass(ExportPass):
"""
There are a variety of ways that tensors can be represented in Vulkan. The two main
Expand Down Expand Up @@ -174,14 +198,33 @@ def propose_node_layout(
else:
return next(iter(valid_layouts))

def should_annotate(self, node) -> bool:
if not isinstance(node, torch.fx.Node):
return False

if not isinstance(node.meta["val"], FakeTensor):
return False

# Storage type and memory layout for tensorref will be determined at runtime
# so there's no use in setting those attributes ahead of time.
if node.meta.get("vkdg_tensorref", False):
return False

return True

def should_delay_annotation(self, node: torch.fx.Node) -> bool:
# For prepack nodes, delay setting the storage type and memory layout as long as
# possible. This is to minimize the number of transitions, since it can be
# difficult to predict what storage type and memory layout should be used at the
# time the prepack node is observed.
return node.target == exir_ops.edge.et_vk.prepack.default

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
if not isinstance(node.meta["val"], FakeTensor):
continue

if node.target == exir_ops.edge.et_vk.prepack.default:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

storage = self.propose_node_storage(node)
Expand All @@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

inserting_transitions_for_node = False
for i, arg in enumerate(node.args):
if not isinstance(arg, torch.fx.Node):
continue
if not isinstance(arg.meta["val"], FakeTensor):
if not self.should_annotate(arg):
continue

assert isinstance(arg, torch.fx.Node)

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

Expand All @@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

insert_transition_node(graph_module, node, arg, storage, layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
)

# Insert a clone node to copy the original tensor to a tensor with the
# desired storage type and memory layout.
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

return PassResult(graph_module, True)
10 changes: 10 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class OpFeatures:
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
"handles_own_prepacking",
# Optional dictionary to specify a custom function to calculate the required
# image extents for a particular argument index.
"skip_limits_check",
# Optional check function used during partitioning to determine if a node's
# inputs are supported by the operator implementation.
"check_node_fn",
Expand All @@ -103,6 +106,7 @@ def __init__(
optimal_storage: Optional[VkStorageType] = None,
optimal_layout: Optional[VkMemoryLayout] = None,
handles_own_prepacking: bool = False,
skip_limits_check: Optional[Set[int]] = None,
check_node_fn: Optional[Callable] = None,
):
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
Expand All @@ -111,6 +115,11 @@ def __init__(
self.optimal_storage: Optional[VkStorageType] = optimal_storage
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
self.handles_own_prepacking: bool = handles_own_prepacking

self.skip_limits_check: Set[int] = set()
if skip_limits_check is not None:
self.skip_limits_check = skip_limits_check

self.check_node_fn: Callable = allow_node
if check_node_fn is not None:
self.check_node_fn = check_node_fn
Expand Down Expand Up @@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures):
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1, 2}
return features


Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ def op_node_is_compatible(
valid_texture_layouts = utils.possible_node_memory_layouts(
node, self.texture_limits
)
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):

for i, arg in enumerate(node.args):
if (
isinstance(arg, torch.fx.Node)
and utils.is_tensor_node(arg)
and i not in features.skip_limits_check
):
arg_texture_layouts = utils.possible_node_memory_layouts(
arg, self.texture_limits
)
Expand Down

0 comments on commit d5a0743

Please sign in to comment.