Skip to content

Commit

Permalink
#14707: refactoring return optional tensors (#14708)
Browse files Browse the repository at this point in the history
* #14707: refactoring return optional tensors
  • Loading branch information
hschoi4448 authored Nov 7, 2024
1 parent 4e152e0 commit 750d0f2
Show file tree
Hide file tree
Showing 25 changed files with 148 additions and 256 deletions.
8 changes: 5 additions & 3 deletions tests/ttnn/unit_tests/operations/test_moreh_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ def make_input_tensors(input_shape, affine, do_backward=False):


def run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device):
torch.manual_seed(2024)

H, W = HW
C, num_groups = C_num_groups
input_shape = (N, C, H, W)
Expand Down Expand Up @@ -284,6 +282,8 @@ def run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rst
],
)
def test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device):
torch.manual_seed(2024)

run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device)


Expand Down Expand Up @@ -343,7 +343,8 @@ def run_test_moreh_group_norm_backward(
if not affine and (gamma_requires_grad or beta_requires_grad):
pytest.skip("gamma_requires_grad and beta_requires_grad are only valid when affine is True.")

torch.manual_seed(2024)
if not (input_requires_grad or gamma_requires_grad or beta_requires_grad):
pytest.skip("at least one requires_grad should be True.")

C, num_groups = C_num_groups
input_shape = (N, C, H, W)
Expand Down Expand Up @@ -466,6 +467,7 @@ def run_test_moreh_group_norm_backward(
def test_moreh_group_norm_backward(
N, C_num_groups, HW, eps, affine, input_requires_grad, gamma_requires_grad, beta_requires_grad, device
):
torch.manual_seed(2024)
run_test_moreh_group_norm_backward(
N, C_num_groups, HW, eps, affine, input_requires_grad, gamma_requires_grad, beta_requires_grad, device
)
Expand Down
130 changes: 63 additions & 67 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

#include <reflect>

#include "tt_metal/graph/graph_tracking.hpp"
#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/core.hpp"
#include "ttnn/device_operation.hpp"
#include "ttnn/operation.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/device_operation.hpp"
#include "tt_metal/graph/graph_tracking.hpp"

namespace ttnn {
namespace decorators {
Expand Down Expand Up @@ -51,17 +51,30 @@ auto extract_args_to_vector(args_t&&... args) {
return result;
}

template <typename operation_t, typename execute_on_worker_thread_return_t>
inline Tensors create_async_output_tensors(const Tensors& inputs, const OptionalConstTensors& optional_inputs) {
template <typename operation_t, typename execute_on_worker_thread_return_t, typename... args_t>
inline auto create_async_output_tensors(
const Tensors& inputs, const OptionalConstTensors& optional_inputs, args_t&&... args) {
bool enable_autoformat_device = false;

constexpr bool custom_create_async_outputs =
requires(const operation_t& t) { t.create_async_output_tensors(inputs, optional_inputs); };

if constexpr (custom_create_async_outputs) {
return operation_t::create_async_output_tensors(inputs, optional_inputs);
} else if constexpr (std::is_same_v<std::decay_t<execute_on_worker_thread_return_t>, OptionalTensors>) {
constexpr bool custom_create_async_optional_outputs = requires(const operation_t& t) {
t.create_async_optional_output_tensors(std::forward<decltype(args)>(args)...);
};
static_assert(
custom_create_async_optional_outputs,
"If the operation returns a vector of optional Tensors, it must "
"implement create_async_optional_output_tensors.");

return operation_t::create_async_optional_output_tensors(std::forward<decltype(args)>(args)...);
} else if constexpr (std::is_same_v<std::decay_t<execute_on_worker_thread_return_t>, Tensor>) {
return {Tensor(operation::get_workers_for_op_output(inputs, optional_inputs, enable_autoformat_device))};
return std::vector{
Tensor(operation::get_workers_for_op_output(inputs, optional_inputs, enable_autoformat_device))};

} else if constexpr (detail::is_homogenous_tuple<execute_on_worker_thread_return_t, Tensor>()) {
Tensors output_tensors;
output_tensors.reserve(std::tuple_size_v<execute_on_worker_thread_return_t>);
Expand Down Expand Up @@ -108,21 +121,13 @@ auto map_launch_op_args_to_execute_on_worker_thread_args(
}

template <typename operation_t, typename T>
auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) -> Tensors {
auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) {
if constexpr (std::is_same_v<std::decay_t<decltype(value)>, Tensors>) {
return value;
} else if constexpr (std::is_same_v<std::decay_t<decltype(value)>, Tensor>) {
return {value};
return std::vector{value};
} else if constexpr (std::is_same_v<std::decay_t<decltype(value)>, OptionalTensors>) {
Tensors output_tensors;
auto size = value.size();
output_tensors.reserve(size);

auto dummy_tensor = Tensor();
for (auto& val : value) {
output_tensors.push_back(val.value_or(dummy_tensor));
}
return output_tensors;
return value;
} else if constexpr (is_homogenous_tuple<T, Tensor>()) {
Tensors output_tensors;
output_tensors.reserve(std::tuple_size_v<T>);
Expand All @@ -135,8 +140,8 @@ auto map_execute_on_worker_thread_return_to_launch_op_return(const T&& value) ->
} else {
static_assert(
tt::stl::concepts::always_false_v<operation_t>,
"Operation must return either a single Tensor or a vector of Tensors or implement "
"map_execute_on_worker_thread_return_to_launch_op_return.");
"Operation must return either a single Tensor or a vector of Tensors or a vector of optional Tensors "
"implement map_execute_on_worker_thread_return_to_launch_op_return.");
}
}

Expand Down Expand Up @@ -194,7 +199,7 @@ template <typename operation_t>
concept PrimitiveOperationConcept = device_operation::DeviceOperationConcept<operation_t>;

// Composite operation allows any code to be executed
template<typename operation_t>
template <typename operation_t>
concept CompositeOperationConcept = !PrimitiveOperationConcept<operation_t>;

template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t, bool auto_launch_op>
Expand All @@ -213,23 +218,23 @@ struct registered_operation_t {
}

template <typename... args_t>
requires PrimitiveOperationConcept<operation_t>
requires PrimitiveOperationConcept<operation_t>
auto invoke(uint8_t queue_id, args_t&&... args) const {
static_assert(requires { operation_t::invoke(std::forward<decltype(args)>(args)...); },
"Primitive Operation must implement operator() method to be invoked.");
static_assert(
requires { operation_t::invoke(std::forward<decltype(args)>(args)...); },
"Primitive Operation must implement operator() method to be invoked.");
ZoneScopedN("Run primitive ttnn operation");
ZoneName(static_cast<const char*>(cpp_fully_qualified_name.data.data()), cpp_fully_qualified_name.size());
auto [operation_attributes, tensors_args] = operation_t::invoke(std::forward<decltype(args)>(args)...);
return ttnn::device_operation::detail::invoke<operation_t>(queue_id, operation_attributes, tensors_args);
}

template <typename... args_t>
requires(PrimitiveOperationConcept<operation_t>)
requires(PrimitiveOperationConcept<operation_t>)
auto invoke(args_t&&... args) const {
return invoke(DefaultQueueId, std::forward<args_t>(args)...);
}


template <typename... args_t>
requires(not auto_launch_op)
auto invoke_composite(args_t&&... args) const {
Expand All @@ -247,16 +252,14 @@ struct registered_operation_t {
// #8479: Fix and re-enable logging in cpp operation decorator
// detail::log("Arguments: ", std::forward<args_t>(args)...);

using execute_on_worker_thread_return_t =
decltype(operation_t::invoke(std::forward<decltype(args)>(args)...));
using execute_on_worker_thread_return_t = decltype(operation_t::invoke(std::forward<decltype(args)>(args)...));

const Tensors input_tensors = detail::extract_args_to_vector<ttnn::Tensor>(std::forward<args_t>(args)...);
const OptionalConstTensors optional_input_tensors =
detail::extract_args_to_vector<std::optional<const ttnn::Tensor>>(std::forward<args_t>(args)...);

auto output_tensors =
detail::create_async_output_tensors<operation_t, execute_on_worker_thread_return_t>(
input_tensors, optional_input_tensors);
auto output_tensors = detail::create_async_output_tensors<operation_t, execute_on_worker_thread_return_t>(
input_tensors, optional_input_tensors, std::forward<decltype(args)>(args)...);

const OptionalTensors optional_output_tensors =
detail::extract_args_to_vector<std::optional<ttnn::Tensor>>(std::forward<args_t>(args)...);
Expand All @@ -266,11 +269,11 @@ struct registered_operation_t {
[args...](
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
const OptionalTensors& optional_output_tensors) mutable -> Tensors {
const OptionalTensors& optional_output_tensors) mutable {
auto execute_on_worker_thread_args = detail::map_launch_op_args_to_execute_on_worker_thread_args(
input_tensors, optional_input_tensors, optional_output_tensors, std::forward<args_t>(args)...);
return std::apply(
[](auto&&... args) -> Tensors {
[](auto&&... args) {
return detail::map_execute_on_worker_thread_return_to_launch_op_return<operation_t>(
operation_t::invoke(std::forward<decltype(args)>(args)...));
},
Expand All @@ -282,24 +285,12 @@ struct registered_operation_t {
optional_output_tensors,
enable_autoformat);


if constexpr (std::is_same_v<std::decay_t<execute_on_worker_thread_return_t>, Tensor>) {
return output_tensors.at(0);
} else if constexpr (std::is_same_v<execute_on_worker_thread_return_t, Tensors>) {
return output_tensors;
} else if constexpr (std::is_same_v<execute_on_worker_thread_return_t, OptionalTensors>) {
// convert tensor to optional tensor
auto size = output_tensors.size();
std::vector<std::optional<Tensor>> ret(size);

auto return_flags = operation_t::create_async_return_flag(std::forward<decltype(args)>(args)...);

for (uint32_t i = 0 ; i < size; i++) {
if (return_flags.at(i)) {
ret[i] = output_tensors.at(i);
}
}
return ret;
return output_tensors;
} else if constexpr (detail::is_homogenous_tuple<execute_on_worker_thread_return_t, Tensor>()) {
return detail::make_tuple_from_vector<execute_on_worker_thread_return_t>(output_tensors);
} else {
Expand All @@ -312,7 +303,7 @@ struct registered_operation_t {
}

template <typename... args_t>
requires(CompositeOperationConcept<operation_t>)
requires(CompositeOperationConcept<operation_t>)
auto invoke(args_t&&... args) const {
return invoke_composite(std::forward<args_t>(args)...);
}
Expand All @@ -336,43 +327,45 @@ struct registered_operation_t {
}
};

template<reflect::fixed_string cpp_fully_qualified_name>
struct operation_name_key_t{
template <reflect::fixed_string cpp_fully_qualified_name>
struct operation_name_key_t {
friend consteval auto get(operation_name_key_t<cpp_fully_qualified_name>);
};

template<typename operation_t>
struct operation_key_t{
template <typename operation_t>
struct operation_key_t {
friend consteval auto get(operation_key_t<operation_t>);
};

template<reflect::fixed_string cpp_fully_qualified_name, typename operation_t, auto operation>
template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t, auto operation>
struct set_operation_t : std::true_type {
friend consteval auto get(operation_key_t<operation_t>) {
return operation;
}
friend consteval auto get(operation_name_key_t<cpp_fully_qualified_name>) {
return operation;
}
friend consteval auto get(operation_key_t<operation_t>) { return operation; }
friend consteval auto get(operation_name_key_t<cpp_fully_qualified_name>) { return operation; }
};

constexpr reflect::fixed_string prim_namespace = "ttnn::prim";

template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t>
consteval void assert_operation_in_correct_namespace() {
if constexpr (PrimitiveOperationConcept<operation_t>) {
if constexpr(cpp_fully_qualified_name.size() > prim_namespace.size()) {
constexpr auto namespace_substring = tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name);
static_assert(tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), "Primitive operations must be in the `ttnn::prim` namespace.");
if constexpr (cpp_fully_qualified_name.size() > prim_namespace.size()) {
constexpr auto namespace_substring =
tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name);
static_assert(
tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace),
"Primitive operations must be in the `ttnn::prim` namespace.");
} else {
#ifndef DISABLE_NAMESPACE_STATIC_ASSERT
#ifndef DISABLE_NAMESPACE_STATIC_ASSERT
static_assert(false, "Primitive operations must be in the `ttnn::prim` namespace.");
#endif
#endif
}
} else {
if constexpr (cpp_fully_qualified_name.size() > prim_namespace.size()) {
constexpr auto namespace_substring = tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name);
static_assert(not tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace), "Composite operations must not be in the `ttnn::prim` namespace.");
constexpr auto namespace_substring =
tt::stl::reflection::fixed_string_substring<0, prim_namespace.size()>(cpp_fully_qualified_name);
static_assert(
not tt::stl::reflection::fixed_string_equals(namespace_substring, prim_namespace),
"Composite operations must not be in the `ttnn::prim` namespace.");
}
}
}
Expand All @@ -381,13 +374,16 @@ template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t,
constexpr auto register_operation_impl() {
assert_operation_in_correct_namespace<cpp_fully_qualified_name, operation_t>();
constexpr auto operation = registered_operation_t<cpp_fully_qualified_name, operation_t, auto_launch_op>{};
static_assert(not requires(operation_name_key_t<cpp_fully_qualified_name> key) { get(key); }, "Operation with this `cpp_fully_qualified_name` was already registered. Please use a different name.");
static_assert(not requires(operation_key_t<operation_t> key) { get(key); }, "Operation with this `operation_t` was already registered. Please use a different type.");
static_assert(
not requires(operation_name_key_t<cpp_fully_qualified_name> key) { get(key); },
"Operation with this `cpp_fully_qualified_name` was already registered. Please use a different name.");
static_assert(
not requires(operation_key_t<operation_t> key) { get(key); },
"Operation with this `operation_t` was already registered. Please use a different type.");
static_assert(set_operation_t<cpp_fully_qualified_name, operation_t, operation>::value);
return operation;
}


template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t>
constexpr auto register_operation() {
return register_operation_impl<cpp_fully_qualified_name, operation_t, false>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,11 @@ std::vector<std::optional<Tensor>> CompositeExampleMutipleReturnOperation::invok
return prim::example_multiple_return(input_tensor, return_output1, return_output2);
}

std::vector<Tensor> CompositeExampleMutipleReturnOperation::create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input_tensor = input_tensors.at(0);
OptionalTensors CompositeExampleMutipleReturnOperation::create_async_optional_output_tensors(
const Tensor& input_tensor, bool return_output1, bool return_output2) {
return {
Tensor(operation::get_workers_for_op_output({input_tensor})),
Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

std::vector<bool> CompositeExampleMutipleReturnOperation::create_async_return_flag(const Tensor& input_tensor, bool return_output1, bool return_output2) {

return {return_output1, return_output2};
return_output1 ? std::optional<Tensor>(operation::get_workers_for_op_output({input_tensor})) : std::nullopt,
return_output2 ? std::optional<Tensor>(operation::get_workers_for_op_output({input_tensor})) : std::nullopt};
}

} // namespace ttnn::operations::examples
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@ struct CompositeExampleMutipleReturnOperation {
// is registered
static std::vector<std::optional<Tensor>> invoke(const Tensor& input_tensor, bool return_output1, bool return_output2);

static std::vector<Tensor> create_async_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs);
static OptionalTensors create_async_optional_output_tensors(
const Tensor& input_tensor, bool return_output1, bool return_output2);

// The parameters of this function must be identical to those of invoke.
static std::vector<bool> create_async_return_flag(
const Tensor& input_tensor, bool return_output1, bool return_output2
);
};

} // namespace ttnn::operations::examples
Expand Down
Loading

0 comments on commit 750d0f2

Please sign in to comment.