Skip to content

Commit

Permalink
#5044: Add optional output to where op
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jun 3, 2024
1 parent 354370a commit ef7660e
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 32 deletions.
8 changes: 8 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,14 @@
"tt_op": tt_lib_ops.where,
"pytorch_op": pytorch_ops.where,
},
"eltwise-where-optional": {
"tt_op": tt_lib_ops.where_optional,
"pytorch_op": pytorch_ops.where,
},
"eltwise-where-scalar-optional": {
"tt_op": tt_lib_ops.where_scalar_optional,
"pytorch_op": pytorch_ops.where_scalar,
},
"where-bw": {
"tt_op": tt_lib_ops.where_bw,
"pytorch_op": pytorch_ops.where_bw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
import random
from functools import partial
from math import pi

Expand Down Expand Up @@ -36,3 +37,48 @@ def test_run_eltwise_where_test(input_shapes, device, function_level_defaults):
comparison_func,
device,
)


@pytest.mark.parametrize("input_shapes", shapes)
def test_run_eltwise_where_test_optional(input_shapes, device, function_level_defaults):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-5, high=+5), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=+10), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32),
]
comparison_func = partial(comparison_funcs.comp_pcc)
run_single_pytorch_test(
"eltwise-where-optional",
[input_shapes[0], input_shapes[0], input_shapes[0], input_shapes[0]],
datagen_func,
comparison_func,
device,
)


shapes_scalar = (
[[1, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core
)


@pytest.mark.parametrize("input_shapes", shapes_scalar)
def test_run_eltwise_where_scalar_optional(input_shapes, device, function_level_defaults):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32),
]
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update({"scalar_true": random.uniform(0.5, 75.5), "scalar_false": random.uniform(0.5, 95.5)})

comparison_func = partial(comparison_funcs.comp_pcc)
run_single_pytorch_test(
"eltwise-where-scalar-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
6 changes: 6 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def where(x, y, z, *args, **kwargs):
return torch.where(x > 0, y, z)


def where_scalar(x, *args, **kwargs):
y = kwargs.pop("scalar_true")
z = kwargs.pop("scalar_false")
return torch.where(x > 0, y, z)


def where_bw(x, y, z, w, *args, **kwargs):
grad_data = x
in_data = y
Expand Down
22 changes: 22 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,28 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, *
return tt2torch_tensor(t3)


@setup_host_and_device
def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3])
ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3)

return tt2torch_tensor(t3)


@setup_host_and_device
def where_scalar_optional(
x, out, device, dtype, layout, input_mem_config, output_mem_config, scalar_true, scalar_false, **kwargs
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1])
ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3)

return tt2torch_tensor(t3)


@setup_host_and_device
def eltwise_div_unary(
x,
Expand Down
76 changes: 56 additions & 20 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,48 +1228,84 @@ Tensor _where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config) {
const MemoryConfig& output_mem_config,
std::optional<Tensor> output_tensor) {

Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);
if(output_tensor.has_value())
{
mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v1(
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) {
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {

Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config);
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);

if(output_tensor.has_value()){
mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v2(
const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config) {
Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {

Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);

if(output_tensor.has_value()){
mul(gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v3(
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) {
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config);
Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);
if(output_tensor.has_value()){
add(t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
} else {
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}

Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config);
const MemoryConfig& output_mem_config,
std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config, output_tensor);
}

// on-device tensor creation 0s like @reference_tensor
Expand Down
12 changes: 8 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,22 +316,26 @@ Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const float value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const float value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const float value_true,
const float value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);

// on-device tensor creation 0s like @reference_tensor
Tensor zeros_like(
Expand Down
20 changes: 12 additions & 8 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.
where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -89,9 +89,10 @@ namespace tt::tt_metal::detail{
"true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "False Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, float, const Tensor&, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, float, const Tensor&, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.
where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -107,9 +108,10 @@ namespace tt::tt_metal::detail{
"true_value", "float", "float", "float scalar", "Yes"
"false_value", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const float, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const float, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.
where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -125,9 +127,10 @@ namespace tt::tt_metal::detail{
"true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "float", "float", "float scalar", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, const float, const float, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const float, const float, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.
where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -143,6 +146,7 @@ namespace tt::tt_metal::detail{
"true_value", "float", "float", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "float", "float", "float scalar", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
// *** composite unary ops ***
detail::bind_unary_op(m_tensor, "normalize_hw", tt::tt_metal::normalize_hw, R"doc(Returns a new tensor with the Gaussian normalize of the elements of the input tensor ``{0}`` on H,W axes.)doc");
Expand Down

0 comments on commit ef7660e

Please sign in to comment.