Skip to content

Commit

Permalink
directly using dim order for memory format comparsion. (#2170)
Browse files Browse the repository at this point in the history
Summary:

bypass-github-export-checks

Reviewed By: digantdesai

Differential Revision: D54341919
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Mar 11, 2024
1 parent 64aa335 commit b088c0b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
18 changes: 18 additions & 0 deletions exir/dim_order_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,21 @@ def get_dim_order(
raise AssertionError(
f"Failed to generate dim_order for a given memory format: {memory_format}"
)


def is_channel_last_dim_order(tensor: torch.Tensor) -> bool:
"""
Check if a tensor has channels last dim order
"""
if tensor.dim() != 4:
# Only support 4D tensors for channel list memory format.
return False

return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim()))


def is_contiguous_dim_order(tensor: torch.Tensor) -> bool:
"""
Check if a tensor has contiguous dim order
"""
return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim()))
22 changes: 9 additions & 13 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

import torch
from executorch.exir import EdgeCompileConfig, to_edge

from executorch.exir.dim_order_utils import (
is_channel_last_dim_order,
is_contiguous_dim_order,
)
from torch.export import export
from torch.testing import FileCheck

Expand All @@ -22,15 +27,6 @@ class MemoryFormatTestSet:


class TestMemoryFormatOpsPass(unittest.TestCase):
def is_channel_last(self, x: torch.Tensor):
# This is a heuristic to determine if the input tensor is in NHWC (channel last)
# due to we do not have a good way to infer the dimension order or the memory format
# of the input tensor. Please not this function is specific for contiguous tensors
# whose dim(1) is channel one only, other types of tensors may not work well
# due to different channel configuration and memory arrangement.

return x.stride(1) == 1

def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
aten_op_str = "torch.ops.aten._to_copy.default"
edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
Expand Down Expand Up @@ -60,13 +56,13 @@ def memory_format_test_runner(self, test_set: MemoryFormatTestSet):
actual = epm.exported_program().module()(*test_set.sample_input)
self.assertTrue(torch.allclose(actual, expected))
self.assertEqual(
self.is_channel_last(actual),
self.is_channel_last(expected),
is_channel_last_dim_order(actual),
is_channel_last_dim_order(expected),
)
if test_set.target_memory_format == torch.channels_last:
self.assertTrue(self.is_channel_last(actual))
self.assertTrue(is_channel_last_dim_order(actual))
elif test_set.target_memory_format == torch.contiguous_format:
self.assertFalse(self.is_channel_last(actual))
self.assertTrue(is_contiguous_dim_order(actual))
else:
raise RuntimeError("Unknown memory format")

Expand Down

0 comments on commit b088c0b

Please sign in to comment.