Skip to content

Commit

Permalink
[et][dim order] aot support for dim order variant empty op
Browse files Browse the repository at this point in the history
Pull Request resolved: #7168

This diff added aot support for dim order variant empty op, including operator impl and registration, memory_format_pass update, and end2end tests on both aten and lean mode.

Differential Revision: [D66738618](https://our.internmc.facebook.com/intern/diff/D66738618/)
ghstack-source-id: 256405766
  • Loading branch information
Gasoonjia committed Dec 4, 2024
1 parent c0de556 commit e67c472
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 15 deletions.
22 changes: 21 additions & 1 deletion exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

# Out variant drops TensorOptions
lib.define(
"_empty_dim_order(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int[]? dim_order=None) -> Tensor"
)

# Out variant of aten::_to_copy and aten::empty drops TensorOptions, so do their dim order variants
lib.define(
"_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)


def _op_impl(target, *args, **kwargs):
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
Expand All @@ -39,18 +47,30 @@ def _to_dim_order_copy_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)


@impl(lib, "_empty_dim_order", "CompositeImplicitAutograd")
def _empty_dim_order_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.memory_format, *args, **kwargs)


@impl(lib, "_empty_dim_order.out", "CompositeImplicitAutograd")
def _empty_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)


"""
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
"aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default,
}

"""
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
"""
MemoryFormatOpsMap = {
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
"dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format,
}

# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
Expand Down
25 changes: 15 additions & 10 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def call_operator(self, op, args, kwargs, meta):
kwargs,
meta,
)

# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

Expand All @@ -50,17 +51,20 @@ def call_operator(self, op, args, kwargs, meta):
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
assert (
0
), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}"

nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
logger.debug(
f"_to_copy = rank: {ndim}, memory_format: {mem_format}."
f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}"
)

t = DimOrderOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in DimOrderOpsMap"
t = DimOrderOpsMap[op.__name__]

return super().call_operator(
t,
Expand Down Expand Up @@ -92,8 +96,10 @@ def call_operator(self, op, args, kwargs, meta):
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"

# get the "to" memory format for the EdgeOp
default_dim_order = list(range(ndim))
Expand All @@ -102,12 +108,11 @@ def call_operator(self, op, args, kwargs, meta):
nkwargs["memory_format"] = get_memory_format(dim_order)

logger.debug(
f" _to_dim_order_copy = dim_order: {dim_order}."
f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
f" {op.__name__} = dim_order: {dim_order}."
f" {DimOrderOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)

t = MemoryFormatOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap"
t = MemoryFormatOpsMap[op.__name__]

return super().call_operator(
t,
Expand Down
34 changes: 34 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCopyChannalsLastModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
SimpleToCopyContiguousModule,
)
Expand All @@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
Expand All @@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_channels_last(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyChannelLastModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_contiguous(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyContiguoustModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleToCopyChannelsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None:
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand Down Expand Up @@ -273,6 +303,7 @@ def test_resnet18(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand Down
32 changes: 32 additions & 0 deletions exir/tests/test_memory_format_ops_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCopyChannalsLastModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
SimpleToCopyContiguousModule,
)
Expand All @@ -28,6 +30,7 @@ def test_op_to_copy_replacement_2d_aten(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
Expand All @@ -39,17 +42,43 @@ def test_op_to_copy_replacement_4d_aten(self) -> None:
self,
MemoryFormatTestSet(
module=SimpleToCopyContiguousModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_channels_last(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyChannelLastModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_empty_replacement_contiguous(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleEmptyContiguoustModule().eval(),
op=torch.ops.aten.empty.memory_format,
sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update_aten(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=SimpleToCopyChannelsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -67,6 +96,7 @@ def test_op_dim_order_propagation_aten(self) -> None:
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
Expand All @@ -85,6 +115,7 @@ def test_resnet18(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand All @@ -100,6 +131,7 @@ def test_mobilenet_v3(self) -> None:
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(torch.randn(1, 3, 224, 224),),
target_memory_format=torch.contiguous_format,
op_level_check=False,
Expand Down
42 changes: 38 additions & 4 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import unittest
from dataclasses import dataclass
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple

import torch

Expand All @@ -26,11 +26,24 @@
from torch.utils._pytree import tree_flatten


MemoryFormatOps2Str: Dict[torch._ops.OpOverload, List[str]] = {
torch.ops.aten._to_copy.default: (
"torch.ops.aten._to_copy.default",
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default",
),
torch.ops.aten.empty.memory_format: (
"torch.ops.aten.empty.memory_format",
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
),
}


@dataclass
class MemoryFormatTestSet:
module: torch.nn.Module
sample_input: Tuple[Any, ...]
target_memory_format: torch.memory_format
op: torch._ops.OpOverload
_load_for_executorch_from_buffer: Any
op_level_check: bool = True
use_xnnpack: bool = False
Expand All @@ -54,6 +67,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)


class SimpleEmptyContiguoustModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
empty_tensor = torch.empty(x.size(), memory_format=torch.contiguous_format)
x = x.to(memory_format=torch.contiguous_format)
empty_tensor.copy_(x)
return empty_tensor


class SimpleEmptyChannelLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
empty_tensor = torch.empty(x.size(), memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)
empty_tensor.copy_(x)
return empty_tensor


class PropagateToCopyChannalsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -86,9 +121,7 @@ def memory_format_test_runner(

# check memory format ops, if needed
if test_set.op_level_check:
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
# check op strings before
FileCheck().check_count(aten_op_str, 1, exactly=True).check_not(
edge_op_str
Expand Down Expand Up @@ -126,6 +159,7 @@ def memory_format_test_runner(
runtime_output = executorch_module.run_method(
"forward", tuple(inputs_flattened)
)[0]

test_class.assertTrue(
torch.allclose(
runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol
Expand Down

0 comments on commit e67c472

Please sign in to comment.