From 24b210a40829b1ca2bec31cac33439953f685bec Mon Sep 17 00:00:00 2001 From: Anasuya G Nair Date: Tue, 25 Jun 2024 13:04:11 +0530 Subject: [PATCH] #8683: Add Unary bitwise XOR, NOT (#9436) * #8683: Add bitwise_xor op * #8683: Add bitwise_not op --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 4 + .../python_api_testing/sweep_tests/op_map.py | 8 ++ .../pytests/tt_dnn/test_bitwise_not.py | 73 +++++++++++++++++++ .../pytests/tt_dnn/test_bitwise_xor.py | 72 ++++++++++++++++++ .../sweep_tests/pytorch_ops.py | 11 +++ .../sweep_tests/tt_lib_ops.py | 36 +++++++++ .../eltwise_unary/eltwise_unary_op.cpp | 27 ++++++- .../eltwise_unary/eltwise_unary_op.hpp | 6 ++ .../csrc/tt_lib_bindings_tensor_xary_ops.cpp | 33 +++++++++ .../metal/llk_api/llk_math_unary_sfpu_api.h | 2 + .../llk_sfpu/ckernel_sfpu_bitwise_not.h | 40 ++++++++++ .../llk_sfpu/ckernel_sfpu_bitwise_xor.h | 35 +++++++++ .../llk_math_eltwise_unary_sfpu_bitwise_not.h | 26 +++++++ .../llk_math_eltwise_unary_sfpu_bitwise_xor.h | 29 ++++++++ .../metal/llk_api/llk_sfpu_types.h | 2 + .../eltwise_unary/bitwise_not.h | 43 +++++++++++ .../eltwise_unary/bitwise_xor.h | 44 +++++++++++ .../eltwise_unary/sfpu_split_includes.h | 8 ++ 18 files changed, 496 insertions(+), 3 deletions(-) create mode 100644 tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_not.py create mode 100644 tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_xor.py create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_not.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_xor.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_not.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_xor.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_not.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_xor.h diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 620a8ae6054..b617ccc9345 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -405,6 +405,10 @@ Tensor elementwise operations .. autofunction:: tt_lib.tensor.heaviside +.. autofunction:: tt_lib.tensor.bitwise_xor + +.. autofunction:: tt_lib.tensor.bitwise_not + .. autofunction:: tt_lib.tensor.right_shift .. autofunction:: tt_lib.tensor.left_shift diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index cbe6e9c8785..73d23e276f2 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -492,6 +492,14 @@ "tt_op": tt_lib_ops.eltwise_heaviside, "pytorch_op": pytorch_ops.heaviside, }, + "eltwise-bitwise_xor": { + "tt_op": tt_lib_ops.eltwise_bitwise_xor, + "pytorch_op": pytorch_ops.bitwise_xor, + }, + "eltwise-bitwise_not": { + "tt_op": tt_lib_ops.eltwise_bitwise_not, + "pytorch_op": pytorch_ops.bitwise_not, + }, "eltwise-right_shift": { "tt_op": tt_lib_ops.eltwise_right_shift, "pytorch_op": pytorch_ops.right_shift, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_not.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_not.py new file mode 100644 index 00000000000..0cdcdd67f71 --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_not.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from functools import partial +import tt_lib as ttl + + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) +from models.utility_functions import skip_for_grayskull + +mem_configs = [ + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), +] + + +@pytest.mark.parametrize( + "scalar", + (1, 1), +) +@pytest.mark.parametrize( + "input_shapes", + [ + [[1, 1, 32, 32]], + [[4, 3, 32, 32]], + [[2, 2, 32, 32]], + ], +) +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +@skip_for_grayskull("#TODO: GS implementation needs to be done") +class TestBitwiseNot: + def test_run_bitwise_not_op( + self, + scalar, + input_shapes, + dst_mem_config, + device, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast( + partial(generation_funcs.gen_rand, low=-2147483647, high=2147483647), torch.int + ) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update( + { + "value": scalar, + "dtype": [(ttl.tensor.DataType.INT32)], + } + ) + test_args.update({"output_mem_config": dst_mem_config}) + comparison_func = comparison_funcs.comp_equal + + run_single_pytorch_test( + "eltwise-bitwise_not", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_xor.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_xor.py new file mode 100644 index 00000000000..39b86ab95c4 --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_bitwise_xor.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import random +from functools import partial +import tt_lib as ttl + + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) +from models.utility_functions import skip_for_grayskull + +mem_configs = [ + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), +] + + +@pytest.mark.parametrize( + "scalar", + {random.randint(-100, 100) for _ in range(10)}, +) +@pytest.mark.parametrize( + "input_shapes", + [ + [[1, 1, 32, 32]], + [[4, 3, 32, 32]], + [[2, 2, 32, 32]], + ], +) +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +@skip_for_grayskull("#TODO: GS implementation needs to be done") +class TestBitwiseXor: + def test_run_bitwise_xor_op( + self, + scalar, + input_shapes, + dst_mem_config, + device, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=0, high=2147483647), torch.int) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update( + { + "value": scalar, + "dtype": [(ttl.tensor.DataType.INT32)], + } + ) + test_args.update({"output_mem_config": dst_mem_config}) + comparison_func = comparison_funcs.comp_equal + + run_single_pytorch_test( + "eltwise-bitwise_xor", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index d3d01c45e31..918ebdc87c4 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -527,6 +527,17 @@ def heaviside(x, *args, **kwargs): return result +def bitwise_xor(x, *args, **kwargs): + value = kwargs.pop("value") + result = torch.bitwise_xor(x, value) + return result + + +def bitwise_not(x, *args, **kwargs): + result = torch.bitwise_not(x) + return result + + def right_shift(x, *args, **kwargs): value = kwargs.pop("value") result = torch.bitwise_right_shift(x, value) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 54f30b79d97..3fa455cb021 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1206,6 +1206,42 @@ def lamb_optimizer( return [tt2torch_tensor(t4[0]), tt2torch_tensor(t4[1]), tt2torch_tensor(t4[2])] +@setup_host_and_device +def eltwise_bitwise_xor( + x, + *args, + value, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = ttl.tensor.bitwise_xor(t0, value, output_mem_config=output_mem_config) + + return tt2torch_tensor(t1) + + +@setup_host_and_device +def eltwise_bitwise_not( + x, + *args, + value, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = ttl.tensor.bitwise_not(t0, value, output_mem_config=output_mem_config) + + return tt2torch_tensor(t1) + + @setup_host_and_device def eltwise_right_shift( x, diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index bd3f7a119e5..37c287490ec 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -65,6 +65,8 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( op_init_and_name = { "heaviside_tile_init();", fmt::format("heaviside_tile({}, {}u);", idst, Converter::to_hex(param0))}; break; + case UnaryOpType::BITWISE_XOR: + op_init_and_name = { + "bitwise_xor_tile_init();", fmt::format("bitwise_xor_tile({}, {}u);", idst, std::to_string((uint)param0))}; + break; + case UnaryOpType::BITWISE_NOT: + op_init_and_name = { + "bitwise_not_tile_init();", fmt::format("bitwise_not_tile({}, {}u);", idst, std::to_string((uint)param0))}; + break; case UnaryOpType::RIGHT_SHIFT: op_init_and_name = { "right_shift_tile_init();", @@ -341,7 +351,7 @@ namespace tt { namespace tt_metal { -inline void validate_supported_arch(tt::ARCH arch, UnaryOpType op_type) { +inline void validate_supported_arch_dtype(tt::ARCH arch, DataType input_datatype, DataType output_datatype, UnaryOpType op_type) { switch (op_type) { case UnaryOpType::REMAINDER: case UnaryOpType::FLOOR: @@ -349,6 +359,12 @@ inline void validate_supported_arch(tt::ARCH arch, UnaryOpType op_type) { case UnaryOpType::RIGHT_SHIFT: TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole"); break; + case UnaryOpType::BITWISE_XOR: + case UnaryOpType::BITWISE_NOT: + TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole"); + TT_FATAL(input_datatype == DataType::INT32, "Data type is not supported for Bitwise operations"); + TT_FATAL(output_datatype == DataType::INT32, "Data type is not supported for Bitwise operations"); + break; default: return; } @@ -357,10 +373,15 @@ inline void validate_supported_arch(tt::ARCH arch, UnaryOpType op_type) { void EltwiseUnary::validate_with_output_tensors(const std::vector &input_tensors, const std::vector> &optional_output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto out_mem_config = (!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value()) ? optional_output_tensors.at(0).value().memory_config() : this->output_mem_config; - + auto output_datatype = output_dtype; + if(!optional_output_tensors.empty() && optional_output_tensors.at(0).has_value()){ + const auto& out = optional_output_tensors.at(0); + output_datatype = out->get_dtype(); + } auto arch = input_tensor_a.device()->arch(); + auto input_datatype = input_tensor_a.get_dtype(); for (const auto& unary_op : this->op_chain) { - validate_supported_arch(arch, unary_op.op_type); + validate_supported_arch_dtype(arch, input_datatype, output_datatype, unary_op.op_type); } TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to eltwise unary need to be on device!"); TT_FATAL( diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index 944d74ce750..00e6f7fcc61 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -79,6 +79,8 @@ enum class UnaryOpType { UNARY_LT, TILED_PROD, TYPECAST, + BITWISE_XOR, + BITWISE_NOT, RIGHT_SHIFT, FLOOR, LEFT_SHIFT, @@ -110,6 +112,8 @@ bool is_parametrized_type(T val) { case UnaryOpType::UNARY_GT: case UnaryOpType::UNARY_LT: case UnaryOpType::TYPECAST: + case UnaryOpType::BITWISE_XOR: + case UnaryOpType::BITWISE_NOT: case UnaryOpType::RIGHT_SHIFT: case UnaryOpType::LEFT_SHIFT: case UnaryOpType::REMAINDER: return true; @@ -415,6 +419,8 @@ constexpr auto leaky_relu = make_eltwise_unary_with_param{}; constexpr auto heaviside = make_eltwise_unary_with_param{}; +constexpr auto bitwise_xor = make_eltwise_unary_with_param{}; +constexpr auto bitwise_not = make_eltwise_unary_with_param{}; constexpr auto right_shift = make_eltwise_unary_with_param{}; constexpr auto left_shift = make_eltwise_unary_with_param{}; constexpr auto unary_remainder = make_eltwise_unary_with_param{}; diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index a4575e335be..307acc33119 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -195,6 +195,39 @@ namespace tt::tt_metal::detail { )doc"); + m_tensor.def("bitwise_xor",bitwise_xor, + py::arg("input").noconvert(),py::arg("value"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc( + Computes bitwise_xor of input tensor ``input`` by a scalar ``value``. Input tensor needs to be positive. Support provided only for Wormhole_B0. + + Input tensor must have INT32 data type. + + Output tensor will have INT32 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "value", "scalar value", "int", "", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + + )doc"); + + m_tensor.def("bitwise_not",bitwise_not, + py::arg("input").noconvert(),py::arg("value"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc( + Computes bitwise_not of input tensor ``input``. Input tensor needs to be in the range [-2147483647, 2147483647]. Support provided only for Wormhole_B0. + + Input tensor must have INT32 data type. + + Output tensor will have INT32 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + + )doc"); + m_tensor.def("right_shift",right_shift, py::arg("input").noconvert(),py::arg("shift_amt"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc( Computes right shift of input tensor ``input`` by ``shift_amt`` bits. ``shift_amt`` range must be [0, 31]. Support provided only for Wormhole_B0. diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index ba6de069bf0..6a4223dbe64 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -27,5 +27,7 @@ #include "llk_math_eltwise_unary_sfpu_trigonometry.h" #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_remainder.h" +#include "llk_math_eltwise_unary_sfpu_bitwise_xor.h" +#include "llk_math_eltwise_unary_sfpu_bitwise_not.h" #include "llk_math_eltwise_unary_sfpu_right_shift.h" #include "llk_math_eltwise_unary_sfpu_left_shift.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_not.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_not.h new file mode 100644 index 00000000000..8d22000daca --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_not.h @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_bitwise_not(const uint value) { +#pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + vInt input = dst_reg[0]; + vInt res; + + v_if(input < 0) { + vInt unsigned_mag = input & 0x7FFFFFFF; + res = unsigned_mag - 1; + } + v_else { + res = setsgn(input, -1); + res = res + 1; + } + v_endif; + + dst_reg[0] = res; + dst_reg++; + } +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_xor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_xor.h new file mode 100644 index 00000000000..25c1462e8c4 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_bitwise_xor.h @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" +#include "limits.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_bitwise_xor(const uint value) { +#pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + vInt input = dst_reg[0]; + vInt v = value; + vInt res = input ^ v; + v_if(res > INT_MIN && res < 0) + { + res = 0 - res; + res = setsgn(res, v); + } + v_endif + dst_reg[0] = res; + dst_reg++; + } +} +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_not.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_not.h new file mode 100644 index 00000000000..6746cd9bdc9 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_not.h @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_bitwise_not.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_bitwise_not_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_bitwise_not(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params + (ckernel::sfpu::calculate_bitwise_not, + dst_index, vector_mode, param0); +} +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_xor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_xor.h new file mode 100644 index 00000000000..a6a8815e9e5 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_bitwise_xor.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_bitwise_xor.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_bitwise_xor_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_bitwise_xor(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_bitwise_xor, + dst_index, + vector_mode, + param0); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index 8bbebe2e15e..917f4c08aa1 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -75,6 +75,8 @@ enum SfpuType { unary_lt, softplus, tiled_prod, + bitwise_xor, + bitwise_not, right_shift, floor, left_shift, diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_not.h b/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_not.h new file mode 100644 index 00000000000..b3de376febc --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_not.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_bitwise_not.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise bitwise_not computation on input x , where x is each element of a tile + * in DST register at index tile_index. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to modify the computation of | uint32_t | Must be less than the size of the DST register buffer | True | + */ +ALWI void bitwise_not_tile(uint32_t idst, uint32_t param0) { + MATH((llk_math_eltwise_unary_sfpu_bitwise_not(idst, param0))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void bitwise_not_tile_init() { + MATH((llk_math_eltwise_unary_sfpu_bitwise_not_init())); } + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_xor.h b/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_xor.h new file mode 100644 index 00000000000..831e235869e --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/bitwise_xor.h @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_bitwise_xor.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise bitwise_xor computation on input x , where x is each element of a tile + * in DST register at index tile_index. The value is provided as const param0 The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking xor is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True | + * | param0 | The value the output is if the input is greater than 0 | uint32_t | | True | + */ +ALWI void bitwise_xor_tile(uint32_t idst, uint32_t param0) { + MATH((llk_math_eltwise_unary_sfpu_bitwise_xor(idst, param0))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void bitwise_xor_tile_init() { MATH((llk_math_eltwise_unary_sfpu_bitwise_xor_init())); } + + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index d882e4d3ab1..fc7bc8fd056 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -68,6 +68,14 @@ #include "compute_kernel_api/eltwise_unary/typecast.h" #endif +#if SFPU_OP_BITWISE_XOR_INCLUDE +#include "compute_kernel_api/eltwise_unary/bitwise_xor.h" +#endif + +#if SFPU_OP_BITWISE_NOT_INCLUDE +#include "compute_kernel_api/eltwise_unary/bitwise_not.h" +#endif + #if SFPU_OP_RIGHT_SHIFT_INCLUDE #include "compute_kernel_api/eltwise_unary/right_shift.h" #endif