Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[et][dim order] aot support for dim order variant empty op #7168

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

# Out variant drops TensorOptions
lib.define(
"_empty_dim_order(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int[]? dim_order=None) -> Tensor"
)

# Out variant of aten::_to_copy and aten::empty drops TensorOptions, so do their dim order variants
lib.define(
"_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)


def _op_impl(target, *args, **kwargs):
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
Expand All @@ -39,18 +47,30 @@ def _to_dim_order_copy_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)


@impl(lib, "_empty_dim_order", "CompositeImplicitAutograd")
def _empty_dim_order_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.memory_format, *args, **kwargs)


@impl(lib, "_empty_dim_order.out", "CompositeImplicitAutograd")
def _empty_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)


"""
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
"aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default,
}

"""
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
"""
MemoryFormatOpsMap = {
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
"dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format,
}

# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
Expand Down
25 changes: 15 additions & 10 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def call_operator(self, op, args, kwargs, meta):
kwargs,
meta,
)

# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

Expand All @@ -50,17 +51,20 @@ def call_operator(self, op, args, kwargs, meta):
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"

nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
logger.debug(
f"_to_copy = rank: {ndim}, memory_format: {mem_format}."
f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}"
)

t = DimOrderOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in DimOrderOpsMap"
t = DimOrderOpsMap[op.__name__]

return super().call_operator(
t,
Expand Down Expand Up @@ -92,8 +96,10 @@ def call_operator(self, op, args, kwargs, meta):
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"

# get the "to" memory format for the EdgeOp
default_dim_order = list(range(ndim))
Expand All @@ -102,12 +108,11 @@ def call_operator(self, op, args, kwargs, meta):
nkwargs["memory_format"] = get_memory_format(dim_order)

logger.debug(
f" _to_dim_order_copy = dim_order: {dim_order}."
f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
f" {op.__name__} = dim_order: {dim_order}."
f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)

t = MemoryFormatOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap"
t = MemoryFormatOpsMap[op.__name__]

return super().call_operator(
t,
Expand Down
34 changes: 34 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCopyChannalsLastModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
SimpleToCopyContiguousModule,
)
Expand All @@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
Expand All @@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_channels_last(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyChannelLastModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_contiguous(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyContiguoustModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleToCopyChannelsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None:
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand Down Expand Up @@ -273,6 +303,7 @@ def test_resnet18(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand Down
32 changes: 32 additions & 0 deletions exir/tests/test_memory_format_ops_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCopyChannalsLastModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
SimpleToCopyContiguousModule,
)
Expand All @@ -28,6 +30,7 @@ def test_op_to_copy_replacement_2d_aten(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
Expand All @@ -39,17 +42,43 @@ def test_op_to_copy_replacement_4d_aten(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_channels_last(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyChannelLastModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_contiguous(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyContiguoustModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update_aten(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleToCopyChannelsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -67,6 +96,7 @@ def test_op_dim_order_propagation_aten(self) -> None:
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -85,6 +115,7 @@ def test_resnet18(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -100,6 +131,7 @@ def test_mobilenet_v3(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand Down
42 changes: 38 additions & 4 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import unittest
from dataclasses import dataclass
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple

import torch

Expand All @@ -26,11 +26,24 @@
from torch.utils._pytree import tree_flatten


MemoryFormatOps2Str: Dict[torch._ops.OpOverload, List[str]] = {
torch.ops.aten._to_copy.default: (
"torch.ops.aten._to_copy.default",
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default",
),
torch.ops.aten.empty.memory_format: (
"torch.ops.aten.empty.memory_format",
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
),
}


@dataclass
class MemoryFormatTestSet:
module: torch.nn.Module
sample_input: Tuple[Any, ...]
target_memory_format: torch.memory_format
op: torch._ops.OpOverload
_load_for_executorch_from_buffer: Any
op_level_check: bool = True
use_xnnpack: bool = False
Expand All @@ -54,6 +67,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)


class SimpleEmptyContiguoustModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
empty_tensor = torch.empty(x.size(), memory_format=torch.contiguous_format)
x = x.to(memory_format=torch.contiguous_format)
empty_tensor.copy_(x)
return empty_tensor


class SimpleEmptyChannelLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
empty_tensor = torch.empty(x.size(), memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)
empty_tensor.copy_(x)
return empty_tensor


class PropagateToCopyChannalsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -86,9 +121,7 @@ def memory_format_test_runner(

# check memory format ops, if needed
if test_set.op_level_check:
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
Expand Down Expand Up @@ -126,6 +159,7 @@ def memory_format_test_runner(
runtime_output = executorch_module.run_method(
"forward", tuple(inputs_flattened)
)[0]

test_class.assertTrue(
torch.allclose(
runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol
Expand Down