Skip to content

Commit

Permalink
#4916: Add avg pool to ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
mywoodstock authored and arakhmati committed Jan 25, 2024
1 parent cfe78bb commit d75c31a
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 1 deletion.
46 changes: 46 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/average_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch
from tests.ttnn.utils_for_testing import check_with_pcc
import ttnn


parameters = {
"act_shape": [[1, 7, 7, 2048], [1, 1, 32, 64]],
"dtype": [ttnn.bfloat16],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
return False, None


def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]:
return False, None


def run(
act_shape,
dtype,
device,
) -> Tuple[bool, Optional[str]]:
torch.manual_seed(0)

act = torch.randn(act_shape, dtype=torch.bfloat16)
ttact = ttnn.from_torch(act, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)

out = ttnn.average_pool2d(ttact)

out_pytorch = ttnn.to_torch(ttnn.from_device(out))

## reference
act_channels_first = torch.permute(act, (0, 3, 1, 2)) # Torch operates on channels-first tensors
golden_pytorch = torch.nn.AdaptiveAvgPool2d((1, 1))(act_channels_first)
golden_pytorch = torch.permute(golden_pytorch, (0, 2, 3, 1))

## test for equivalance
return check_with_pcc(golden_pytorch, out_pytorch)
48 changes: 48 additions & 0 deletions tests/ttnn/unit_tests/test_average_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger

import torch
import pytest
import math
from models.utility_functions import skip_for_wormhole_b0
from tests.ttnn.utils_for_testing import assert_with_pcc
import ttnn


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"act_shape",
(([1, 7, 7, 2048], ([1, 1, 32, 64]))),
ids=["resnet50_unpadded", "tile_divisible"],
)
@pytest.mark.parametrize(
"dtype",
(ttnn.bfloat16,),
ids=[
"BFLOAT16",
],
)
def test_run_average_pool(
act_shape,
dtype,
device,
):
torch.manual_seed(0)

act = torch.randn(act_shape, dtype=torch.bfloat16)
ttact = ttnn.from_torch(act, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)

out = ttnn.average_pool2d(ttact)

out_pytorch = ttnn.to_torch(ttnn.from_device(out))

## reference
act_channels_first = torch.permute(act, (0, 3, 1, 2)) # Torch operates on channels-first tensors
golden_pytorch = torch.nn.AdaptiveAvgPool2d((1, 1))(act_channels_first)
golden_pytorch = torch.permute(golden_pytorch, (0, 2, 3, 1))

## test for equivalance
assert_with_pcc(golden_pytorch, out_pytorch)
5 changes: 4 additions & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@

from ttnn.operations import transformer
from ttnn.operations.conv import Conv2D
from ttnn.operations.max_pool import MaxPool2D
from ttnn.operations.pooling import (
MaxPool2D,
average_pool2d,
)
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from typing import Tuple, Union, Dict

import tt_lib as ttl

import ttnn.core as ttnn
from ttnn.decorators import decorate_operation

from tt_eager.tt_dnn.op_library.sliding_window_op_infra.tt_py_max_pool import (
TTPyMaxPool,
Expand Down Expand Up @@ -81,3 +84,21 @@ def copy_input_to_device(self, input: ttnn.Tensor):

def copy_output_from_device(self, output: ttnn.Tensor):
return ttnn.Tensor(self.max_pool.copy_output_from_device(output.value))


def _torch_average_pool2d(input_tensor: ttnn.Tensor):
import torch

output_size = (1, 1)
input_tensor = ttnn.from_device(input_tensor)
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
input_tensor = ttnn.to_torch(input_tensor)

return torch.nn.AdaptiveAvgPool2d(output_size)(input_tensor)


@decorate_operation(torch_function=_torch_average_pool2d)
def average_pool2d(input_tensor: ttnn.Tensor) -> ttnn.Tensor:
output = ttl.tensor.average_pool_2d(input_tensor.value)

return ttnn.Tensor(output)

0 comments on commit d75c31a

Please sign in to comment.