diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 27fc03f9413..c4436aaa910 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -15,11 +15,19 @@ "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor" ) -# Out variant drops TensorOptions +lib.define( + "_empty_dim_order(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int[]? dim_order=None) -> Tensor" +) + +# Out variant of aten::_to_copy and aten::empty drops TensorOptions, so do their dim order variants lib.define( "_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" +) + def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -39,11 +47,22 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs) +@impl(lib, "_empty_dim_order", "CompositeImplicitAutograd") +def _empty_dim_order_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.empty.memory_format, *args, **kwargs) + + +@impl(lib, "_empty_dim_order.out", "CompositeImplicitAutograd") +def _empty_dim_order_out_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) + + """ Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + "aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default, } """ @@ -51,6 +70,7 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): """ MemoryFormatOpsMap = { "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, + "dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format, } # If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 32678bf4082..2e3661f36c4 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -39,6 +39,7 @@ def call_operator(self, op, args, kwargs, meta): kwargs, meta, ) + # new kwargs with dim_order, and no memory_format for the new op nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable @@ -50,17 +51,20 @@ def call_operator(self, op, args, kwargs, meta): ndim = args[0].to_tensor().dim() elif isinstance(args[0], torch.Tensor): ndim = args[0].dim() + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): + ndim = len(args[0]) else: - assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + assert ( + 0 + ), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}" nkwargs["dim_order"] = get_dim_order(mem_format, ndim) logger.debug( - f"_to_copy = rank: {ndim}, memory_format: {mem_format}." - f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}" + f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}." + f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}" ) - t = DimOrderOpsMap.get(op.__name__, None) - assert t is not None, f"{op.__name__} not found in DimOrderOpsMap" + t = DimOrderOpsMap[op.__name__] return super().call_operator( t, @@ -92,8 +96,10 @@ def call_operator(self, op, args, kwargs, meta): ndim = args[0].to_tensor().dim() elif isinstance(args[0], torch.Tensor): ndim = args[0].dim() + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): + ndim = len(args[0]) else: - assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" # get the "to" memory format for the EdgeOp default_dim_order = list(range(ndim)) @@ -102,12 +108,11 @@ def call_operator(self, op, args, kwargs, meta): nkwargs["memory_format"] = get_memory_format(dim_order) logger.debug( - f" _to_dim_order_copy = dim_order: {dim_order}." - f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}." + f" {op.__name__} = dim_order: {dim_order}." + f" {DimOrderOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}." ) - t = MemoryFormatOpsMap.get(op.__name__, None) - assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap" + t = MemoryFormatOpsMap[op.__name__] return super().call_operator( t, diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 53befded94b..0292cf98f50 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -27,6 +27,8 @@ MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleEmptyChannelLastModule, + SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, SimpleToCopyContiguousModule, ) @@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, @@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), ) + def test_op_empty_replacement_channels_last(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyChannelLastModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_empty_replacement_contiguous(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyContiguoustModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, MemoryFormatTestSet( module=SimpleToCopyChannelsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None: self, MemoryFormatTestSet( module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -273,6 +303,7 @@ def test_resnet18(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, diff --git a/exir/tests/test_memory_format_ops_pass_aten.py b/exir/tests/test_memory_format_ops_pass_aten.py index 601893fd238..5aa687e6aef 100644 --- a/exir/tests/test_memory_format_ops_pass_aten.py +++ b/exir/tests/test_memory_format_ops_pass_aten.py @@ -13,6 +13,8 @@ MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleEmptyChannelLastModule, + SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, SimpleToCopyContiguousModule, ) @@ -28,6 +30,7 @@ def test_op_to_copy_replacement_2d_aten(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, @@ -39,17 +42,43 @@ def test_op_to_copy_replacement_4d_aten(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), ) + def test_op_empty_replacement_channels_last(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyChannelLastModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_empty_replacement_contiguous(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyContiguoustModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update_aten(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, MemoryFormatTestSet( module=SimpleToCopyChannelsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -67,6 +96,7 @@ def test_op_dim_order_propagation_aten(self) -> None: self, MemoryFormatTestSet( module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -85,6 +115,7 @@ def test_resnet18(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -100,6 +131,7 @@ def test_mobilenet_v3(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 8ae5c0190a4..3049f30a8cb 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -8,7 +8,7 @@ import unittest from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -26,11 +26,24 @@ from torch.utils._pytree import tree_flatten +MemoryFormatOps2Str: Dict[torch._ops.OpOverload, List[str]] = { + torch.ops.aten._to_copy.default: ( + "torch.ops.aten._to_copy.default", + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default", + ), + torch.ops.aten.empty.memory_format: ( + "torch.ops.aten.empty.memory_format", + "executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default", + ), +} + + @dataclass class MemoryFormatTestSet: module: torch.nn.Module sample_input: Tuple[Any, ...] target_memory_format: torch.memory_format + op: torch._ops.OpOverload _load_for_executorch_from_buffer: Any op_level_check: bool = True use_xnnpack: bool = False @@ -54,6 +67,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.channels_last) +class SimpleEmptyContiguoustModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + empty_tensor = torch.empty(x.size(), memory_format=torch.contiguous_format) + x = x.to(memory_format=torch.contiguous_format) + empty_tensor.copy_(x) + return empty_tensor + + +class SimpleEmptyChannelLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + empty_tensor = torch.empty(x.size(), memory_format=torch.channels_last) + x = x.to(memory_format=torch.channels_last) + empty_tensor.copy_(x) + return empty_tensor + + class PropagateToCopyChannalsLastModule(torch.nn.Module): def __init__(self): super().__init__() @@ -86,9 +121,7 @@ def memory_format_test_runner( # check memory format ops, if needed if test_set.op_level_check: - aten_op_str = "torch.ops.aten._to_copy.default" - edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - + aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op] # check op strings before FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( edge_op_str @@ -126,6 +159,7 @@ def memory_format_test_runner( runtime_output = executorch_module.run_method( "forward", tuple(inputs_flattened) )[0] + test_class.assertTrue( torch.allclose( runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol