Skip to content

Commit

Permalink
[ET-VK] Update partitioner to account for custom packed arguments
Browse files Browse the repository at this point in the history
## Problem

Convolution operators, especially for pointwise convolution, may have sizes like

```
W=1, H=1, C=320, N=1280
```

When represented as a texture, this tensor would normally require a texture with extents

```
(1, 1, 320 / 4 * 1280 = 102400)
```

which would normally exceed texture limits. The new partitioner system detects this and prevents nodes with similar weights from being lowered to Vulkan. However, the partitioner system does not account for the fact that the operator implementation uses a specialized prepacking algorithm which results in valid texture limits for the packed weights.

## Changes

* Add field to `OpFeatures` class to annotate that some arguments in an op should be skipped when checking against texture limits
* Update metadata tagging pass to ignore annotating constant tensor nodes so that they don't influence memory layout and storage type proposals. Without this change, the tagging pass will try to use buffer storage for the pointwise convolution since the weight can only be represented as a buffer under normal circumstances.

Differential Revision: [D65759236](https://our.internmc.facebook.com/intern/diff/D65759236/)

ghstack-source-id: 252885980
Pull Request resolved: #6753
  • Loading branch information
SS-JIA committed Nov 11, 2024
1 parent 793f17e commit dcce9da
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 23 deletions.
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
if not is_param_node(program, node):
return True

# Annotate that this node is going to represented as a tensorref in the Vulkan
# compute graph. This will be useful for later graph passes.
node.meta["vkdg_tensorref"] = True

for user in node.users:
if user.op == "call_function" and handles_own_prepacking(
# pyre-ignore
Expand Down
73 changes: 52 additions & 21 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def set_memory_metadata(
utils.set_node_spec_attr(node, "vk_memory_layout", layout)


def insert_transition_node(
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
arg: torch.fx.Node,
storage: VkStorageType,
layout: VkMemoryLayout,
) -> None:
"""
Insert a clone node to copy the original tensor to a tensor with the desired storage
type and memory layout.
"""
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)


class TagMemoryMetaPass(ExportPass):
"""
There are a variety of ways that tensors can be represented in Vulkan. The two main
Expand Down Expand Up @@ -174,14 +198,33 @@ def propose_node_layout(
else:
return next(iter(valid_layouts))

def should_annotate(self, node) -> bool:
if not isinstance(node, torch.fx.Node):
return False

if not isinstance(node.meta["val"], FakeTensor):
return False

# Storage type and memory layout for tensorref will be determined at runtime
# so there's no use in setting those attributes ahead of time.
if node.meta.get("vkdg_tensorref", False):
return False

return True

def should_delay_annotation(self, node: torch.fx.Node) -> bool:
# For prepack nodes, delay setting the storage type and memory layout as long as
# possible. This is to minimize the number of transitions, since it can be
# difficult to predict what storage type and memory layout should be used at the
# time the prepack node is observed.
return node.target == exir_ops.edge.et_vk.prepack.default

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
if not isinstance(node.meta["val"], FakeTensor):
continue

if node.target == exir_ops.edge.et_vk.prepack.default:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

storage = self.propose_node_storage(node)
Expand All @@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

inserting_transitions_for_node = False
for i, arg in enumerate(node.args):
if not isinstance(arg, torch.fx.Node):
continue
if not isinstance(arg.meta["val"], FakeTensor):
if not self.should_annotate(arg):
continue

assert isinstance(arg, torch.fx.Node)

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

Expand All @@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

insert_transition_node(graph_module, node, arg, storage, layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
)

# Insert a clone node to copy the original tensor to a tensor with the
# desired storage type and memory layout.
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

return PassResult(graph_module, True)
10 changes: 10 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class OpFeatures:
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
"handles_own_prepacking",
# Optional dictionary to specify a custom function to calculate the required
# image extents for a particular argument index.
"skip_limits_check",
# Optional check function used during partitioning to determine if a node's
# inputs are supported by the operator implementation.
"check_node_fn",
Expand All @@ -103,6 +106,7 @@ def __init__(
optimal_storage: Optional[VkStorageType] = None,
optimal_layout: Optional[VkMemoryLayout] = None,
handles_own_prepacking: bool = False,
skip_limits_check: Optional[Set[int]] = None,
check_node_fn: Optional[Callable] = None,
):
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
Expand All @@ -111,6 +115,11 @@ def __init__(
self.optimal_storage: Optional[VkStorageType] = optimal_storage
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
self.handles_own_prepacking: bool = handles_own_prepacking

self.skip_limits_check: Set[int] = set()
if skip_limits_check is not None:
self.skip_limits_check = skip_limits_check

self.check_node_fn: Callable = allow_node
if check_node_fn is not None:
self.check_node_fn = check_node_fn
Expand Down Expand Up @@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures):
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1, 2}
return features


Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ def op_node_is_compatible(
valid_texture_layouts = utils.possible_node_memory_layouts(
node, self.texture_limits
)
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):

for i, arg in enumerate(node.args):
if (
isinstance(arg, torch.fx.Node)
and utils.is_tensor_node(arg)
and i not in features.skip_limits_check
):
arg_texture_layouts = utils.possible_node_memory_layouts(
arg, self.texture_limits
)
Expand Down

0 comments on commit dcce9da

Please sign in to comment.