diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index 4d70e6b70d6..d6bf0b1ab5f 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -803,6 +803,14 @@ "tt_op": tt_lib_ops.where, "pytorch_op": pytorch_ops.where, }, + "eltwise-where-optional": { + "tt_op": tt_lib_ops.where_optional, + "pytorch_op": pytorch_ops.where, + }, + "eltwise-where-scalar-optional": { + "tt_op": tt_lib_ops.where_scalar_optional, + "pytorch_op": pytorch_ops.where_scalar, + }, "where-bw": { "tt_op": tt_lib_ops.where_bw, "pytorch_op": pytorch_ops.where_bw, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py index 4ddb18dde5c..fee9e99be3c 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py @@ -4,6 +4,7 @@ import pytest import torch +import random from functools import partial from math import pi @@ -36,3 +37,48 @@ def test_run_eltwise_where_test(input_shapes, device, function_level_defaults): comparison_func, device, ) + + +@pytest.mark.parametrize("input_shapes", shapes) +def test_run_eltwise_where_test_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-5, high=+5), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=+10), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-optional", + [input_shapes[0], input_shapes[0], input_shapes[0], input_shapes[0]], + datagen_func, + comparison_func, + device, + ) + + +shapes_scalar = ( + [[1, 1, 32, 32], [1, 1, 32, 32]], # Single core + [[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core + [[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core +) + + +@pytest.mark.parametrize("input_shapes", shapes_scalar) +def test_run_eltwise_where_scalar_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0] + test_args.update({"scalar_true": random.uniform(0.5, 75.5), "scalar_false": random.uniform(0.5, 95.5)}) + + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-scalar-optional", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 6a804785513..1b0f4c27a1a 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -96,6 +96,12 @@ def where(x, y, z, *args, **kwargs): return torch.where(x > 0, y, z) +def where_scalar(x, *args, **kwargs): + y = kwargs.pop("scalar_true") + z = kwargs.pop("scalar_false") + return torch.where(x > 0, y, z) + + def where_bw(x, y, z, w, *args, **kwargs): grad_data = x in_data = y diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index e22d6558329..b9dac18fd1b 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1518,6 +1518,28 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, * return tt2torch_tensor(t3) +@setup_host_and_device +def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) + t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2]) + t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3]) + ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + +@setup_host_and_device +def where_scalar_optional( + x, out, device, dtype, layout, input_mem_config, output_mem_config, scalar_true, scalar_false, **kwargs +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1]) + ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + @setup_host_and_device def eltwise_div_unary( x, diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 7db4638049f..97bd3476238 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1228,48 +1228,84 @@ Tensor _where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()) + { + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v1( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v2( - const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config) { - Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v3( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()){ + add(t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } else { + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } - Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config); + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config, output_tensor); } // on-device tensor creation 0s like @reference_tensor diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 0d79d22a44e..45edd04a6ac 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -316,22 +316,26 @@ Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const Tensor& value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); // on-device tensor creation 0s like @reference_tensor Tensor zeros_like( diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index b3750d8cdd8..5ea5a87f8ec 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -72,8 +72,8 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -89,9 +89,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "False Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -107,9 +108,10 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "float scalar", "Yes" "false_value", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -125,9 +127,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -143,6 +146,7 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); // *** composite unary ops *** detail::bind_unary_op(m_tensor, "normalize_hw", tt::tt_metal::normalize_hw, R"doc(Returns a new tensor with the Gaussian normalize of the elements of the input tensor ``{0}`` on H,W axes.)doc");