Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#5044: Add optional output to where op #9055

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,14 @@
"tt_op": tt_lib_ops.where,
"pytorch_op": pytorch_ops.where,
},
"eltwise-where-optional": {
"tt_op": tt_lib_ops.where_optional,
"pytorch_op": pytorch_ops.where,
},
"eltwise-where-scalar-optional": {
"tt_op": tt_lib_ops.where_scalar_optional,
"pytorch_op": pytorch_ops.where_scalar,
},
"where-bw": {
"tt_op": tt_lib_ops.where_bw,
"pytorch_op": pytorch_ops.where_bw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
import random
from functools import partial
from math import pi

Expand Down Expand Up @@ -36,3 +37,48 @@ def test_run_eltwise_where_test(input_shapes, device, function_level_defaults):
comparison_func,
device,
)


@pytest.mark.parametrize("input_shapes", shapes)
def test_run_eltwise_where_test_optional(input_shapes, device, function_level_defaults):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-5, high=+5), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=+10), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32),
]
comparison_func = partial(comparison_funcs.comp_pcc)
run_single_pytorch_test(
"eltwise-where-optional",
[input_shapes[0], input_shapes[0], input_shapes[0], input_shapes[0]],
datagen_func,
comparison_func,
device,
)


shapes_scalar = (
[[1, 1, 32, 32], [1, 1, 32, 32]], # Single core
[[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core
)


@pytest.mark.parametrize("input_shapes", shapes_scalar)
def test_run_eltwise_where_scalar_optional(input_shapes, device, function_level_defaults):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32),
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32),
]
test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0]
test_args.update({"scalar_true": random.uniform(0.5, 75.5), "scalar_false": random.uniform(0.5, 95.5)})

comparison_func = partial(comparison_funcs.comp_pcc)
run_single_pytorch_test(
"eltwise-where-scalar-optional",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
6 changes: 6 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def where(x, y, z, *args, **kwargs):
return torch.where(x > 0, y, z)


def where_scalar(x, *args, **kwargs):
y = kwargs.pop("scalar_true")
z = kwargs.pop("scalar_false")
return torch.where(x > 0, y, z)


def where_bw(x, y, z, w, *args, **kwargs):
grad_data = x
in_data = y
Expand Down
22 changes: 22 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,28 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, *
return tt2torch_tensor(t3)


@setup_host_and_device
def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3])
ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3)

return tt2torch_tensor(t3)


@setup_host_and_device
def where_scalar_optional(
x, out, device, dtype, layout, input_mem_config, output_mem_config, scalar_true, scalar_false, **kwargs
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1])
ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3)

return tt2torch_tensor(t3)


@setup_host_and_device
def eltwise_div_unary(
x,
Expand Down
76 changes: 56 additions & 20 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,48 +1228,84 @@ Tensor _where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config) {
const MemoryConfig& output_mem_config,
std::optional<Tensor> output_tensor) {

Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);
if(output_tensor.has_value())
{
mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v1(
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) {
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {

Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config);
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);

if(output_tensor.has_value()){
mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v2(
const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config) {
Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {

Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);

if(output_tensor.has_value()){
mul(gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
add(output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
}
else
{
Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config);
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}
Tensor _where_v3(
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) {
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config);
Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config);
return add(t2, t1, std::nullopt, output_mem_config);
if(output_tensor.has_value()){
add(t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value());
} else {
output_tensor = add(t2, t1, std::nullopt, output_mem_config);
}
return output_tensor.value();
}

Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config);
const MemoryConfig& output_mem_config,
std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config, output_tensor);
}
Tensor where(
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config);
const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional<Tensor> output_tensor) {
return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config, output_tensor);
}

// on-device tensor creation 0s like @reference_tensor
Expand Down
12 changes: 8 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,22 +316,26 @@ Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const float value_true,
const Tensor& value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const Tensor& value_true,
const float value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);
Tensor where(
const Tensor& predicate,
const float value_true,
const float value_false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<Tensor> output_tensor = std::nullopt);

// on-device tensor creation 0s like @reference_tensor
Tensor zeros_like(
Expand Down
20 changes: 12 additions & 8 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.

where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -89,9 +89,10 @@ namespace tt::tt_metal::detail{
"true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "False Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, float, const Tensor&, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, float, const Tensor&, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.

where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -107,9 +108,10 @@ namespace tt::tt_metal::detail{
"true_value", "float", "float", "float scalar", "Yes"
"false_value", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const float, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const Tensor&, const float, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.

where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -125,9 +127,10 @@ namespace tt::tt_metal::detail{
"true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "float", "float", "float scalar", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
m_tensor.def("where", py::overload_cast<const Tensor&, const float, const float, const MemoryConfig&>(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
m_tensor.def("where", py::overload_cast<const Tensor&, const float, const float, const MemoryConfig&, std::optional<Tensor> >(&where),
py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc(
Perform an ternary where operation on two tensors based on third @predicate.

where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value.
Expand All @@ -143,6 +146,7 @@ namespace tt::tt_metal::detail{
"true_value", "float", "float", "Tensor of shape [W, Z, Y, X]", "Yes"
"false_value", "float", "float", "float scalar", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
"output_tensor", "optional output tensor", "Tensor", "default is None", "No"
)doc");
// *** composite unary ops ***
detail::bind_unary_op(m_tensor, "normalize_hw", tt::tt_metal::normalize_hw, R"doc(Returns a new tensor with the Gaussian normalize of the elements of the input tensor ``{0}`` on H,W axes.)doc");
Expand Down
Loading