Skip to content

Commit

Permalink
#8683: Add Unary bitwise XOR, NOT (#9436)
Browse files Browse the repository at this point in the history
* #8683: Add bitwise_xor op

* #8683: Add bitwise_not op
  • Loading branch information
mcw-anasuya authored Jun 25, 2024
1 parent a728de2 commit 24b210a
Show file tree
Hide file tree
Showing 18 changed files with 496 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.heaviside

.. autofunction:: tt_lib.tensor.bitwise_xor

.. autofunction:: tt_lib.tensor.bitwise_not

.. autofunction:: tt_lib.tensor.right_shift

.. autofunction:: tt_lib.tensor.left_shift
Expand Down
8 changes: 8 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,14 @@
"tt_op": tt_lib_ops.eltwise_heaviside,
"pytorch_op": pytorch_ops.heaviside,
},
"eltwise-bitwise_xor": {
"tt_op": tt_lib_ops.eltwise_bitwise_xor,
"pytorch_op": pytorch_ops.bitwise_xor,
},
"eltwise-bitwise_not": {
"tt_op": tt_lib_ops.eltwise_bitwise_not,
"pytorch_op": pytorch_ops.bitwise_not,
},
"eltwise-right_shift": {
"tt_op": tt_lib_ops.eltwise_right_shift,
"pytorch_op": pytorch_ops.right_shift,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from functools import partial
import tt_lib as ttl


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import skip_for_grayskull

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"scalar",
(1, 1),
)
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
[[2, 2, 32, 32]],
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
@skip_for_grayskull("#TODO: GS implementation needs to be done")
class TestBitwiseNot:
def test_run_bitwise_not_op(
self,
scalar,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(
partial(generation_funcs.gen_rand, low=-2147483647, high=2147483647), torch.int
)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
{
"value": scalar,
"dtype": [(ttl.tensor.DataType.INT32)],
}
)
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_equal

run_single_pytorch_test(
"eltwise-bitwise_not",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import random
from functools import partial
import tt_lib as ttl


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import skip_for_grayskull

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"scalar",
{random.randint(-100, 100) for _ in range(10)},
)
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
[[2, 2, 32, 32]],
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
@skip_for_grayskull("#TODO: GS implementation needs to be done")
class TestBitwiseXor:
def test_run_bitwise_xor_op(
self,
scalar,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=0, high=2147483647), torch.int)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
{
"value": scalar,
"dtype": [(ttl.tensor.DataType.INT32)],
}
)
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_equal

run_single_pytorch_test(
"eltwise-bitwise_xor",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
11 changes: 11 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,17 @@ def heaviside(x, *args, **kwargs):
return result


def bitwise_xor(x, *args, **kwargs):
value = kwargs.pop("value")
result = torch.bitwise_xor(x, value)
return result


def bitwise_not(x, *args, **kwargs):
result = torch.bitwise_not(x)
return result


def right_shift(x, *args, **kwargs):
value = kwargs.pop("value")
result = torch.bitwise_right_shift(x, value)
Expand Down
36 changes: 36 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,42 @@ def lamb_optimizer(
return [tt2torch_tensor(t4[0]), tt2torch_tensor(t4[1]), tt2torch_tensor(t4[2])]


@setup_host_and_device
def eltwise_bitwise_xor(
x,
*args,
value,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.bitwise_xor(t0, value, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_bitwise_not(
x,
*args,
value,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.bitwise_not(t0, value, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_right_shift(
x,
Expand Down
27 changes: 24 additions & 3 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break;
case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break;
case UnaryOpType::TYPECAST: defines["SFPU_OP_TYPECAST_INCLUDE"] = "1"; break;
case UnaryOpType::BITWISE_XOR: defines["SFPU_OP_BITWISE_XOR_INCLUDE"] = "1"; break;
case UnaryOpType::BITWISE_NOT: defines["SFPU_OP_BITWISE_NOT_INCLUDE"] = "1"; break;
case UnaryOpType::RIGHT_SHIFT: defines["SFPU_OP_RIGHT_SHIFT_INCLUDE"] = "1"; break;
case UnaryOpType::FLOOR: defines["SFPU_OP_FLOOR_INCLUDE"] = "1"; break;
case UnaryOpType::LEFT_SHIFT: defines["SFPU_OP_LEFT_SHIFT_INCLUDE"] = "1"; break;
Expand Down Expand Up @@ -112,6 +114,14 @@ std::pair<string, string> get_op_init_and_func_parameterized(
op_init_and_name = {
"heaviside_tile_init();", fmt::format("heaviside_tile({}, {}u);", idst, Converter::to_hex(param0))};
break;
case UnaryOpType::BITWISE_XOR:
op_init_and_name = {
"bitwise_xor_tile_init();", fmt::format("bitwise_xor_tile({}, {}u);", idst, std::to_string((uint)param0))};
break;
case UnaryOpType::BITWISE_NOT:
op_init_and_name = {
"bitwise_not_tile_init();", fmt::format("bitwise_not_tile({}, {}u);", idst, std::to_string((uint)param0))};
break;
case UnaryOpType::RIGHT_SHIFT:
op_init_and_name = {
"right_shift_tile_init();",
Expand Down Expand Up @@ -341,14 +351,20 @@ namespace tt {

namespace tt_metal {

inline void validate_supported_arch(tt::ARCH arch, UnaryOpType op_type) {
inline void validate_supported_arch_dtype(tt::ARCH arch, DataType input_datatype, DataType output_datatype, UnaryOpType op_type) {
switch (op_type) {
case UnaryOpType::REMAINDER:
case UnaryOpType::FLOOR:
case UnaryOpType::LEFT_SHIFT:
case UnaryOpType::RIGHT_SHIFT:
TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole");
break;
case UnaryOpType::BITWISE_XOR:
case UnaryOpType::BITWISE_NOT:
TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole");
TT_FATAL(input_datatype == DataType::INT32, "Data type is not supported for Bitwise operations");
TT_FATAL(output_datatype == DataType::INT32, "Data type is not supported for Bitwise operations");
break;
default:
return;
}
Expand All @@ -357,10 +373,15 @@ inline void validate_supported_arch(tt::ARCH arch, UnaryOpType op_type) {
void EltwiseUnary::validate_with_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &optional_output_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
auto out_mem_config = (!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value()) ? optional_output_tensors.at(0).value().memory_config() : this->output_mem_config;

auto output_datatype = output_dtype;
if(!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value()){
const auto& out = optional_output_tensors.at(0);
output_datatype = out->get_dtype();
}
auto arch = input_tensor_a.device()->arch();
auto input_datatype = input_tensor_a.get_dtype();
for (const auto& unary_op : this->op_chain) {
validate_supported_arch(arch, unary_op.op_type);
validate_supported_arch_dtype(arch, input_datatype, output_datatype, unary_op.op_type);
}
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to eltwise unary need to be on device!");
TT_FATAL(
Expand Down
6 changes: 6 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ enum class UnaryOpType {
UNARY_LT,
TILED_PROD,
TYPECAST,
BITWISE_XOR,
BITWISE_NOT,
RIGHT_SHIFT,
FLOOR,
LEFT_SHIFT,
Expand Down Expand Up @@ -110,6 +112,8 @@ bool is_parametrized_type(T val) {
case UnaryOpType::UNARY_GT:
case UnaryOpType::UNARY_LT:
case UnaryOpType::TYPECAST:
case UnaryOpType::BITWISE_XOR:
case UnaryOpType::BITWISE_NOT:
case UnaryOpType::RIGHT_SHIFT:
case UnaryOpType::LEFT_SHIFT:
case UnaryOpType::REMAINDER: return true;
Expand Down Expand Up @@ -415,6 +419,8 @@ constexpr auto leaky_relu = make_eltwise_unary_with_param<UnaryOpType::LEAKY_REL
constexpr auto prelu = leaky_relu;
constexpr auto elu = make_eltwise_unary_with_param<UnaryOpType::ELU>{};
constexpr auto heaviside = make_eltwise_unary_with_param<UnaryOpType::HEAVISIDE>{};
constexpr auto bitwise_xor = make_eltwise_unary_with_param<UnaryOpType::BITWISE_XOR>{};
constexpr auto bitwise_not = make_eltwise_unary_with_param<UnaryOpType::BITWISE_NOT>{};
constexpr auto right_shift = make_eltwise_unary_with_param<UnaryOpType::RIGHT_SHIFT>{};
constexpr auto left_shift = make_eltwise_unary_with_param<UnaryOpType::LEFT_SHIFT>{};
constexpr auto unary_remainder = make_eltwise_unary_with_param<UnaryOpType::REMAINDER>{};
Expand Down
33 changes: 33 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,39 @@ namespace tt::tt_metal::detail {
)doc");

m_tensor.def("bitwise_xor",bitwise_xor,
py::arg("input").noconvert(),py::arg("value"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Computes bitwise_xor of input tensor ``input`` by a scalar ``value``. Input tensor needs to be positive. Support provided only for Wormhole_B0.
Input tensor must have INT32 data type.
Output tensor will have INT32 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"value", "scalar value", "int", "", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("bitwise_not",bitwise_not,
py::arg("input").noconvert(),py::arg("value"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Computes bitwise_not of input tensor ``input``. Input tensor needs to be in the range [-2147483647, 2147483647]. Support provided only for Wormhole_B0.
Input tensor must have INT32 data type.
Output tensor will have INT32 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("right_shift",right_shift,
py::arg("input").noconvert(),py::arg("shift_amt"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Computes right shift of input tensor ``input`` by ``shift_amt`` bits. ``shift_amt`` range must be [0, 31]. Support provided only for Wormhole_B0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_remainder.h"
#include "llk_math_eltwise_unary_sfpu_bitwise_xor.h"
#include "llk_math_eltwise_unary_sfpu_bitwise_not.h"
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
Loading

0 comments on commit 24b210a

Please sign in to comment.