Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Update partitioner to account for custom packed arguments #6763

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
if not is_param_node(program, node):
return True

# Annotate that this node is going to represented as a tensorref in the Vulkan
# compute graph. This will be useful for later graph passes.
node.meta["vkdg_tensorref"] = True

for user in node.users:
if user.op == "call_function" and handles_own_prepacking(
# pyre-ignore
Expand Down
73 changes: 52 additions & 21 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def set_memory_metadata(
utils.set_node_spec_attr(node, "vk_memory_layout", layout)


def insert_transition_node(
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
arg: torch.fx.Node,
storage: VkStorageType,
layout: VkMemoryLayout,
) -> None:
"""
Insert a clone node to copy the original tensor to a tensor with the desired storage
type and memory layout.
"""
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)


class TagMemoryMetaPass(ExportPass):
"""
There are a variety of ways that tensors can be represented in Vulkan. The two main
Expand Down Expand Up @@ -174,14 +198,33 @@ def propose_node_layout(
else:
return next(iter(valid_layouts))

def should_annotate(self, node) -> bool:
if not isinstance(node, torch.fx.Node):
return False

if not isinstance(node.meta["val"], FakeTensor):
return False

# Storage type and memory layout for tensorref will be determined at runtime
# so there's no use in setting those attributes ahead of time.
if node.meta.get("vkdg_tensorref", False):
return False

return True

def should_delay_annotation(self, node: torch.fx.Node) -> bool:
# For prepack nodes, delay setting the storage type and memory layout as long as
# possible. This is to minimize the number of transitions, since it can be
# difficult to predict what storage type and memory layout should be used at the
# time the prepack node is observed.
return node.target == exir_ops.edge.et_vk.prepack.default

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
if not isinstance(node.meta["val"], FakeTensor):
continue

if node.target == exir_ops.edge.et_vk.prepack.default:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

storage = self.propose_node_storage(node)
Expand All @@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

inserting_transitions_for_node = False
for i, arg in enumerate(node.args):
if not isinstance(arg, torch.fx.Node):
continue
if not isinstance(arg.meta["val"], FakeTensor):
if not self.should_annotate(arg):
continue

assert isinstance(arg, torch.fx.Node)

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

Expand All @@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

insert_transition_node(graph_module, node, arg, storage, layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
)

# Insert a clone node to copy the original tensor to a tensor with the
# desired storage type and memory layout.
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

return PassResult(graph_module, True)
10 changes: 10 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class OpFeatures:
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
"handles_own_prepacking",
# Optional dictionary to specify a custom function to calculate the required
# image extents for a particular argument index.
"skip_limits_check",
# Optional check function used during partitioning to determine if a node's
# inputs are supported by the operator implementation.
"check_node_fn",
Expand All @@ -103,6 +106,7 @@ def __init__(
optimal_storage: Optional[VkStorageType] = None,
optimal_layout: Optional[VkMemoryLayout] = None,
handles_own_prepacking: bool = False,
skip_limits_check: Optional[Set[int]] = None,
check_node_fn: Optional[Callable] = None,
):
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
Expand All @@ -111,6 +115,11 @@ def __init__(
self.optimal_storage: Optional[VkStorageType] = optimal_storage
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
self.handles_own_prepacking: bool = handles_own_prepacking

self.skip_limits_check: Set[int] = set()
if skip_limits_check is not None:
self.skip_limits_check = skip_limits_check

self.check_node_fn: Callable = allow_node
if check_node_fn is not None:
self.check_node_fn = check_node_fn
Expand Down Expand Up @@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures):
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1, 2}
return features


Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ def op_node_is_compatible(
valid_texture_layouts = utils.possible_node_memory_layouts(
node, self.texture_limits
)
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):

for i, arg in enumerate(node.args):
if (
isinstance(arg, torch.fx.Node)
and utils.is_tensor_node(arg)
and i not in features.skip_limits_check
):
arg_texture_layouts = utils.possible_node_memory_layouts(
arg, self.texture_limits
)
Expand Down
Loading