diff --git a/tests/ttnn/unit_tests/operations/test_max.py b/tests/ttnn/unit_tests/operations/test_max.py index 9dcbfffb073..5e7a8007fe0 100644 --- a/tests/ttnn/unit_tests/operations/test_max.py +++ b/tests/ttnn/unit_tests/operations/test_max.py @@ -44,10 +44,7 @@ def test_max_global(device, batch_size, h, w): input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.max(input_tensor) - output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) - output_tensor = ttnn.from_device(output_tensor) - output_tensor = ttnn.to_torch(output_tensor) - output_tensor = output_tensor[0, 0, 0, 0] + output_tensor = output_tensor[0, 0, 0] assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_min.py b/tests/ttnn/unit_tests/operations/test_min.py index e38ee7082dc..767fb4a407d 100644 --- a/tests/ttnn/unit_tests/operations/test_min.py +++ b/tests/ttnn/unit_tests/operations/test_min.py @@ -51,6 +51,6 @@ def test_min_global(device, batch_size, h, w): output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) - output_tensor = output_tensor[0, 0, 0, 0] + output_tensor = output_tensor[0, 0, 0] assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_sum.py b/tests/ttnn/unit_tests/operations/test_sum.py index a099991e426..80d16bdf69f 100644 --- a/tests/ttnn/unit_tests/operations/test_sum.py +++ b/tests/ttnn/unit_tests/operations/test_sum.py @@ -51,6 +51,6 @@ def test_sum_global(device, batch_size, h, w): output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) - output_tensor = output_tensor[0, 0, 0, 0] + output_tensor = output_tensor[0, 0, 0] assert_with_pcc(torch_output_tensor, output_tensor) diff --git a/ttnn/ttnn/operations/reduction.py b/ttnn/ttnn/operations/reduction.py index 520e7cde0a4..b40529e41da 100644 --- a/ttnn/ttnn/operations/reduction.py +++ b/ttnn/ttnn/operations/reduction.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple, Union +from typing import Tuple, Union, Optional import tt_lib as ttl @@ -10,13 +10,21 @@ import ttnn -def _golden_function(input_tensor: ttnn.Tensor, dim: int, keepdim=False, **_): +def _create_golden_function(torch_function_name): import torch - return torch.std(input_tensor, dim=dim, keepdim=keepdim) + torch_function = getattr(torch, torch_function_name) + + def golden_function(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim=False, **_): + if dim == None: + return torch_function(input_tensor, keepdim=keepdim) + else: + return torch_function(input_tensor, dim=dim, keepdim=keepdim) + + return golden_function -def _std_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): +def _validate_input_tensors(operation_name, input_tensor, *args, **kwargs): ttnn.validate_input_tensor( operation_name, input_tensor, @@ -28,46 +36,35 @@ def _std_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): ) -@ttnn.register_operation( - name="ttnn.std", - validate_input_tensors=_std_validate_input_tensors, - golden_function=_golden_function, -) -def std( +def reduce( input_tensor: ttnn.Tensor, - dim: Union[int, Tuple[int]], - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, -) -> ttnn.Tensor: - """ - std(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int]]) -> ttnn.Tensor - """ + reduction_op: str, + dim: Optional[Union[int, Tuple[int]]], + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, +): + if not keepdim: + raise RuntimeError("keepdim=False is not supported") input_shape = tuple(input_tensor.shape) rank = len(input_shape) + memory_config = memory_config or input_tensor.memory_config() + original_dim = dim if isinstance(dim, int): if dim < 0: dim = rank + dim dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") + elif dim is None: + dim = list(range(rank)) output_shape = [] padded_output_shape = [] for axis, size in enumerate(input_shape): if axis in dim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) + if keepdim: + output_shape.append(1) + padded_output_shape.append(ttnn.TILE_SIZE if axis >= rank - 2 else 1) else: output_shape.append(size) padded_output_shape.append(size) @@ -76,89 +73,136 @@ def std( input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - mean_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1]) - mean_square_tensor = ttl.tensor.reduce( - ttl.tensor.pow(input_tensor, 2.0), ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1] - ) - output_tensor = ttl.tensor.sqrt(ttl.tensor.sub(mean_square_tensor, (ttl.tensor.pow(mean_tensor, 2.0)))) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - - return output_tensor + if original_dim is None: + if reduction_op == "mean": + output_tensor = ttl.tensor.global_mean(input_tensor, output_mem_config=memory_config) + elif reduction_op == "sum": + output_tensor = ttl.tensor.global_sum(input_tensor, output_mem_config=memory_config) + elif reduction_op == "max": + output_tensor = ttl.tensor.global_max(input_tensor, output_mem_config=memory_config) + elif reduction_op == "min": + output_tensor = ttl.tensor.global_min(input_tensor, output_mem_config=memory_config) + else: + raise RuntimeError("Unsupported reduction operation") + else: + if isinstance(dim, tuple): + if dim == (rank - 1,): + reduce_op_dim = ttl.tensor.ReduceOpDim.W + elif dim == (rank - 2,): + reduce_op_dim = ttl.tensor.ReduceOpDim.H + elif dim == (rank - 1, rank - 2): + reduce_op_dim = ttl.tensor.ReduceOpDim.HW + else: + raise RuntimeError("Unsupported dim") + else: + raise RuntimeError("Invalid dim") + + reduced_volume = 1 + for axis in dim: + reduced_volume *= input_shape[axis] + + if reduction_op == "sum": + output_tensor = ttl.tensor.reduce( + input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1.0, output_mem_config=memory_config + ) + elif reduction_op == "mean": + output_tensor = ttl.tensor.reduce( + input_tensor, + ttl.tensor.ReduceOpMath.SUM, + reduce_op_dim, + 1 / reduced_volume, + output_mem_config=memory_config, + ) + elif reduction_op == "max": + output_tensor = ttl.tensor.reduce( + input_tensor, ttl.tensor.ReduceOpMath.MAX, reduce_op_dim, 1.0, output_mem_config=memory_config + ) + elif reduction_op == "min": + output_tensor = ttl.tensor.reduce( + input_tensor, ttl.tensor.ReduceOpMath.MIN, reduce_op_dim, 1.0, output_mem_config=memory_config + ) + elif reduction_op == "std": + mean_tensor = ttl.tensor.reduce( + input_tensor, + ttl.tensor.ReduceOpMath.SUM, + reduce_op_dim, + 1 / reduced_volume, + output_mem_config=memory_config, + ) + mean_square_tensor = ttl.tensor.reduce( + ttl.tensor.pow(input_tensor, 2.0), + ttl.tensor.ReduceOpMath.SUM, + reduce_op_dim, + 1 / reduced_volume, + output_mem_config=memory_config, + ) + output_tensor = ttl.tensor.sqrt( + ttl.tensor.sub( + mean_square_tensor, + ttl.tensor.pow(mean_tensor, 2.0, output_mem_config=memory_config), + output_mem_config=memory_config, + ), + output_mem_config=memory_config, + ) + elif reduction_op == "var": + mean_tensor = ttl.tensor.reduce( + input_tensor, + ttl.tensor.ReduceOpMath.SUM, + reduce_op_dim, + 1 / reduced_volume, + output_mem_config=memory_config, + ) + mean_square_tensor = ttl.tensor.reduce( + ttl.tensor.pow(input_tensor, 2.0), + ttl.tensor.ReduceOpMath.SUM, + reduce_op_dim, + 1 / reduced_volume, + output_mem_config=memory_config, + ) + output_tensor = ttl.tensor.sub( + mean_square_tensor, ttl.tensor.pow(mean_tensor, 2.0), output_mem_config=memory_config + ) + else: + raise RuntimeError("Unsupported reduction operation") -def _golden_function(input_tensor: ttnn.Tensor, dim: int, keepdim=False, **_): - import torch + output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - return torch.var(input_tensor, dim=dim, keepdim=keepdim) + return output_tensor -def _var_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) +@ttnn.register_operation( + name="ttnn.std", + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("std"), +) +def std( + input_tensor: ttnn.Tensor, + dim: Optional[Union[int, Tuple[int]]], + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, +) -> ttnn.Tensor: + """ + std(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]]) -> ttnn.Tensor + """ + return reduce(input_tensor, "std", dim, keepdim, memory_config) @ttnn.register_operation( name="ttnn.var", - validate_input_tensors=_var_validate_input_tensors, - golden_function=_golden_function, + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("var"), ) def var( input_tensor: ttnn.Tensor, - dim: Union[int, Tuple[int]], - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, + dim: Optional[Union[int, Tuple[int]]], + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, ) -> ttnn.Tensor: """ - var(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int]]) -> ttnn.Tensor + var(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]]) -> ttnn.Tensor """ - - input_shape = tuple(input_tensor.shape) - rank = len(input_shape) - - if isinstance(dim, int): - if dim < 0: - dim = rank + dim - dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") - - output_shape = [] - padded_output_shape = [] - for axis, size in enumerate(input_shape): - if axis in dim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) - else: - output_shape.append(size) - padded_output_shape.append(size) - output_shape = tuple(output_shape) - padded_output_shape = tuple(padded_output_shape) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - - mean_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1]) - mean_square_tensor = ttl.tensor.reduce( - ttl.tensor.pow(input_tensor, 2.0), ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1] - ) - output_tensor = ttl.tensor.sub(mean_square_tensor, ttl.tensor.pow(mean_tensor, 2.0)) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - return output_tensor + return reduce(input_tensor, "var", dim, keepdim, memory_config) def _golden_function(input_tensor: ttnn.Tensor, dim: Union[int, None], keepdim=False, **_): @@ -170,232 +214,55 @@ def _golden_function(input_tensor: ttnn.Tensor, dim: Union[int, None], keepdim=F return torch.max(input_tensor, dim=dim, keepdim=keepdim) -def _max_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) - - @ttnn.register_operation( name="ttnn.max", - validate_input_tensors=_max_validate_input_tensors, - golden_function=_golden_function, + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("max"), ) def max( input_tensor: ttnn.Tensor, - dim: Union[int, Tuple[int], None] = None, - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, + dim: Optional[Union[int, Tuple[int]]] = None, + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, ) -> ttnn.Tensor: """ - max(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int], None]) -> ttnn.Tensor + max(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]]) -> ttnn.Tensor """ - - input_shape = tuple(input_tensor.shape) - rank = len(input_shape) - - if dim == None: - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.global_max(input_tensor) - return output_tensor - - if isinstance(dim, int): - if dim < 0: - dim = rank + dim - dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") - - output_shape = [] - padded_output_shape = [] - for axis, size in enumerate(input_shape): - if axis in dim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) - else: - output_shape.append(size) - padded_output_shape.append(size) - output_shape = tuple(output_shape) - padded_output_shape = tuple(padded_output_shape) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.MAX, reduce_op_dim, 1.0) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - - return output_tensor - - -def _golden_function(input_tensor: ttnn.Tensor, dim: Union[int, None], keepdim=False, **_): - import torch - - if dim == None: - return torch.min(input_tensor) - else: - return torch.min(input_tensor, dim=dim, keepdim=keepdim) - - -def _min_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) + return reduce(input_tensor, "max", dim, keepdim, memory_config) @ttnn.register_operation( name="ttnn.min", - validate_input_tensors=_min_validate_input_tensors, - golden_function=_golden_function, + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("min"), ) def min( input_tensor: ttnn.Tensor, - dim: Union[int, Tuple[int], None] = None, - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, + dim: Optional[Union[int, Tuple[int]]] = None, + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, ) -> ttnn.Tensor: """ - min(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int], None]) -> ttnn.Tensor + min(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]]) -> ttnn.Tensor """ - - input_shape = tuple(input_tensor.shape) - rank = len(input_shape) - - if dim == None: - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.global_min(input_tensor) - return output_tensor - - if isinstance(dim, int): - if dim < 0: - dim = rank + dim - dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") - - output_shape = [] - padded_output_shape = [] - for axis, size in enumerate(input_shape): - if axis in dim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) - else: - output_shape.append(size) - padded_output_shape.append(size) - output_shape = tuple(output_shape) - padded_output_shape = tuple(padded_output_shape) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.MIN, reduce_op_dim, 1.0) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - - return output_tensor - - -def _golden_function(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int], None] = None, keepdim=False, **_): - import torch - - if dim == None: - return torch.sum(input_tensor) - else: - return torch.sum(input_tensor, dim=dim, keepdim=keepdim) - - -def _sum_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) + return reduce(input_tensor, "min", dim, keepdim, memory_config) @ttnn.register_operation( name="ttnn.sum", - validate_input_tensors=_sum_validate_input_tensors, - golden_function=_golden_function, + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("sum"), ) def sum( input_tensor: ttnn.Tensor, - dim: Union[int, Tuple[int], None] = None, - memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, + dim: Optional[Union[int, Tuple[int]]] = None, + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, ) -> ttnn.Tensor: """ - sum(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int], None]) -> ttnn.Tensor + sum(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]]) -> ttnn.Tensor """ - - input_shape = tuple(input_tensor.shape) - rank = len(input_shape) - - if dim == None: - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.global_sum(input_tensor) - return output_tensor - - if isinstance(dim, int): - if dim < 0: - dim = rank + dim - dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") - - output_shape = [] - padded_output_shape = [] - for axis, size in enumerate(input_shape): - if axis in dim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) - else: - output_shape.append(size) - padded_output_shape.append(size) - output_shape = tuple(output_shape) - padded_output_shape = tuple(padded_output_shape) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1.0) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - - return output_tensor + return reduce(input_tensor, "sum", dim, keepdim, memory_config) def _golden_function(input_tensor: ttnn.Tensor, dim: int, keepdim=False, **_): @@ -404,66 +271,21 @@ def _golden_function(input_tensor: ttnn.Tensor, dim: int, keepdim=False, **_): return torch.mean(input_tensor, dim=dim, keepdim=keepdim) -def _mean_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): - ttnn.validate_input_tensor( - operation_name, - input_tensor, - ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), - can_be_on_device=True, - can_be_on_cpu=False, - ) - - @ttnn.register_operation( name="ttnn.mean", - validate_input_tensors=_mean_validate_input_tensors, - golden_function=_golden_function, + validate_input_tensors=_validate_input_tensors, + golden_function=_create_golden_function("mean"), ) -def mean(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int]], keepdim: bool = False) -> ttnn.Tensor: +def mean( + input_tensor: ttnn.Tensor, + dim: Optional[Union[int, Tuple[int]]] = None, + keepdim: bool = True, + memory_config: Optional[ttnn.MemoryConfig] = None, +) -> ttnn.Tensor: """ - mean(input_tensor: ttnn.Tensor, dim: Union[int, Tuple[int]], keepdim: bool = False) -> ttnn.Tensor + mean(input_tensor: ttnn.Tensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = True, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor """ - - input_shape = tuple(input_tensor.shape) - rank = len(input_shape) - - if isinstance(dim, int): - if dim < 0: - dim = rank + dim - dim = (dim,) - - if isinstance(dim, tuple): - if dim == (rank - 1,): - reduce_op_dim = ttl.tensor.ReduceOpDim.W - elif dim == (rank - 2,): - reduce_op_dim = ttl.tensor.ReduceOpDim.H - elif dim == (rank - 1, rank - 2): - reduce_op_dim = ttl.tensor.ReduceOpDim.HW - else: - raise RuntimeError("Unsupported dim") - else: - raise RuntimeError("Invalid dim") - - output_shape = [] - padded_output_shape = [] - for axis, size in enumerate(input_shape): - if axis in dim: - if keepdim: - output_shape.append(1) - padded_output_shape.append(ttnn.TILE_SIZE) - else: - output_shape.append(size) - padded_output_shape.append(size) - output_shape = tuple(output_shape) - padded_output_shape = tuple(padded_output_shape) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - output_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1]) - output_tensor = ttl.tensor.reduce(input_tensor, ttl.tensor.ReduceOpMath.SUM, reduce_op_dim, 1 / input_shape[-1]) - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(output_shape, padded_output_shape)) - return output_tensor + return reduce(input_tensor, "mean", dim, keepdim, memory_config) __all__ = []