From 6f8479b62cc13dd2d87af3355a03069758e211b7 Mon Sep 17 00:00:00 2001 From: KalaivaniMCW Date: Tue, 21 May 2024 13:14:20 +0000 Subject: [PATCH] #5044: add optional output to BW ops EQ, add, addalpha, mul --- .../backward_ops/test_backward_add.py | 46 + .../backward_ops/test_backward_addalpha.py | 48 + .../backward_ops/test_backward_binary_eq.py | 52 +- .../backward_ops/test_backward_mul.py | 46 + .../op_library/backward/backward_ops.cpp | 1987 +++++++++++------ .../op_library/backward/backward_ops.hpp | 973 +++++--- .../tt_lib_bindings_tensor_backward_ops.cpp | 21 +- 7 files changed, 2216 insertions(+), 957 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_add.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_add.py index 6c03bd8feb9..8feb7effcf6 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_add.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_add.py @@ -34,3 +34,49 @@ def test_bw_add(input_shapes, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) +def test_bw_add_with_opt_output(input_shapes, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) + input_grad = None + other_grad = None + + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) + + tt_output_tensor_on_device = tt_lib.tensor.add_bw( + grad_tensor, + input_tensor, + other_tensor, + are_required_outputs=are_required_outputs, + input_grad=input_grad, + other_grad=other_grad, + ) + + in_data.retain_grad() + other_data.retain_grad() + + pyt_y = torch.add(in_data, other_data) + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad, other_data.grad] + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py index fcf32c3ba3b..65332cd6660 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_addalpha.py @@ -35,3 +35,51 @@ def test_bw_addalpha(input_shapes, alpha, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("alpha", [0.05, 2.0, 1.5, 0.12]) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) +def test_bw_addalpha_with_opt_output(input_shapes, alpha, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) + input_grad = None + other_grad = None + + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) + + tt_output_tensor_on_device = tt_lib.tensor.addalpha_bw( + grad_tensor, + input_tensor, + other_tensor, + alpha, + are_required_outputs=are_required_outputs, + input_grad=input_grad, + other_grad=other_grad, + ) + + in_data.retain_grad() + other_data.retain_grad() + + pyt_y = torch.add(in_data, other_data, alpha=alpha) + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad, other_data.grad] + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_eq.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_eq.py index 586f1d8fa9b..5cbb5a147ad 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_eq.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_binary_eq.py @@ -18,10 +18,54 @@ ) def test_bw_binary_eq(input_shapes, device): in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + _, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) - tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(grad_tensor, input_tensor) - pt_y = torch.zeros_like(grad_data) - golden_tensor = [pt_y, pt_y] + tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw(grad_tensor, input_tensor, other_tensor) + in_grad = torch.zeros_like(in_data) + other_grad = torch.zeros_like(other_data) + + golden_tensor = [in_grad, other_grad] comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) +def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + _, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device) + input_grad = None + other_grad = None + if are_required_outputs[0]: + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) + + tt_output_tensor_on_device = tt_lib.tensor.binary_eq_bw( + grad_tensor, + input_tensor, + other_tensor, + are_required_outputs=are_required_outputs, + input_grad=input_grad, + other_grad=other_grad, + ) + + in_grad = torch.zeros_like(in_data) + other_grad = torch.zeros_like(other_data) + + golden_tensor = [in_grad, other_grad] + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mul.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mul.py index e6ca9dba20f..3293b3af8d6 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_mul.py @@ -34,3 +34,49 @@ def test_bw_mul(input_shapes, device): status = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert status + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) +def test_bw_mul_opt_output(input_shapes, device, are_required_outputs): + in_data_a, input_tensor_a = data_gen_with_range(input_shapes, -90, 80, device, True) + in_data_b, input_tensor_b = data_gen_with_range(input_shapes, -70, 90, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -60, 60, device) + input_a_grad = None + input_b_grad = None + + if are_required_outputs[0]: + _, input_a_grad = data_gen_with_range(input_shapes, -1, 1, device) + if are_required_outputs[1]: + _, input_b_grad = data_gen_with_range(input_shapes, -1, 1, device) + + tt_output_tensor_on_device = tt_lib.tensor.mul_bw( + grad_tensor, + input_tensor_a, + input_tensor_b, + are_required_outputs=are_required_outputs, + input_a_grad=input_a_grad, + input_b_grad=input_b_grad, + ) + + in_data_a.retain_grad() + in_data_b.retain_grad() + + pyt_y = torch.mul(in_data_a, in_data_b) + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data_a.grad, in_data_b.grad] + + status = True + for i in range(len(are_required_outputs)): + if are_required_outputs[i]: + status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) + assert status diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 81c43548804..777baf144cc 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -2,106 +2,189 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_dnn/op_library/backward/backward_ops.hpp" -#include "tt_dnn/op_library/reduce/reduce_op.hpp" -#include "tt_dnn/op_library/reshape/reshape_op.hpp" -#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" + +#include "tt_dnn/op_library/complex/complex_ops.hpp" +#include "tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_dnn/op_library/embeddings/embeddings_op.hpp" -#include "tt_numpy/functions.hpp" -#include "tt_eager/tensor/tensor_utils.hpp" #include "tt_dnn/op_library/math.hpp" +#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" +#include "tt_dnn/op_library/permute/permute_op.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_dnn/op_library/reshape/reshape_op.hpp" #include "tt_dnn/op_library/unpad/unpad_op.hpp" -#include "tt_dnn/op_library/complex/complex_ops.hpp" +#include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp" -#include "tt_dnn/op_library/permute/permute_op.hpp" +#include "tt_numpy/functions.hpp" +#include "tt_dnn/op_library/copy/copy_op.hpp" namespace tt { namespace tt_metal { +std::vector> _addalpha_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + std::vector> result; + + if (are_required_outputs.at(0)) { + if(input_grad.has_value()){ + assign(grad, input_grad.value()); + } else { + input_grad = grad; + } + result.push_back(input_grad.value()); + } else { + result.push_back(std::nullopt); + } + if (are_required_outputs.at(1)) { + if(other_grad.has_value()){ + mul(grad, full_like(grad, alpha, output_mem_config), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, other_grad.value() ); + } else { + other_grad = mul_unary(grad, alpha, output_mem_config); + } + result.push_back(other_grad.value()); + } else { + result.push_back(std::nullopt); + } -std::vector _addalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - grad_tensor.emplace_back(grad); - Tensor grad_b = mul_unary(grad, alpha, output_mem_config); - grad_tensor.emplace_back(grad_b); - - return grad_tensor; + return std::move(result); } -std::vector addalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _addalpha_bw)(grad, input, other, alpha, output_mem_config); +std::vector> addalpha_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + return operation::decorate_as_composite(__func__, _addalpha_bw)( + grad, input, other, alpha, output_mem_config, are_required_outputs, input_grad, other_grad); } -std::vector add_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _addalpha_bw)(grad, input, other, 1, output_mem_config); +std::vector> add_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + return operation::decorate_as_composite(__func__, _addalpha_bw)( + grad, input, other, 1, output_mem_config, are_required_outputs, input_grad, other_grad); } -std::vector _unary_mul_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { +std::vector _unary_mul_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor result = mul_unary(grad, scalar, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector unary_mul_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) -{ +std::vector unary_mul_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_mul_bw)(grad, input, scalar, output_mem_config); } // unary_pow: // grad_input = grad * exponent * torch.pow(input, exponent - 1) -std::vector _unary_pow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { +std::vector _unary_pow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - const float ZERO_THRESHOLD = std::numeric_limits::epsilon()*10.0f; + const float ZERO_THRESHOLD = std::numeric_limits::epsilon() * 10.0f; TT_FATAL(exponent >= 0.0, "negative exponents are not supported; use recip(pow(input,abs(exponent)))"); - if ( std::abs(exponent) < ZERO_THRESHOLD ) { - grad_tensor.emplace_back( zeros_like( input, output_mem_config) ); + if (std::abs(exponent) < ZERO_THRESHOLD) { + grad_tensor.emplace_back(zeros_like(input, output_mem_config)); return grad_tensor; } Tensor power_input = power(input, fabs(exponent - 1.0f), output_mem_config); - if ( exponent < 1.0f ) { - power_input = recip(power_input,output_mem_config); + if (exponent < 1.0f) { + power_input = recip(power_input, output_mem_config); } Tensor result = mul_unary(power_input, exponent, output_mem_config); Tensor final_result = mul(result, grad, std::nullopt, output_mem_config); - final_result = where(gte_unary(final_result, 3.4e+38, output_mem_config), std::numeric_limits::infinity(), where(lte_unary(final_result, -3.4e+38, output_mem_config), -std::numeric_limits::infinity(), final_result, output_mem_config), output_mem_config); + final_result = where( + gte_unary(final_result, 3.4e+38, output_mem_config), + std::numeric_limits::infinity(), + where( + lte_unary(final_result, -3.4e+38, output_mem_config), + -std::numeric_limits::infinity(), + final_result, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(final_result); return grad_tensor; } -std::vector unary_pow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) -{ +std::vector unary_pow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_pow_bw)(grad, input, exponent, output_mem_config); } -std::vector _unary_add_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { +std::vector _unary_add_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector unary_add_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector unary_add_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_add_bw)(grad, input, alpha, output_mem_config); } +std::vector> _mul_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_a_grad, + std::optional input_b_grad) { + std::vector> result; + + if (are_required_outputs.at(0)) { + if(input_a_grad.has_value()) { + mul(grad, input_b, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, input_a_grad.value()); + } else { + input_a_grad = mul(grad, input_b, std::nullopt, output_mem_config); + } + result.push_back(input_a_grad.value()); + } else { + result.push_back(std::nullopt); + } + if (are_required_outputs.at(1)) { + if(input_b_grad.has_value()) { + mul(grad, input_a, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, input_b_grad.value()); + } else { + input_b_grad = mul(grad, input_a, std::nullopt, output_mem_config); + } + result.push_back(input_b_grad.value()); + } else { + result.push_back(std::nullopt); + } -std::vector _mul_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor grad_a = mul(grad, input_b, std::nullopt, output_mem_config); - grad_tensor.emplace_back(grad_a); - Tensor grad_b = mul(grad, input_a, std::nullopt, output_mem_config); - grad_tensor.emplace_back(grad_b); - return grad_tensor; + return std::move(result); } -std::vector mul_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _mul_bw)(grad, input_a, input_b, output_mem_config); +std::vector> mul_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_a_grad, + std::optional input_b_grad) { + return operation::decorate_as_composite(__func__, _mul_bw)( + grad, input_a, input_b, output_mem_config, are_required_outputs, input_a_grad, input_b_grad); } - std::vector _exp_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; float t_inf = std::numeric_limits::infinity(); @@ -109,17 +192,29 @@ std::vector _exp_bw(const Tensor& grad, const Tensor& input, const Memor Tensor result = mul(grad, exp_result, std::nullopt, output_mem_config); result = where(gte_unary(result, 1e+38, output_mem_config), t_inf, result, output_mem_config); result = where(lte_unary(result, -1e+38, output_mem_config), -t_inf, result, output_mem_config); - result = where(logical_and(gte_unary(abs(exp_result, output_mem_config), 1e+38, output_mem_config),ltz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config); + result = where( + logical_and( + gte_unary(abs(exp_result, output_mem_config), 1e+38, output_mem_config), + ltz(grad, output_mem_config), + std::nullopt, + output_mem_config), + -t_inf, + result, + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector exp_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector exp_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _exp_bw)(grad, input, output_mem_config); } - -std::vector _addcmul_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config) { +std::vector _addcmul_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); Tensor grad_a = mul_unary(mul(grad, tensor2, std::nullopt, output_mem_config), value, output_mem_config); @@ -129,88 +224,131 @@ std::vector _addcmul_bw(const Tensor& grad, const Tensor& input, const T return grad_tensor; } -std::vector addcmul_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _addcmul_bw)(grad, input, tensor1, tensor2, value, output_mem_config); +std::vector addcmul_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _addcmul_bw)( + grad, input, tensor1, tensor2, value, output_mem_config); } - std::vector _unary_assign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector unary_assign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector unary_assign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_assign_bw)(grad, input, output_mem_config); } -std::vector binary_assign_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector binary_assign_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_assign_bw)(grad, input, output_mem_config); } std::vector _sqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor sqrt_result = sqrt(input, output_mem_config); - Tensor result = mul(grad, recip(mul_unary(sqrt_result, 2.0, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - float t_nan = std::nanf(""); + Tensor result = + mul(grad, + recip(mul_unary(sqrt_result, 2.0, output_mem_config), output_mem_config), + std::nullopt, + output_mem_config); + float t_nan = std::nanf(""); float t_inf = std::numeric_limits::infinity(); result = where(lez(input, output_mem_config), t_nan, result, output_mem_config); - result = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config); - result = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), t_inf, result, output_mem_config); + result = where( + logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), + -t_inf, + result, + output_mem_config); + result = where( + logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), + t_inf, + result, + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector sqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector sqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sqrt_bw)(grad, input, output_mem_config); } - -std::vector _unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { +std::vector _unary_div_bw( + const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float inv_scalar = 1.0f/scalar; - if (round_mode=="None"){ + float inv_scalar = 1.0f / scalar; + if (round_mode == "None") { Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); - if(scalar == 0.0){ - float t_nan = std::nanf(""); - grad_tensor.emplace_back( where(eqz(grad, output_mem_config), t_nan, mul( sign(grad, output_mem_config), t_inf, std::nullopt, output_mem_config), output_mem_config) ); - }else{ - grad_tensor.emplace_back( mul_unary(grad, inv_scalar, output_mem_config) ); + if (scalar == 0.0) { + float t_nan = std::nanf(""); + grad_tensor.emplace_back(where( + eqz(grad, output_mem_config), + t_nan, + mul(sign(grad, output_mem_config), t_inf, std::nullopt, output_mem_config), + output_mem_config)); + } else { + grad_tensor.emplace_back(mul_unary(grad, inv_scalar, output_mem_config)); } - } - else{ + } else { Tensor result = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(result); } return grad_tensor; } -std::vector unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _unary_div_bw)(grad, input, scalar, round_mode, output_mem_config); +std::vector unary_div_bw( + const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _unary_div_bw)( + grad, input, scalar, round_mode, output_mem_config); } - -std::vector _div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config) { +std::vector _div_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + string round_mode, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; - if (round_mode=="None"){ + if (round_mode == "None") { Tensor grad_a = mul(grad, recip(other, output_mem_config), std::nullopt, output_mem_config); Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); Tensor t_nan = full_like(input, std::nanf(""), output_mem_config); - grad_tensor.emplace_back( where(eqz(other, output_mem_config), - where(eqz(grad, output_mem_config), - t_nan, - mul(t_inf, sign(grad, output_mem_config), std::nullopt, output_mem_config), output_mem_config), - grad_a, output_mem_config)); - Tensor grad_b = mul(neg(grad, output_mem_config) , (mul(input, recip(square(other, output_mem_config), output_mem_config), std::nullopt, output_mem_config)), std::nullopt, output_mem_config); - grad_tensor.emplace_back(where(eqz(other, output_mem_config), - where(eqz(grad, output_mem_config), - t_nan, - where(eqz(input, output_mem_config), - t_nan, - mul( mul( neg(t_inf, output_mem_config), sign(input, output_mem_config), std::nullopt, output_mem_config), sign(grad, output_mem_config), std::nullopt, output_mem_config), output_mem_config), output_mem_config), - grad_b, output_mem_config)); - } else{ + grad_tensor.emplace_back(where( + eqz(other, output_mem_config), + where( + eqz(grad, output_mem_config), + t_nan, + mul(t_inf, sign(grad, output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + grad_a, + output_mem_config)); + Tensor grad_b = mul( + neg(grad, output_mem_config), + (mul(input, recip(square(other, output_mem_config), output_mem_config), std::nullopt, output_mem_config)), + std::nullopt, + output_mem_config); + grad_tensor.emplace_back(where( + eqz(other, output_mem_config), + where( + eqz(grad, output_mem_config), + t_nan, + where( + eqz(input, output_mem_config), + t_nan, + mul(mul(neg(t_inf, output_mem_config), + sign(input, output_mem_config), + std::nullopt, + output_mem_config), + sign(grad, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config), + output_mem_config), + grad_b, + output_mem_config)); + } else { Tensor grad_a = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(grad_a); Tensor grad_b = zeros_like(grad, output_mem_config); @@ -219,35 +357,65 @@ std::vector _div_bw(const Tensor& grad, const Tensor& input, const Tenso return grad_tensor; } -std::vector div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config) -{ +std::vector div_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + string round_mode, + const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _div_bw)(grad, input, other, round_mode, output_mem_config); } -std::vector _rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { +std::vector _rdiv_bw( + const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); float t_inf = std::numeric_limits::infinity(); - if (round_mode=="None"){ - Tensor result = where(nez(input), mul(neg(grad, output_mem_config) , (mul_unary(recip(square(input, output_mem_config)), scalar, output_mem_config)), std::nullopt, output_mem_config), t_nan, output_mem_config); - if (scalar>0){ - result = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), t_inf, result, output_mem_config); - result = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config); - } - else if (scalar<0){ - result = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config); - result = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), t_inf, result, output_mem_config); + if (round_mode == "None") { + Tensor result = where( + nez(input), + mul(neg(grad, output_mem_config), + (mul_unary(recip(square(input, output_mem_config)), scalar, output_mem_config)), + std::nullopt, + output_mem_config), + t_nan, + output_mem_config); + if (scalar > 0) { + result = where( + logical_and( + eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), + t_inf, + result, + output_mem_config); + result = where( + logical_and( + eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), + -t_inf, + result, + output_mem_config); + } else if (scalar < 0) { + result = where( + logical_and( + eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), + -t_inf, + result, + output_mem_config); + result = where( + logical_and( + eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), + t_inf, + result, + output_mem_config); } grad_tensor.emplace_back(result); - } - else{ + } else { Tensor result = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(result); } return grad_tensor; } -std::vector rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) -{ +std::vector rdiv_bw( + const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _rdiv_bw)(grad, input, scalar, round_mode, output_mem_config); } @@ -260,41 +428,47 @@ std::vector _tanh_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(result); return grad_tensor; } -std::vector tanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector tanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _tanh_bw)(grad, input, output_mem_config); } // grad(sigmoid) = grad*(1 - sigmoid(x))*sigmoid(x) -std::vector _sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { +std::vector _sigmoid_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { std::vector grad_tensor; Tensor sig_result = sigmoid(input, output_mem_config); Tensor rsub_term = rsub(sig_result, 1.0f, output_mem_config); - Tensor prod_term_1 = mul(sig_result, rsub_term,{},output_mem_config); - Tensor prod_term_2 = mul(prod_term_1, grad,{},output_mem_config); + Tensor prod_term_1 = mul(sig_result, rsub_term, {}, output_mem_config); + Tensor prod_term_2 = mul(prod_term_1, grad, {}, output_mem_config); grad_tensor.emplace_back(prod_term_2); return grad_tensor; } -std::vector sigmoid_bw(const Tensor& grad, const Tensor& input, - const MemoryConfig& output_mem_config) { +std::vector sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sigmoid_bw)(grad, input, output_mem_config); } - std::vector _tan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor tan_result = tan(input, output_mem_config); - Tensor result = mul(grad, add1(square(tan_result, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + Tensor result = + mul(grad, add1(square(tan_result, output_mem_config), output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector tan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector tan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _tan_bw)(grad, input, output_mem_config); } -std::vector _addcdiv_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config) { +std::vector _addcdiv_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); @@ -307,7 +481,8 @@ std::vector _addcdiv_bw(const Tensor& grad, const Tensor& input, const T output_mem_config)); Tensor tmp = mul( mul_unary(neg(grad, output_mem_config), value, output_mem_config), tensor1, std::nullopt, output_mem_config); - Tensor grad_b = mul(tmp, recip(square(tensor2, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + Tensor grad_b = + mul(tmp, recip(square(tensor2, output_mem_config), output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(where( eqz(tensor2, output_mem_config), where(eqz(grad, output_mem_config), t_nan, neg(t_inf, output_mem_config), output_mem_config), @@ -315,12 +490,23 @@ std::vector _addcdiv_bw(const Tensor& grad, const Tensor& input, const T output_mem_config)); return grad_tensor; } -std::vector addcdiv_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _addcdiv_bw)(grad, input, tensor1, tensor2, value, output_mem_config); +std::vector addcdiv_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _addcdiv_bw)( + grad, input, tensor1, tensor2, value, output_mem_config); } -std::vector _where_bw(const Tensor& grad, const Tensor& condition, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _where_bw( + const Tensor& grad, + const Tensor& condition, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor grad_a = where(condition, grad, 0.0f, output_mem_config); grad_tensor.emplace_back(grad_a); @@ -329,30 +515,43 @@ std::vector _where_bw(const Tensor& grad, const Tensor& condition, const return grad_tensor; } -std::vector where_bw(const Tensor& grad, const Tensor& condition, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector where_bw( + const Tensor& grad, + const Tensor& condition, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _where_bw)(grad, condition, input, other, output_mem_config); } -//template parameter min_or_max = TRUE for MAX, FALSE for MIN -template -std::vector _min_or_max_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +// template parameter min_or_max = TRUE for MAX, FALSE for MIN +template +std::vector _min_or_max_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { Tensor zeros_t = zeros_like(input, output_mem_config); std::vector grad_tensor; Tensor t_scale_grad = mul_unary(grad, 0.5, output_mem_config); Tensor t_sub = sub(other, input, std::nullopt, output_mem_config); - Tensor t_sub_gtz = gtz(t_sub,output_mem_config); - Tensor t_sub_eqz = eqz(t_sub,output_mem_config); - Tensor t_sub_ltz = ltz(t_sub,output_mem_config); - Tensor grad_other = add(mul(t_sub_ltz, grad,{},output_mem_config),mul(t_sub_eqz, t_scale_grad,{},output_mem_config), std::nullopt, output_mem_config); - Tensor grad_input = add(mul(t_sub_gtz, grad,{},output_mem_config),mul(t_sub_eqz, t_scale_grad,{},output_mem_config), std::nullopt, output_mem_config); + Tensor t_sub_gtz = gtz(t_sub, output_mem_config); + Tensor t_sub_eqz = eqz(t_sub, output_mem_config); + Tensor t_sub_ltz = ltz(t_sub, output_mem_config); + Tensor grad_other = + add(mul(t_sub_ltz, grad, {}, output_mem_config), + mul(t_sub_eqz, t_scale_grad, {}, output_mem_config), + std::nullopt, + output_mem_config); + Tensor grad_input = + add(mul(t_sub_gtz, grad, {}, output_mem_config), + mul(t_sub_eqz, t_scale_grad, {}, output_mem_config), + std::nullopt, + output_mem_config); if (min_or_max) { - //MAX + // MAX grad_tensor.emplace_back(grad_other); grad_tensor.emplace_back(grad_input); } else { - //MIN + // MIN grad_tensor.emplace_back(grad_input); grad_tensor.emplace_back(grad_other); } @@ -361,25 +560,23 @@ std::vector _min_or_max_bw(const Tensor& grad, const Tensor& input, cons auto _max_bw = _min_or_max_bw; auto _min_bw = _min_or_max_bw; -std::vector max_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector max_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _max_bw)(grad, input, other, output_mem_config); } -std::vector min_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector min_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _min_bw)(grad, input, other, output_mem_config); } - std::vector _fill_zero_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor result = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector fill_zero_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector fill_zero_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _fill_zero_bw)(grad, output_mem_config); } @@ -388,27 +585,31 @@ std::vector _fill_bw(const Tensor& grad, const MemoryConfig& output_mem_ Tensor val = grad; val = global_sum(val, output_mem_config); Tensor result = zeros_like(grad, output_mem_config); - result = bcast(result, val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + result = bcast(result, val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector fill_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector fill_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _fill_bw)(grad, output_mem_config); } -std::vector _embedding_bw(const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config) { +std::vector _embedding_bw( + const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config) { TT_FATAL(input.get_dtype() == DataType::UINT32, "Input must be UINT32"); - TT_FATAL(grad.get_legacy_shape()[0] == 1 && grad.get_legacy_shape()[1] == 1, "First two dimensions for the grad must be 1"); - TT_FATAL(input.get_legacy_shape()[1] == 1 && input.get_legacy_shape()[2] == 1, "Only dim 0 && 3 for the input can be non 1"); + TT_FATAL( + grad.get_legacy_shape()[0] == 1 && grad.get_legacy_shape()[1] == 1, + "First two dimensions for the grad must be 1"); + TT_FATAL( + input.get_legacy_shape()[1] == 1 && input.get_legacy_shape()[2] == 1, + "Only dim 0 && 3 for the input can be non 1"); std::vector grad_tensor; Tensor grad_a = embeddings(input, grad, false); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector embedding_bw(const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config) -{ +std::vector embedding_bw( + const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _embedding_bw)(grad, input, weight, output_mem_config); } @@ -416,20 +617,21 @@ std::vector embedding_bw(const Tensor& grad, const Tensor& input, const // self: grad // other: -grad * alpha -std::vector _subalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { +std::vector _subalpha_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); Tensor grad_b = mul_unary(neg(grad, output_mem_config), alpha, output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector subalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector subalpha_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _subalpha_bw)(grad, input, other, alpha, output_mem_config); } -std::vector sub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector sub_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _subalpha_bw)(grad, input, other, 1.0, output_mem_config); } @@ -440,8 +642,7 @@ std::vector _unary_sub_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector unary_sub_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector unary_sub_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_sub_bw)(grad, input, output_mem_config); } @@ -451,18 +652,18 @@ std::vector _neg_bw(const Tensor& grad, const Tensor& input, const Memor grad_tensor.emplace_back(result); return grad_tensor; } -std::vector neg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector neg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _neg_bw)(grad, input, output_mem_config); } -std::vector _rsub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { - std::vector grad_tensor = _subalpha_bw(grad,input,other, 1.0f, output_mem_config); - std::swap(grad_tensor[0],grad_tensor[1]); +std::vector _rsub_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { + std::vector grad_tensor = _subalpha_bw(grad, input, other, 1.0f, output_mem_config); + std::swap(grad_tensor[0], grad_tensor[1]); return grad_tensor; } -std::vector rsub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector rsub_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _rsub_bw)(grad, input, other, output_mem_config); } @@ -472,8 +673,7 @@ std::vector _lt_bw(const Tensor& grad, const MemoryConfig& output_mem_co grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector lt_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector lt_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _lt_bw)(grad, output_mem_config); } @@ -483,8 +683,7 @@ std::vector _gt_bw(const Tensor& grad, const MemoryConfig& output_mem_co grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector gt_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector gt_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _gt_bw)(grad, output_mem_config); } @@ -494,8 +693,7 @@ std::vector _ne_bw(const Tensor& grad, const MemoryConfig& output_mem_co grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector ne_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector ne_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _ne_bw)(grad, output_mem_config); } @@ -504,15 +702,18 @@ std::vector _log_bw(const Tensor& grad, const Tensor& input, const Memor Tensor grad_a = mul(grad, recip(input, output_mem_config), std::nullopt, output_mem_config); Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); Tensor t_nan = full_like(input, std::nanf(""), output_mem_config); - grad_tensor.emplace_back( where(eqz(input, output_mem_config), - where(eqz(grad, output_mem_config), - t_nan, - mul(t_inf, sign(grad, output_mem_config), std::nullopt, output_mem_config), output_mem_config), - grad_a, output_mem_config)); + grad_tensor.emplace_back(where( + eqz(input, output_mem_config), + where( + eqz(grad, output_mem_config), + t_nan, + mul(t_inf, sign(grad, output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + grad_a, + output_mem_config)); return grad_tensor; } -std::vector log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log_bw)(grad, input, output_mem_config); } @@ -522,8 +723,7 @@ std::vector _abs_bw(const Tensor& grad, const Tensor& input, const Memor grad_tensor.emplace_back(result); return grad_tensor; } -std::vector abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _abs_bw)(grad, input, output_mem_config); } @@ -535,31 +735,32 @@ std::vector _binary_le_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(zero_input); return grad_tensor; } -std::vector binary_le_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector binary_le_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _binary_le_bw)(grad, input, output_mem_config); } std::vector _rsqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor rsqrt_result = power(rsqrt(input, true, output_mem_config), 3, output_mem_config); - Tensor result = mul_unary(mul(grad, rsqrt_result, std::nullopt, output_mem_config) , -0.5, output_mem_config); + Tensor result = mul_unary(mul(grad, rsqrt_result, std::nullopt, output_mem_config), -0.5, output_mem_config); float t_inf = std::numeric_limits::infinity(); result = where(eqz(input, output_mem_config), t_inf, result, output_mem_config); - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); result = where(ltz(input, output_mem_config), t_nan, result, output_mem_config); - result = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), t_nan, result, output_mem_config); + result = where( + logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), + t_nan, + result, + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector rsqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector rsqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _rsqrt_bw)(grad, input, output_mem_config); } - -std::vector _clamp_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) -{ +std::vector _clamp_bw( + const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor minT = gte_unary(input, min, output_mem_config); Tensor maxT = lte_unary(input, max, output_mem_config); @@ -568,55 +769,54 @@ std::vector _clamp_bw(const Tensor& grad, const Tensor& input, float min grad_tensor.emplace_back(result); return grad_tensor; } -std::vector clamp_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) -{ +std::vector clamp_bw( + const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _clamp_bw)(grad, input, min, max, output_mem_config); } - -std::vector _clamp_min_bw(const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) -{ +std::vector _clamp_min_bw( + const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor minT = gte_unary(input, min, output_mem_config); Tensor result = mul(grad, minT, std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector clamp_min_bw(const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) -{ +std::vector clamp_min_bw( + const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _clamp_min_bw)(grad, input, min, output_mem_config); } - -std::vector _clamp_max_bw(const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) -{ +std::vector _clamp_max_bw( + const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor maxT = lte_unary(input, max, output_mem_config); Tensor result = mul(grad, maxT, std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector clamp_max_bw(const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) -{ +std::vector clamp_max_bw( + const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _clamp_max_bw)(grad, input, max, output_mem_config); } std::vector _relu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = mul(gtz(input,output_mem_config), grad, std::nullopt, output_mem_config); + Tensor result = mul(gtz(input, output_mem_config), grad, std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector relu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector relu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _relu_bw)(grad, input, output_mem_config); } -std::vector _atan2_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _atan2_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float t_nan = std::nanf(""); - UnaryWithParam op1 {UnaryOpType::SQUARE}; - UnaryWithParam op2 {UnaryOpType::RECIP}; - Tensor recip_mul = mul(grad, unary_chain(hypot(input,other), {op1, op2}, output_mem_config), std::nullopt, output_mem_config); + float t_nan = std::nanf(""); + UnaryWithParam op1{UnaryOpType::SQUARE}; + UnaryWithParam op2{UnaryOpType::RECIP}; + Tensor recip_mul = + mul(grad, unary_chain(hypot(input, other), {op1, op2}, output_mem_config), std::nullopt, output_mem_config); Tensor grad_a = mul(other, recip_mul, std::nullopt, output_mem_config); Tensor cond = logical_and(eqz(input, output_mem_config), eqz(other, output_mem_config)); grad_a = where(cond, t_nan, grad_a, output_mem_config); @@ -628,40 +828,41 @@ std::vector _atan2_bw(const Tensor& grad, const Tensor& input, const Ten grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector atan2_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector atan2_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _atan2_bw)(grad, input, other, output_mem_config); } -std::vector _hypot_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _hypot_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor result_recip = recip(hypot(input, other, output_mem_config), output_mem_config); - Tensor grad_a = mul(grad, mul(input, result_recip, std::nullopt, output_mem_config), std::nullopt, output_mem_config); + Tensor grad_a = + mul(grad, mul(input, result_recip, std::nullopt, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_a); - Tensor grad_b = mul(grad, mul(other, result_recip, std::nullopt, output_mem_config), std::nullopt, output_mem_config); + Tensor grad_b = + mul(grad, mul(other, result_recip, std::nullopt, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector hypot_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector hypot_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _hypot_bw)(grad, input, other, output_mem_config); } -//bw(expm1) = grad * expm1(input) + 1 +// bw(expm1) = grad * expm1(input) + 1 std::vector _expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor eresult = expm1(input, output_mem_config); - Tensor rp1 = add1(eresult , output_mem_config); + Tensor rp1 = add1(eresult, output_mem_config); Tensor result = mul(grad, rp1, std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _expm1_bw)(grad, input, output_mem_config); } - // # bw (exp2) = grad * exp2(input) * M_LN2 // # M_LN2 = 0.693147180559945309417 std::vector _exp2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { @@ -672,13 +873,13 @@ std::vector _exp2_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(result); return grad_tensor; } -std::vector exp2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector exp2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _exp2_bw)(grad, input, output_mem_config); } // lerp(input, end, weight) = self: grad * (1 - weight), end: grad * weight -std::vector _lerp(const Tensor& grad, const Tensor& input, const Tensor& end, float weight, const MemoryConfig& output_mem_config) { +std::vector _lerp( + const Tensor& grad, const Tensor& input, const Tensor& end, float weight, const MemoryConfig& output_mem_config) { std::vector grad_tensor; float sub_scalar = 1.0f - weight; Tensor result_1 = mul_unary(grad, sub_scalar, output_mem_config); @@ -687,13 +888,18 @@ std::vector _lerp(const Tensor& grad, const Tensor& input, const Tensor& grad_tensor.emplace_back(result_2); return grad_tensor; } -std::vector lerp_bw(const Tensor& grad, const Tensor& input, const Tensor& end, float weight, const MemoryConfig& output_mem_config) -{ +std::vector lerp_bw( + const Tensor& grad, const Tensor& input, const Tensor& end, float weight, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _lerp)(grad, input, end, weight, output_mem_config); } // lerp(input, end, weight) = self: grad * (1 - weight), end: grad * weight -std::vector _lerp_overload(const Tensor& grad, const Tensor& input, const Tensor& end, const Tensor& weight, const MemoryConfig& output_mem_config) { +std::vector _lerp_overload( + const Tensor& grad, + const Tensor& input, + const Tensor& end, + const Tensor& weight, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor result_1 = mul(grad, sub_unary(1.0, weight, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(result_1); @@ -701,77 +907,108 @@ std::vector _lerp_overload(const Tensor& grad, const Tensor& input, cons grad_tensor.emplace_back(result_2); return grad_tensor; } -std::vector lerp_bw(const Tensor& grad, const Tensor& input, const Tensor& end, const Tensor& weight, const MemoryConfig& output_mem_config) -{ +std::vector lerp_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& end, + const Tensor& weight, + const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _lerp_overload)(grad, input, end, weight, output_mem_config); } -std::vector _gelu_bw(const Tensor& grad, const Tensor& input, string approximate, const MemoryConfig& output_mem_config) { +std::vector _gelu_bw( + const Tensor& grad, const Tensor& input, string approximate, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - if (approximate == "tanh"){ + if (approximate == "tanh") { float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; float kKappa = 0.044715; - Tensor x_sq = mul(input , input, std::nullopt, output_mem_config); - Tensor x_cube = mul(x_sq , input, std::nullopt, output_mem_config); - Tensor inner = mul_unary(kBeta , add(input , mul_unary(kKappa , x_cube, output_mem_config)), output_mem_config); + Tensor x_sq = mul(input, input, std::nullopt, output_mem_config); + Tensor x_cube = mul(x_sq, input, std::nullopt, output_mem_config); + Tensor inner = mul_unary(kBeta, add(input, mul_unary(kKappa, x_cube, output_mem_config)), output_mem_config); Tensor tanh_inner = tanh(inner, output_mem_config); - Tensor left = mul_unary(0.5 , input, output_mem_config); - Tensor right = add_unary(1 , tanh_inner, output_mem_config); - - Tensor left_derivative = mul_unary(0.5 , right, output_mem_config); - - Tensor tanh_derivative = neg(sub_unary(mul(tanh_inner , tanh_inner, std::nullopt, output_mem_config),1, output_mem_config), output_mem_config); - Tensor inner_derivative = mul_unary(kBeta , (add_unary(1 , mul_unary(3 , mul_unary(kKappa , x_sq, output_mem_config), output_mem_config), output_mem_config))); - Tensor right_derivative = mul(mul(left , tanh_derivative, std::nullopt, output_mem_config) , inner_derivative, std::nullopt, output_mem_config); - - Tensor grad_a = mul(grad , (add(left_derivative , right_derivative)), std::nullopt, output_mem_config); + Tensor left = mul_unary(0.5, input, output_mem_config); + Tensor right = add_unary(1, tanh_inner, output_mem_config); + + Tensor left_derivative = mul_unary(0.5, right, output_mem_config); + + Tensor tanh_derivative = + neg(sub_unary(mul(tanh_inner, tanh_inner, std::nullopt, output_mem_config), 1, output_mem_config), + output_mem_config); + Tensor inner_derivative = mul_unary( + kBeta, + (add_unary( + 1, mul_unary(3, mul_unary(kKappa, x_sq, output_mem_config), output_mem_config), output_mem_config))); + Tensor right_derivative = + mul(mul(left, tanh_derivative, std::nullopt, output_mem_config), + inner_derivative, + std::nullopt, + output_mem_config); + + Tensor grad_a = mul(grad, (add(left_derivative, right_derivative)), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_a); - } - else{ + } else { float kAlpha = M_SQRT1_2; float kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; - Tensor cdf = mul_unary(0.5 , (add_unary(1 , erf(mul_unary(input , kAlpha, output_mem_config)), output_mem_config))); - Tensor pdf = mul_unary(kBeta , exp(mul_unary(mul(input , input) , -0.5), output_mem_config), output_mem_config); - Tensor grad_a = mul(grad , (add(cdf , mul(input , pdf)))); + Tensor cdf = + mul_unary(0.5, (add_unary(1, erf(mul_unary(input, kAlpha, output_mem_config)), output_mem_config))); + Tensor pdf = mul_unary(kBeta, exp(mul_unary(mul(input, input), -0.5), output_mem_config), output_mem_config); + Tensor grad_a = mul(grad, (add(cdf, mul(input, pdf)))); grad_tensor.emplace_back(grad_a); } return grad_tensor; } -std::vector gelu_bw(const Tensor& grad, const Tensor& input, string approximate, const MemoryConfig& output_mem_config) -{ +std::vector gelu_bw( + const Tensor& grad, const Tensor& input, string approximate, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _gelu_bw)(grad, input, approximate, output_mem_config); } -std::vector _bias_gelu_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, string approximate, const MemoryConfig& output_mem_config) { +std::vector _bias_gelu_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + string approximate, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor input = add(input_a, input_b); - grad_tensor = gelu_bw(grad, input, approximate=approximate); + grad_tensor = gelu_bw(grad, input, approximate = approximate); return grad_tensor; } -std::vector bias_gelu_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, string approximate, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _bias_gelu_bw)(grad, input_a, input_b, approximate, output_mem_config); +std::vector bias_gelu_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + string approximate, + const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _bias_gelu_bw)( + grad, input_a, input_b, approximate, output_mem_config); } -std::vector _bias_gelu_unary_bw(const Tensor& grad, const Tensor& input_tensor, float bias, string approximate, const MemoryConfig& output_mem_config) { +std::vector _bias_gelu_unary_bw( + const Tensor& grad, + const Tensor& input_tensor, + float bias, + string approximate, + const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor input = add_unary(input_tensor, bias); - grad_tensor = gelu_bw(grad, input, approximate=approximate); + grad_tensor = gelu_bw(grad, input, approximate = approximate); return grad_tensor; } -std::vector bias_gelu_unary_bw(const Tensor& grad, const Tensor& input, float bias, string approximate, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _bias_gelu_unary_bw)(grad, input, bias, approximate, output_mem_config); +std::vector bias_gelu_unary_bw( + const Tensor& grad, const Tensor& input, float bias, string approximate, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _bias_gelu_unary_bw)( + grad, input, bias, approximate, output_mem_config); } -std::vector _squared_difference_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _squared_difference_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor difference = sub(input, other); Tensor grad_a = mul_unary(2, mul(grad, difference, std::nullopt, output_mem_config), output_mem_config); @@ -780,18 +1017,18 @@ std::vector _squared_difference_bw(const Tensor& grad, const Tensor& inp grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector squared_difference_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector squared_difference_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _squared_difference_bw)(grad, input, other, output_mem_config); } - // torch reference // - name: ldexp(Tensor self, Tensor other) -> Tensor // self: grad * 2^other // other: grad * self * ln(2) * (2^other) // # M_LN2 = ln(2)= 0.693147180559945309417 -std::vector _ldexp_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _ldexp_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor tpow_o = mul(grad, rpow(other, 2.0, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(tpow_o); @@ -799,29 +1036,42 @@ std::vector _ldexp_bw(const Tensor& grad, const Tensor& input, const Ten grad_tensor.emplace_back(result); return grad_tensor; } -std::vector ldexp_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector ldexp_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _ldexp_bw)(grad, input, other, output_mem_config); } - -std::vector _xlogy_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _xlogy_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor grad1_result = mul(grad, log(other, output_mem_config), std::nullopt, output_mem_config); Tensor zero_tensor = full_like(other, 0.0, output_mem_config); - grad1_result = where(logical_and(eqz(input, output_mem_config), lte(other, zero_tensor, std::nullopt, output_mem_config), std::nullopt, output_mem_config) , zero_tensor, - where(ltz(other, output_mem_config), std::nanf(" "), grad1_result, output_mem_config), output_mem_config); - grad1_result = where(eq_unary(input, std::nanf(" "), output_mem_config), std::nanf(" "), grad1_result, output_mem_config); + grad1_result = where( + logical_and( + eqz(input, output_mem_config), + lte(other, zero_tensor, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + zero_tensor, + where(ltz(other, output_mem_config), std::nanf(" "), grad1_result, output_mem_config), + output_mem_config); + grad1_result = + where(eq_unary(input, std::nanf(" "), output_mem_config), std::nanf(" "), grad1_result, output_mem_config); grad_tensor.emplace_back(grad1_result); Tensor div_result = mul(input, recip(other, output_mem_config), std::nullopt, output_mem_config); - Tensor grad2_result = mul(grad, div_result , std::nullopt, output_mem_config); - grad2_result = where(eqz(other, output_mem_config), mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config), grad2_result, output_mem_config); - grad2_result = where(eq_unary(other, std::nanf(" "), output_mem_config), std::nanf(" "), grad2_result, output_mem_config); + Tensor grad2_result = mul(grad, div_result, std::nullopt, output_mem_config); + grad2_result = where( + eqz(other, output_mem_config), + mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config), + grad2_result, + output_mem_config); + grad2_result = + where(eq_unary(other, std::nanf(" "), output_mem_config), std::nanf(" "), grad2_result, output_mem_config); grad_tensor.emplace_back(grad2_result); return grad_tensor; } -std::vector xlogy_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector xlogy_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _xlogy_bw)(grad, input, other, output_mem_config); } @@ -831,9 +1081,11 @@ name: logaddexp(Tensor self, Tensor other) -> Tensor self: grad / (1 + exp(other - self)).conj() other: grad / (1 + exp(self - other)).conj() */ -std::vector _logaddexp_bw(const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _logaddexp_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor opexp = add1(exp(sub(other, input_a, std::nullopt, output_mem_config), output_mem_config), output_mem_config); + Tensor opexp = + add1(exp(sub(other, input_a, std::nullopt, output_mem_config), output_mem_config), output_mem_config); Tensor grad_a = mul(grad, recip(opexp, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_a); opexp = add1(exp(sub(input_a, other, std::nullopt, output_mem_config), output_mem_config), output_mem_config); @@ -841,8 +1093,8 @@ std::vector _logaddexp_bw(const Tensor& grad, const Tensor& input_a, con grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector logaddexp_bw(const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector logaddexp_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _logaddexp_bw)(grad, input_a, other, output_mem_config); } @@ -853,9 +1105,11 @@ self: grad / (1 + pow(2, other - self)) other: grad / (1 + pow(2, self - other)) */ -std::vector _logaddexp2_bw(const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _logaddexp2_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor oppow = add1(rpow(sub(other, input_a, std::nullopt, output_mem_config), 2, output_mem_config), output_mem_config); + Tensor oppow = + add1(rpow(sub(other, input_a, std::nullopt, output_mem_config), 2, output_mem_config), output_mem_config); Tensor grad_a = mul(grad, recip(oppow, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_a); oppow = add1(rpow(sub(input_a, other, std::nullopt, output_mem_config), 2, output_mem_config), output_mem_config); @@ -863,96 +1117,131 @@ std::vector _logaddexp2_bw(const Tensor& grad, const Tensor& input_a, co grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector logaddexp2_bw(const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector logaddexp2_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _logaddexp2_bw)(grad, input_a, other, output_mem_config); } -std::vector _concat_bw(const Tensor& grad, const Tensor& input, const Tensor& other, int dim, const MemoryConfig& output_mem_config) { +std::vector _concat_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, int dim, const MemoryConfig& output_mem_config) { std::vector grad_tensor; const Shape start_index = {0, 0, 0, 0}; - const Shape end_index = {input.get_legacy_shape()[0] - 1, input.get_legacy_shape()[1] - 1, input.get_legacy_shape()[2] - 1, input.get_legacy_shape()[3] - 1}; + const Shape end_index = { + input.get_legacy_shape()[0] - 1, + input.get_legacy_shape()[1] - 1, + input.get_legacy_shape()[2] - 1, + input.get_legacy_shape()[3] - 1}; Tensor grad_a = unpad(grad, start_index, end_index); grad_tensor.emplace_back(grad_a); Shape start_index_2 = {0, 0, 0, 0}; - if(dim == 0) - { - start_index_2 = {input.get_legacy_shape()[0], 0, 0, 0}; - } - else if(dim == 1) - { + if (dim == 0) { + start_index_2 = {input.get_legacy_shape()[0], 0, 0, 0}; + } else if (dim == 1) { start_index_2 = {input.get_legacy_shape()[0] - 1, input.get_legacy_shape()[1], 0, 0}; - } - else if(dim == 2) - { - start_index_2 = {input.get_legacy_shape()[0] - 1, input.get_legacy_shape()[1] - 1, input.get_legacy_shape()[2], 0}; - } - else if(dim == 3) - { + } else if (dim == 2) { + start_index_2 = { + input.get_legacy_shape()[0] - 1, input.get_legacy_shape()[1] - 1, input.get_legacy_shape()[2], 0}; + } else if (dim == 3) { start_index_2 = {0, 0, 0, input.get_legacy_shape()[3]}; } - const Shape end_index_2 = {grad.get_legacy_shape()[0] - 1, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1, grad.get_legacy_shape()[3] - 1}; + const Shape end_index_2 = { + grad.get_legacy_shape()[0] - 1, + grad.get_legacy_shape()[1] - 1, + grad.get_legacy_shape()[2] - 1, + grad.get_legacy_shape()[3] - 1}; Tensor grad_b = unpad(grad, start_index_2, end_index_2); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector concat_bw(const Tensor& grad, const Tensor& input, const Tensor& other, int dim, const MemoryConfig& output_mem_config) -{ +std::vector concat_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, int dim, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _concat_bw)(grad, input, other, dim, output_mem_config); } - - std::vector _hardsigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_a = where(logical_or(lte_unary(input, -3, output_mem_config), gte_unary(input, 3, output_mem_config), std::nullopt, output_mem_config), zeros_like(input, output_mem_config), mul_unary(grad, 1.0/6), output_mem_config); + Tensor grad_a = where( + logical_or( + lte_unary(input, -3, output_mem_config), + gte_unary(input, 3, output_mem_config), + std::nullopt, + output_mem_config), + zeros_like(input, output_mem_config), + mul_unary(grad, 1.0 / 6), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector hardsigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector hardsigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _hardsigmoid_bw)(grad, input, output_mem_config); } std::vector _i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; float t_inf = std::numeric_limits::infinity(); - Tensor value = mul_unary(0.5, mul(i0(input, output_mem_config), recip(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config); - Tensor result = where(ltz(input, output_mem_config), mul(grad, sub(neg(i0(input, output_mem_config), output_mem_config), value, std::nullopt, output_mem_config), std::nullopt, output_mem_config), mul(grad, sub(i0(input, output_mem_config), value, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config); - result = where(gte_unary(abs(i0(input, output_mem_config), output_mem_config), 3.4e+38, output_mem_config), t_inf, result, output_mem_config); - result = where(gte_unary(abs(result, output_mem_config), 3.4e+38, output_mem_config), t_inf, result, output_mem_config); + Tensor value = mul_unary( + 0.5, + mul(i0(input, output_mem_config), recip(input, output_mem_config), std::nullopt, output_mem_config), + output_mem_config); + Tensor result = where( + ltz(input, output_mem_config), + mul(grad, + sub(neg(i0(input, output_mem_config), output_mem_config), value, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + mul(grad, + sub(i0(input, output_mem_config), value, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); + result = where( + gte_unary(abs(i0(input, output_mem_config), output_mem_config), 3.4e+38, output_mem_config), + t_inf, + result, + output_mem_config); + result = + where(gte_unary(abs(result, output_mem_config), 3.4e+38, output_mem_config), t_inf, result, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _i0_bw)(grad, input, output_mem_config); } -std::vector _hardshrink_bw(const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { +std::vector _hardshrink_bw( + const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor hardshrink_result = hardshrink(input_tensor, lambd, output_mem_config); Tensor result = where(eqz(hardshrink_result, output_mem_config), 0.0f, grad, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector hardshrink_bw(const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config) -{ +std::vector hardshrink_bw( + const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _hardshrink_bw)(grad, input, lambd, output_mem_config); } -//softshrink -// result: torch.where(self < -lambd, grad, torch.where(self > lambd, grad, torch.tensor(0.0))) -std::vector _softshrink_bw(const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { +// softshrink +// result: torch.where(self < -lambd, grad, torch.where(self > lambd, grad, torch.tensor(0.0))) +std::vector _softshrink_bw( + const Tensor& grad, const Tensor& input_tensor, float lambd, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = where(logical_or(lt(input_tensor, full_like(input_tensor, -lambd, output_mem_config), std::nullopt, output_mem_config), gt(input_tensor, full_like(input_tensor, lambd, output_mem_config), std::nullopt, output_mem_config), std::nullopt, output_mem_config), grad, zeros_like(grad, output_mem_config), output_mem_config); + Tensor result = where( + logical_or( + lt(input_tensor, full_like(input_tensor, -lambd, output_mem_config), std::nullopt, output_mem_config), + gt(input_tensor, full_like(input_tensor, lambd, output_mem_config), std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + grad, + zeros_like(grad, output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector softshrink_bw(const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config) -{ +std::vector softshrink_bw( + const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _softshrink_bw)(grad, input, lambd, output_mem_config); } @@ -960,25 +1249,37 @@ std::vector softshrink_bw(const Tensor& grad, const Tensor& input, float // result: torch.where(input < -3,0.0,torch.where(input <= 3, grad * ((input / 3) + 0.5), grad),) std::vector _hardswish_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_result = where(lt(input, full_like(input, -3.0f), std::nullopt, output_mem_config), - 0.0, where(lte(input, full_like(input, 3.0f), std::nullopt, output_mem_config), - mul(grad, add_unary(mul_unary(input, 0.3333f, output_mem_config), 0.5f, output_mem_config), std::nullopt, output_mem_config), grad), output_mem_config); + Tensor grad_result = where( + lt(input, full_like(input, -3.0f), std::nullopt, output_mem_config), + 0.0, + where( + lte(input, full_like(input, 3.0f), std::nullopt, output_mem_config), + mul(grad, + add_unary(mul_unary(input, 0.3333f, output_mem_config), 0.5f, output_mem_config), + std::nullopt, + output_mem_config), + grad), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector hardswish_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector hardswish_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _hardswish_bw)(grad, input, output_mem_config); } // Softplus -std::vector _softplus_bw(const Tensor& grad, const Tensor& input, float beta, float threshold, const MemoryConfig& output_mem_config) { +std::vector _softplus_bw( + const Tensor& grad, const Tensor& input, float beta, float threshold, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor mul_input_beta = mul_unary(input, beta, output_mem_config); Tensor exp_beta_self = exp(mul_input_beta, output_mem_config); - Tensor sub_result = add_unary(-threshold , mul_input_beta, output_mem_config); - Tensor temp = mul(mul(grad, exp_beta_self, std::nullopt, output_mem_config), recip(add1(exp_beta_self, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + Tensor sub_result = add_unary(-threshold, mul_input_beta, output_mem_config); + Tensor temp = + mul(mul(grad, exp_beta_self, std::nullopt, output_mem_config), + recip(add1(exp_beta_self, output_mem_config), output_mem_config), + std::nullopt, + output_mem_config); Tensor grad_result = where(gtz(sub_result, output_mem_config), grad, temp, output_mem_config); mul_input_beta.deallocate(); exp_beta_self.deallocate(); @@ -987,62 +1288,92 @@ std::vector _softplus_bw(const Tensor& grad, const Tensor& input, float grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector softplus_bw(const Tensor& grad, const Tensor& input, float beta, float threshold, const MemoryConfig& output_mem_config) -{ +std::vector softplus_bw( + const Tensor& grad, const Tensor& input, float beta, float threshold, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _softplus_bw)(grad, input, beta, threshold, output_mem_config); } -std::vector _polygamma_bw(const Tensor& grad, const Tensor& input, int n, const MemoryConfig& output_mem_config) { +std::vector _polygamma_bw( + const Tensor& grad, const Tensor& input, int n, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); float pos_neg = 1.0f; if (n == 2 || n == 4 || n == 6 || n == 8 || n == 10) { pos_neg = -1.0f; } - Tensor grad_a = mul(grad, polygamma(input, (n+1), output_mem_config), std::nullopt, output_mem_config); - grad_a = where(logical_and(lte_unary(input, 0.0, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), t_nan, grad_a, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), mul_unary(full_like(input, -std::numeric_limits::infinity(), output_mem_config), pos_neg, output_mem_config), grad_a, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), mul_unary(full_like(input, std::numeric_limits::infinity(), output_mem_config), pos_neg, output_mem_config), grad_a, output_mem_config); + Tensor grad_a = mul(grad, polygamma(input, (n + 1), output_mem_config), std::nullopt, output_mem_config); + grad_a = where( + logical_and( + lte_unary(input, 0.0, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), + t_nan, + grad_a, + output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), + mul_unary( + full_like(input, -std::numeric_limits::infinity(), output_mem_config), pos_neg, output_mem_config), + grad_a, + output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), + mul_unary( + full_like(input, std::numeric_limits::infinity(), output_mem_config), pos_neg, output_mem_config), + grad_a, + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector polygamma_bw(const Tensor& grad, const Tensor& input, int n, const MemoryConfig& output_mem_config) -{ +std::vector polygamma_bw( + const Tensor& grad, const Tensor& input, int n, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _polygamma_bw)(grad, input, n, output_mem_config); } std::vector _atan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - UnaryWithParam op1 {UnaryOpType::SQUARE}; - UnaryWithParam op2 {UnaryOpType::ADD_UNARY_SFPU, 1.0f}; - UnaryWithParam op3 {UnaryOpType::RECIP}; - Tensor grad_a = mul(grad, unary_chain( input, {op1, op2, op3}, output_mem_config), std::nullopt, output_mem_config); + UnaryWithParam op1{UnaryOpType::SQUARE}; + UnaryWithParam op2{UnaryOpType::ADD_UNARY_SFPU, 1.0f}; + UnaryWithParam op3{UnaryOpType::RECIP}; + Tensor grad_a = mul(grad, unary_chain(input, {op1, op2, op3}, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector atan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector atan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _atan_bw)(grad, input, output_mem_config); } std::vector _atanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); float t_inf = std::numeric_limits::infinity(); - UnaryWithParam op1 {UnaryOpType::SQUARE}; - UnaryWithParam op2 {UnaryOpType::SUB_UNARY_SFPU, 1.0f}; - UnaryWithParam op3 {UnaryOpType::NEG}; - UnaryWithParam op4 {UnaryOpType::RECIP}; - Tensor grad_a = mul(grad, unary_chain( input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config); + UnaryWithParam op1{UnaryOpType::SQUARE}; + UnaryWithParam op2{UnaryOpType::SUB_UNARY_SFPU, 1.0f}; + UnaryWithParam op3{UnaryOpType::NEG}; + UnaryWithParam op4{UnaryOpType::RECIP}; + Tensor grad_a = + mul(grad, unary_chain(input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config); grad_a = where(eqz(grad, output_mem_config), t_nan, grad_a, output_mem_config); - grad_a = where(logical_and(eqz(grad, output_mem_config), eqz(input, output_mem_config)), 0, grad_a, output_mem_config); - grad_a = where(logical_and(logical_or(eq_unary(input, 1, output_mem_config), eq_unary(input, -1, output_mem_config), std::nullopt, output_mem_config), nez(grad, output_mem_config)), t_inf, grad_a, output_mem_config); - grad_a = where(logical_and(eq_unary(grad_a, t_inf, output_mem_config), ltz(grad, output_mem_config)), -t_inf, grad_a, output_mem_config); + grad_a = + where(logical_and(eqz(grad, output_mem_config), eqz(input, output_mem_config)), 0, grad_a, output_mem_config); + grad_a = where( + logical_and( + logical_or( + eq_unary(input, 1, output_mem_config), + eq_unary(input, -1, output_mem_config), + std::nullopt, + output_mem_config), + nez(grad, output_mem_config)), + t_inf, + grad_a, + output_mem_config); + grad_a = where( + logical_and(eq_unary(grad_a, t_inf, output_mem_config), ltz(grad, output_mem_config)), + -t_inf, + grad_a, + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector atanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector atanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _atanh_bw)(grad, input, output_mem_config); } @@ -1050,11 +1381,12 @@ std::vector atanh_bw(const Tensor& grad, const Tensor& input, const Memo // result: grad * (-self * self + 1).rsqrt() std::vector _asin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - UnaryWithParam op1 {UnaryOpType::SQUARE}; - UnaryWithParam op2 {UnaryOpType::NEG}; - UnaryWithParam op3 {UnaryOpType::ADD_UNARY_SFPU, 1.0f}; - UnaryWithParam op4 {UnaryOpType::RSQRT, true}; - Tensor grad_result = mul(grad, unary_chain( input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config); + UnaryWithParam op1{UnaryOpType::SQUARE}; + UnaryWithParam op2{UnaryOpType::NEG}; + UnaryWithParam op3{UnaryOpType::ADD_UNARY_SFPU, 1.0f}; + UnaryWithParam op4{UnaryOpType::RSQRT, true}; + Tensor grad_result = + mul(grad, unary_chain(input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config); Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); Tensor t_nan = full_like(input, std::nanf(""), output_mem_config); Tensor sub_one = add_unary(-1, input, output_mem_config); @@ -1079,8 +1411,7 @@ std::vector _asin_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(result); return grad_tensor; } -std::vector asin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector asin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _asin_bw)(grad, input, output_mem_config); } @@ -1088,15 +1419,15 @@ std::vector asin_bw(const Tensor& grad, const Tensor& input, const Memor // result: grad * (self * self + 1).rsqrt() std::vector _asinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - UnaryWithParam op1 {UnaryOpType::SQUARE}; - UnaryWithParam op2 {UnaryOpType::ADD_UNARY_SFPU, 1.0f}; - UnaryWithParam op3 {UnaryOpType::RSQRT, true}; - Tensor grad_result = mul(grad, unary_chain( input, {op1, op2, op3}, output_mem_config), std::nullopt, output_mem_config); + UnaryWithParam op1{UnaryOpType::SQUARE}; + UnaryWithParam op2{UnaryOpType::ADD_UNARY_SFPU, 1.0f}; + UnaryWithParam op3{UnaryOpType::RSQRT, true}; + Tensor grad_result = + mul(grad, unary_chain(input, {op1, op2, op3}, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector asinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector asinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _asinh_bw)(grad, input, output_mem_config); } @@ -1105,18 +1436,32 @@ std::vector asinh_bw(const Tensor& grad, const Tensor& input, const Memo std::vector _cosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor t_inf = mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config); - Tensor t_neg_inf = mul_unary(sign(grad, output_mem_config), -std::numeric_limits::infinity(), output_mem_config); - Tensor grad_a = where(gt(input, full_like(input, 88.50, output_mem_config), std::nullopt, output_mem_config), t_inf, - where(lt(input, full_like(input, -88.50, output_mem_config), std::nullopt, output_mem_config), t_neg_inf, - mul(grad, sinh(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config), output_mem_config); + Tensor t_neg_inf = + mul_unary(sign(grad, output_mem_config), -std::numeric_limits::infinity(), output_mem_config); + Tensor grad_a = where( + gt(input, full_like(input, 88.50, output_mem_config), std::nullopt, output_mem_config), + t_inf, + where( + lt(input, full_like(input, -88.50, output_mem_config), std::nullopt, output_mem_config), + t_neg_inf, + mul(grad, sinh(input, output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + output_mem_config); t_neg_inf.deallocate(); t_inf.deallocate(); - grad_a = where(gte_unary(grad_a, 3.4e+38, output_mem_config), std::numeric_limits::infinity(), where(lte_unary(grad_a, -3.4e+38, output_mem_config), -std::numeric_limits::infinity(), grad_a, output_mem_config), output_mem_config); + grad_a = where( + gte_unary(grad_a, 3.4e+38, output_mem_config), + std::numeric_limits::infinity(), + where( + lte_unary(grad_a, -3.4e+38, output_mem_config), + -std::numeric_limits::infinity(), + grad_a, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector cosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector cosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _cosh_bw)(grad, input, output_mem_config); } @@ -1124,12 +1469,12 @@ std::vector cosh_bw(const Tensor& grad, const Tensor& input, const Memor // self: grad * -self.sin() std::vector _cos_bw(const Tensor& grad, const Tensor& input_tensor, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = mul(grad, (neg(sin(input_tensor, output_mem_config), output_mem_config)), std::nullopt, output_mem_config); + Tensor result = + mul(grad, (neg(sin(input_tensor, output_mem_config), output_mem_config)), std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector cos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector cos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _cos_bw)(grad, input, output_mem_config); } @@ -1138,19 +1483,28 @@ std::vector _acosh_bw(const Tensor& grad, const Tensor& input, const Mem Tensor in_rsqrt = square(input, output_mem_config); in_rsqrt = rsqrt(sub_unary(in_rsqrt, 1.0, output_mem_config), true, output_mem_config); Tensor grad_a = mul(grad, in_rsqrt, std::nullopt, output_mem_config); - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); float t_inf = std::numeric_limits::infinity(); - Tensor cond_result = logical_or(lt(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), - gt(input, full_like(input, 1.0, output_mem_config), std::nullopt, output_mem_config), std::nullopt, output_mem_config); + Tensor cond_result = logical_or( + lt(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), + gt(input, full_like(input, 1.0, output_mem_config), std::nullopt, output_mem_config), + std::nullopt, + output_mem_config); grad_a = where(eqz(cond_result, output_mem_config), t_nan, grad_a, output_mem_config); - cond_result = logical_or(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), - eq(input, full_like(input, 1.0, output_mem_config), std::nullopt, output_mem_config), std::nullopt, output_mem_config); - grad_a = where(eq(cond_result, ones_like(input, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config); + cond_result = logical_or( + eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), + eq(input, full_like(input, 1.0, output_mem_config), std::nullopt, output_mem_config), + std::nullopt, + output_mem_config); + grad_a = where( + eq(cond_result, ones_like(input, output_mem_config), std::nullopt, output_mem_config), + t_inf, + grad_a, + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector acosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector acosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _acosh_bw)(grad, input, output_mem_config); } @@ -1159,65 +1513,84 @@ std::vector acosh_bw(const Tensor& grad, const Tensor& input, const Memo std::vector _acos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor neg_in = neg(input, output_mem_config); - Tensor in_rsqrt = rsqrt(add1(mul(neg_in, input, std::nullopt, output_mem_config), output_mem_config), true, output_mem_config); + Tensor in_rsqrt = + rsqrt(add1(mul(neg_in, input, std::nullopt, output_mem_config), output_mem_config), true, output_mem_config); in_rsqrt = neg(in_rsqrt, output_mem_config); Tensor grad_a = mul(grad, in_rsqrt, std::nullopt, output_mem_config); Tensor neg_one = full_like(input, -1.0, output_mem_config); Tensor pos_one = full_like(input, 1.0, output_mem_config); Tensor t_inf = mul_unary(sign(grad, output_mem_config), -std::numeric_limits::infinity(), output_mem_config); - grad_a = where(logical_or(lt(input, neg_one, std::nullopt, output_mem_config), - gt(input, pos_one, std::nullopt, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "), grad_a, output_mem_config); - grad_a = where(eq(input, neg_one, std::nullopt, output_mem_config), t_inf, - where(eq(input, pos_one, std::nullopt, output_mem_config), t_inf, - grad_a, output_mem_config), output_mem_config); + grad_a = where( + logical_or( + lt(input, neg_one, std::nullopt, output_mem_config), + gt(input, pos_one, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + std::nanf(" "), + grad_a, + output_mem_config); + grad_a = where( + eq(input, neg_one, std::nullopt, output_mem_config), + t_inf, + where(eq(input, pos_one, std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector acos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector acos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _acos_bw)(grad, input, output_mem_config); } // Leaky_Relu // result: torch.where(self > 0, grad_output, grad_output * negative_slope) -std::vector _leaky_relu_bw(const Tensor& grad, const Tensor& input, float negative_slope, const MemoryConfig& output_mem_config) { +std::vector _leaky_relu_bw( + const Tensor& grad, const Tensor& input, float negative_slope, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_result = where(gtz(input, output_mem_config), grad, mul_unary(grad, negative_slope, output_mem_config), output_mem_config); + Tensor grad_result = where( + gtz(input, output_mem_config), grad, mul_unary(grad, negative_slope, output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector leaky_relu_bw(const Tensor& grad, const Tensor& input, float negative_slope, const MemoryConfig& output_mem_config) -{ +std::vector leaky_relu_bw( + const Tensor& grad, const Tensor& input, float negative_slope, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _leaky_relu_bw)(grad, input, negative_slope, output_mem_config); } // ELU // result : grad * (torch.where(input >= 0, 1, alpha * torch.exp(input))) -std::vector _elu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { +std::vector _elu_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_result = where(gez(input, output_mem_config), grad, mul(grad, mul_unary(exp(input, output_mem_config), alpha, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + Tensor grad_result = where( + gez(input, output_mem_config), + grad, + mul(grad, mul_unary(exp(input, output_mem_config), alpha, output_mem_config), std::nullopt, output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector elu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector elu_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _elu_bw)(grad, input, alpha, output_mem_config); } // Hardtanh // result: torch.where((input <= min) | (input >= max), 0.0, grad) -std::vector _hardtanh_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { +std::vector _hardtanh_bw( + const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_result = where(lte(input, full_like(input, min), std::nullopt, output_mem_config), - 0.0, where(gte(input, full_like(input, max), std::nullopt, output_mem_config), - 0.0, grad), output_mem_config); + Tensor grad_result = where( + lte(input, full_like(input, min), std::nullopt, output_mem_config), + 0.0, + where(gte(input, full_like(input, max), std::nullopt, output_mem_config), 0.0, grad), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector hardtanh_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) -{ +std::vector hardtanh_bw( + const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _hardtanh_bw)(grad, input, min, max, output_mem_config); } @@ -1229,8 +1602,7 @@ std::vector _sin_bw(const Tensor& grad, const Tensor& input_tensor, cons grad_tensor.emplace_back(grad_input); return grad_tensor; } -std::vector sin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector sin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sin_bw)(grad, input, output_mem_config); } @@ -1239,33 +1611,51 @@ std::vector sin_bw(const Tensor& grad, const Tensor& input, const Memory std::vector _sinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor t_inf = mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config); - Tensor grad_a = where(gt(input, full_like(input, 88.5, output_mem_config), std::nullopt, output_mem_config), t_inf, - where(lt(input, full_like(input, -88.5, output_mem_config), std::nullopt, output_mem_config), t_inf, - mul(grad, cosh(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config), output_mem_config); + Tensor grad_a = where( + gt(input, full_like(input, 88.5, output_mem_config), std::nullopt, output_mem_config), + t_inf, + where( + lt(input, full_like(input, -88.5, output_mem_config), std::nullopt, output_mem_config), + t_inf, + mul(grad, cosh(input, output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + output_mem_config); t_inf.deallocate(); - grad_a = where(gte_unary(grad_a, 3.4e+38, output_mem_config), std::numeric_limits::infinity(), where(lte_unary(grad_a, -3.4e+38, output_mem_config), -std::numeric_limits::infinity(), grad_a, output_mem_config), output_mem_config); + grad_a = where( + gte_unary(grad_a, 3.4e+38, output_mem_config), + std::numeric_limits::infinity(), + where( + lte_unary(grad_a, -3.4e+38, output_mem_config), + -std::numeric_limits::infinity(), + grad_a, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector sinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector sinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sinh_bw)(grad, input, output_mem_config); } // Celu // result: torch.where((input > 0), grad, grad * torch.exp(input / alpha)) -std::vector _celu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { +std::vector _celu_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor div_result = mul(input, recip(full_like(input, alpha, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + Tensor div_result = mul( + input, recip(full_like(input, alpha, output_mem_config), output_mem_config), std::nullopt, output_mem_config); Tensor exp_result = exp(div_result, output_mem_config); - Tensor grad_result = where(gt(input, zeros_like( input, output_mem_config), std::nullopt, output_mem_config), - grad, mul(grad, exp_result, std::nullopt, output_mem_config), output_mem_config); + Tensor grad_result = where( + gt(input, zeros_like(input, output_mem_config), std::nullopt, output_mem_config), + grad, + mul(grad, exp_result, std::nullopt, output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector celu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector celu_bw( + const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _celu_bw)(grad, input, alpha, output_mem_config); } @@ -1277,8 +1667,7 @@ std::vector _binary_lt_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(zero_input); return grad_tensor; } -std::vector binary_lt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector binary_lt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _binary_lt_bw)(grad, input, output_mem_config); } @@ -1287,54 +1676,87 @@ std::vector binary_lt_bw(const Tensor& grad, const Tensor& input, const // for input -1 and 1: grad.sign() * inf, for input > 1 or < -1 : nan std::vector _erfinv_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = mul_unary(0.5, mul(sqrt(full_like(input, M_PI , output_mem_config), output_mem_config), mul(exp(square(erfinv(input, output_mem_config), output_mem_config), output_mem_config), grad, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + Tensor result = mul_unary( + 0.5, + mul(sqrt(full_like(input, M_PI, output_mem_config), output_mem_config), + mul(exp(square(erfinv(input, output_mem_config), output_mem_config), output_mem_config), + grad, + std::nullopt, + output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); Tensor neg_one = full_like(input, -1.0, output_mem_config); Tensor pos_one = full_like(input, 1.0, output_mem_config); Tensor t_inf = mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config); - result = where(logical_or(lt(input, neg_one, std::nullopt, output_mem_config), - gt(input, pos_one, std::nullopt, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "), result, output_mem_config); - result = where(eq(input, neg_one, std::nullopt, output_mem_config), t_inf, - where(eq(input, pos_one, std::nullopt, output_mem_config), t_inf, - result, output_mem_config), output_mem_config); + result = where( + logical_or( + lt(input, neg_one, std::nullopt, output_mem_config), + gt(input, pos_one, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + std::nanf(" "), + result, + output_mem_config); + result = where( + eq(input, neg_one, std::nullopt, output_mem_config), + t_inf, + where(eq(input, pos_one, std::nullopt, output_mem_config), t_inf, result, output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector erfinv_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector erfinv_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _erfinv_bw)(grad, input, output_mem_config); } - // bw(log10(in)) = grad/(in * 2.30258509299404568402) std::vector _log10_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits::infinity(), std::numeric_limits::infinity(), output_mem_config); - Tensor grad_a = mul(grad, recip(mul_unary(input, M_LN10, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "), - where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), output_mem_config); + Tensor t_inf = where( + ltz(grad, output_mem_config), + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + output_mem_config); + Tensor grad_a = mul( + grad, recip(mul_unary(input, M_LN10, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), + std::nanf(" "), + where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector log10_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector log10_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log10_bw)(grad, input, output_mem_config); } - // bw(log1p(in)) = grad/(in + 1) // for -1 = inf std::vector _log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits::infinity(), std::numeric_limits::infinity(), output_mem_config); + Tensor t_inf = where( + ltz(grad, output_mem_config), + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + output_mem_config); Tensor t_inp1 = add1(input, output_mem_config); Tensor grad_a = mul(grad, recip(t_inp1, output_mem_config), std::nullopt, output_mem_config); - grad_a = where(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config); - grad_a = where(logical_and(eqz(t_inp1, output_mem_config), eqz(grad, output_mem_config)), std::nanf(" "), grad_a, output_mem_config); + grad_a = where( + eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), + t_inf, + grad_a, + output_mem_config); + grad_a = where( + logical_and(eqz(t_inp1, output_mem_config), eqz(grad, output_mem_config)), + std::nanf(" "), + grad_a, + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log1p_bw)(grad, input, output_mem_config); } @@ -1346,70 +1768,88 @@ std::vector _binary_ne_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(zero_input); return grad_tensor; } -std::vector binary_ne_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector binary_ne_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _binary_ne_bw)(grad, input, output_mem_config); } std::vector _erf_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = mul_unary(M_2_SQRTPI, mul(exp(neg(square(input, output_mem_config), output_mem_config), output_mem_config), grad, std::nullopt, output_mem_config), output_mem_config); + Tensor result = mul_unary( + M_2_SQRTPI, + mul(exp(neg(square(input, output_mem_config), output_mem_config), output_mem_config), + grad, + std::nullopt, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector erf_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector erf_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _erf_bw)(grad, input, output_mem_config); } std::vector _erfc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = mul_unary(-M_2_SQRTPI, mul(exp(neg(square(input, output_mem_config), output_mem_config), output_mem_config), grad, std::nullopt, output_mem_config), output_mem_config); + Tensor result = mul_unary( + -M_2_SQRTPI, + mul(exp(neg(square(input, output_mem_config), output_mem_config), output_mem_config), + grad, + std::nullopt, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector erfc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector erfc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _erfc_bw)(grad, input, output_mem_config); } std::vector _digamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; float t_inf = std::numeric_limits::infinity(); - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); Tensor grad_a = mul(grad, polygamma(input, 1, output_mem_config), std::nullopt, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), t_nan, grad_a, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, grad_a, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), + t_nan, + grad_a, + output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), + -t_inf, + grad_a, + output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), + t_inf, + grad_a, + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector digamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector digamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _digamma_bw)(grad, input, output_mem_config); } std::vector _deg2rad_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float M_PI_180 = M_PI/180; + float M_PI_180 = M_PI / 180; Tensor grad_result = mul_unary(grad, M_PI_180, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector deg2rad_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector deg2rad_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _deg2rad_bw)(grad, input, output_mem_config); } std::vector _rad2deg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float M_180_PI = 180/M_PI; + float M_180_PI = 180 / M_PI; Tensor grad_result = mul_unary(grad, M_180_PI, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector rad2deg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector rad2deg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _rad2deg_bw)(grad, input, output_mem_config); } @@ -1417,15 +1857,21 @@ std::vector _reciprocal_bw(const Tensor& grad, const Tensor& input, cons std::vector grad_tensor; Tensor t_inf = full_like(input, std::numeric_limits::infinity(), output_mem_config); Tensor t_nan = full_like(input, std::nanf(""), output_mem_config); - grad_tensor.emplace_back( where(eqz(input, output_mem_config), - where(eqz(grad, output_mem_config), - t_nan, - mul(t_inf, neg( sign(grad, output_mem_config), output_mem_config), std::nullopt, output_mem_config), output_mem_config), - mul(neg(grad, output_mem_config), recip(square(input, output_mem_config), output_mem_config), std::nullopt, output_mem_config), output_mem_config)); + grad_tensor.emplace_back(where( + eqz(input, output_mem_config), + where( + eqz(grad, output_mem_config), + t_nan, + mul(t_inf, neg(sign(grad, output_mem_config), output_mem_config), std::nullopt, output_mem_config), + output_mem_config), + mul(neg(grad, output_mem_config), + recip(square(input, output_mem_config), output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config)); return grad_tensor; } -std::vector reciprocal_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector reciprocal_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _reciprocal_bw)(grad, input, output_mem_config); } @@ -1434,31 +1880,45 @@ std::vector _relu6_bw(const Tensor& grad, const Tensor& input, const Mem Tensor zero_tensor = zeros_like(input, output_mem_config); Tensor one_tensor = ones_like(input, output_mem_config); Tensor six_tensor = full_like(input, 6, output_mem_config); - Tensor grad_result = where(lte(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); - grad_result = where(logical_and(gtz(input, output_mem_config), lt(input , six_tensor, std::nullopt, output_mem_config), std::nullopt, output_mem_config), grad, grad_result, output_mem_config); - grad_result = where(gte(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); + Tensor grad_result = + where(lte(input, zero_tensor, std::nullopt, output_mem_config), zero_tensor, six_tensor, output_mem_config); + grad_result = where( + logical_and( + gtz(input, output_mem_config), + lt(input, six_tensor, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + grad, + grad_result, + output_mem_config); + grad_result = + where(gte(input, six_tensor, std::nullopt, output_mem_config), zero_tensor, grad_result, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _relu6_bw)(grad, input, output_mem_config); } -std::vector _rpow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { +std::vector _rpow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - float t_nan = std::nanf(""); + float t_nan = std::nanf(""); Tensor grad_result = zeros_like(input, output_mem_config); - if (exponent != 0.0){ - grad_result = mul(grad, mul_unary(pow(input, exponent - 1, output_mem_config), exponent, output_mem_config), std::nullopt, output_mem_config); + if (exponent != 0.0) { + grad_result = + mul(grad, + mul_unary(pow(input, exponent - 1, output_mem_config), exponent, output_mem_config), + std::nullopt, + output_mem_config); grad_result = where(ltz(input, output_mem_config), t_nan, grad_result, output_mem_config); } grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector rpow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) -{ +std::vector rpow_bw( + const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _rpow_bw)(grad, input, exponent, output_mem_config); } @@ -1467,14 +1927,18 @@ std::vector rpow_bw(const Tensor& grad, const Tensor& input, float expon std::vector _silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor grad_sigmoid = mul(grad, sigmoid(input, output_mem_config), std::nullopt, output_mem_config); - Tensor add_sub = add1(mul(sub_unary(1.0f, sigmoid(input, output_mem_config), output_mem_config), input, std::nullopt, output_mem_config), output_mem_config); + Tensor add_sub = add1( + mul(sub_unary(1.0f, sigmoid(input, output_mem_config), output_mem_config), + input, + std::nullopt, + output_mem_config), + output_mem_config); Tensor grad_result = mul(grad_sigmoid, add_sub, std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _silu_bw)(grad, input, output_mem_config); } @@ -1483,16 +1947,21 @@ std::vector silu_bw(const Tensor& grad, const Tensor& input, const Memor std::vector _selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor grad_lambd = mul_unary(grad, 1.0507f, output_mem_config); - Tensor grad_result = where(gtz(input, output_mem_config), grad_lambd, mul(mul_unary(grad_lambd, 1.673260f, output_mem_config), exp(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + Tensor grad_result = where( + gtz(input, output_mem_config), + grad_lambd, + mul(mul_unary(grad_lambd, 1.673260f, output_mem_config), + exp(input, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _selu_bw)(grad, input, output_mem_config); } - std::vector _binary_ge_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor zero_grad = zeros_like(grad, output_mem_config); @@ -1501,22 +1970,55 @@ std::vector _binary_ge_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(zero_input); return grad_tensor; } -std::vector binary_ge_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector binary_ge_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _binary_ge_bw)(grad, input, output_mem_config); } -std::vector _binary_eq_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { - std::vector grad_tensor; - Tensor zero_grad = zeros_like(grad, output_mem_config); - grad_tensor.emplace_back(zero_grad); - Tensor zero_input = zeros_like(input, output_mem_config); - grad_tensor.emplace_back(zero_input); - return grad_tensor; +// name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +// self: zeros_like(self) +// other: zeros_like(other) +std::vector> _binary_eq_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + std::vector> result; + + if (are_required_outputs.at(0)) { + if(input_grad.has_value()){ + assign(zeros_like(input, output_mem_config), input_grad.value()); + } else { + input_grad = zeros_like(input, output_mem_config); + } + result.push_back(input_grad.value()); + } else { + result.push_back(std::nullopt); + } + if (are_required_outputs.at(1)) { + if(other_grad.has_value()){ + assign(zeros_like(other, output_mem_config), other_grad.value()); + } else { + other_grad = zeros_like(other, output_mem_config); + } + result.push_back(other_grad.value()); + } else { + result.push_back(std::nullopt); + } + return std::move(result); } -std::vector binary_eq_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ - return operation::decorate_as_composite(__func__, _binary_eq_bw)(grad, input, output_mem_config); +std::vector> binary_eq_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config, + const std::vector& are_required_outputs, + std::optional input_grad, + std::optional other_grad) { + return operation::decorate_as_composite(__func__, _binary_eq_bw)( + grad, input, other, output_mem_config, are_required_outputs, input_grad, other_grad); } std::vector _binary_gt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { @@ -1527,8 +2029,7 @@ std::vector _binary_gt_bw(const Tensor& grad, const Tensor& input, const grad_tensor.emplace_back(zero_input); return grad_tensor; } -std::vector binary_gt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector binary_gt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _binary_gt_bw)(grad, input, output_mem_config); } @@ -1554,9 +2055,12 @@ std::vector _prod_bw( prod_result = tt::tt_metal::change_layout_to_tile(prod_result, output_mem_config); } if (all_dimensions == true) { - Tensor temp = mul(prod_result, grad, std::nullopt, output_mem_config); // result is stored in the first position - Tensor fill_tensor = tt::numpy::fill_first_val_into_tensor( temp, temp.get_dtype(), temp.get_layout(), temp.device(), output_mem_config); - Tensor all_dimension_result = mul(recip(input, output_mem_config), fill_tensor, std::nullopt, output_mem_config); + Tensor temp = + mul(prod_result, grad, std::nullopt, output_mem_config); // result is stored in the first position + Tensor fill_tensor = tt::numpy::fill_first_val_into_tensor( + temp, temp.get_dtype(), temp.get_layout(), temp.device(), output_mem_config); + Tensor all_dimension_result = + mul(recip(input, output_mem_config), fill_tensor, std::nullopt, output_mem_config); grad_tensor.emplace_back(all_dimension_result); return grad_tensor; } @@ -1567,7 +2071,8 @@ std::vector _prod_bw( std::vector after_permute_dims = {0, 3, 1, 2}; Tensor required = permute(grad, after_permute_dims, output_mem_config); const Shape start_index = {0, 0, 0, 0}; - const Shape end_index = { grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1}; + const Shape end_index = { + grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1}; Tensor new_unpad_tensor = unpad(required, start_index, end_index); after_permute_dims = {0, 2, 3, 1}; updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config); @@ -1583,7 +2088,8 @@ std::vector _prod_bw( std::vector after_permute_dims = {0, 2, 1, 3}; Tensor required = permute(grad, after_permute_dims, output_mem_config); const Shape start_index = {0, 0, 0, 0}; - const Shape end_index = { grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[3] - 1}; + const Shape end_index = { + grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[3] - 1}; Tensor new_unpad_tensor = unpad(required, start_index, end_index); updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config); if(updated_grad.get_layout()==Layout::ROW_MAJOR){ @@ -1619,7 +2125,10 @@ std::vector _prod_bw( Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); after_permute_dims = {0, 3, 1, 2}; - Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, output_mem_config); + Tensor result = permute( + bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), + after_permute_dims, + output_mem_config); Tensor grad_result = result; if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { const Shape start_index = {0, 0, 0, 0}; @@ -1647,7 +2156,10 @@ std::vector _prod_bw( std::vector after_permute_dims = {3, 1, 2, 0}; Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); - Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, output_mem_config); + Tensor result = permute( + bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), + after_permute_dims, + output_mem_config); Tensor grad_result = result; if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { const Shape start_index = {0, 0, 0, 0}; @@ -1674,20 +2186,17 @@ std::vector _square_bw(const Tensor& grad, const Tensor& input, const Me grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector square_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector square_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _square_bw)(grad, input, output_mem_config); } - std::vector _lgamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor grad_result = mul(grad, digamma(input, output_mem_config), std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector lgamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector lgamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _lgamma_bw)(grad, input, output_mem_config); } @@ -1696,8 +2205,7 @@ std::vector _frac_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector frac_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector frac_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _frac_bw)(grad, input, output_mem_config); } @@ -1707,8 +2215,7 @@ std::vector _trunc_bw(const Tensor& grad, const Tensor& input, const Mem grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector trunc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector trunc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _trunc_bw)(grad, input, output_mem_config); } @@ -1721,17 +2228,16 @@ std::vector _log_sigmoid_bw(const Tensor& grad, const Tensor& input, con Tensor in_abs = abs(input, output_mem_config); Tensor z = exp(neg(in_abs, output_mem_config), output_mem_config); - Tensor mul_z = mul(z, recip((add1(z , output_mem_config)), output_mem_config), std::nullopt, output_mem_config); + Tensor mul_z = mul(z, recip((add1(z, output_mem_config)), output_mem_config), std::nullopt, output_mem_config); Tensor mul_sign = mul(in_sign, mul_z, std::nullopt, output_mem_config); Tensor sub_max = sub(max_deriv, mul_sign, std::nullopt, output_mem_config); - Tensor grad_result = mul(grad, sub_max, std::nullopt, output_mem_config); + Tensor grad_result = mul(grad, sub_max, std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector log_sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector log_sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log_sigmoid_bw)(grad, input, output_mem_config); } @@ -1743,36 +2249,40 @@ std::vector _tanhshrink_bw(const Tensor& grad, const Tensor& input, cons grad_tensor.emplace_back(mul(grad, tanh_res, std::nullopt, output_mem_config)); return grad_tensor; } -std::vector tanhshrink_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector tanhshrink_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _tanhshrink_bw)(grad, input, output_mem_config); } -//threshold -//if input <= threshold = 0 else grad -std::vector _threshold_bw(const Tensor& grad, const Tensor& input, float threshold, float value, const MemoryConfig& output_mem_config) { +// threshold +// if input <= threshold = 0 else grad +std::vector _threshold_bw( + const Tensor& grad, const Tensor& input, float threshold, float value, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor result = where(gtz(add_unary(-threshold , input, output_mem_config), output_mem_config), grad, zeros_like( input, output_mem_config), output_mem_config); + Tensor result = where( + gtz(add_unary(-threshold, input, output_mem_config), output_mem_config), + grad, + zeros_like(input, output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector threshold_bw(const Tensor& grad, const Tensor& input, float threshold, float value, const MemoryConfig& output_mem_config) -{ +std::vector threshold_bw( + const Tensor& grad, const Tensor& input, float threshold, float value, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _threshold_bw)(grad, input, threshold, value, output_mem_config); } -std::vector _unary_eq_bw(const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) { +std::vector _unary_eq_bw( + const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor zero_grad = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(zero_grad); return grad_tensor; } -std::vector unary_eq_bw(const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) -{ +std::vector unary_eq_bw( + const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_eq_bw)(grad, input, other, output_mem_config); } - // Torch reference // # if eps is not None: // # lo = eps @@ -1788,47 +2298,73 @@ std::vector unary_eq_bw(const Tensor& grad, const Tensor& input, float o // # grad_output / (self * (1.0 - self)), // # self.new_full((), float("nan")), // # ) -std::vector _logiteps_bw(const Tensor& grad, const Tensor& input, float eps, const MemoryConfig& output_mem_config) { +std::vector _logiteps_bw( + const Tensor& grad, const Tensor& input, float eps, const MemoryConfig& output_mem_config) { std::vector grad_tensor; float low, high; low = eps; - high = 1.0 - low ; - Tensor grad_result = mul(grad, recip(mul(input, rsub(input, 1.0f, output_mem_config), std::nullopt, output_mem_config)), std::nullopt, output_mem_config); + high = 1.0 - low; + Tensor grad_result = + mul(grad, + recip(mul(input, rsub(input, 1.0f, output_mem_config), std::nullopt, output_mem_config)), + std::nullopt, + output_mem_config); Tensor t_eps = full_like(input, eps, output_mem_config); Tensor t_low = full_like(input, low, output_mem_config); Tensor t_high = full_like(input, high, output_mem_config); - Tensor ltl_gth = logical_or(lt(input, t_low, std::nullopt, output_mem_config), - gt(input, t_high, std::nullopt, output_mem_config), std::nullopt, output_mem_config); - grad_result = where(eq(ltl_gth, ones_like(input, output_mem_config), std::nullopt, output_mem_config), - where(ltz(t_eps, output_mem_config), std::nanf(" "), 0.0, output_mem_config), - where(logical_or(eq_unary(input, 0.0, output_mem_config), - eq_unary(input, 1.0, output_mem_config), std::nullopt, output_mem_config), - mul_unary(sign(grad, output_mem_config), - std::numeric_limits::infinity(), output_mem_config), grad_result, output_mem_config), output_mem_config); + Tensor ltl_gth = logical_or( + lt(input, t_low, std::nullopt, output_mem_config), + gt(input, t_high, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config); + grad_result = where( + eq(ltl_gth, ones_like(input, output_mem_config), std::nullopt, output_mem_config), + where(ltz(t_eps, output_mem_config), std::nanf(" "), 0.0, output_mem_config), + where( + logical_or( + eq_unary(input, 0.0, output_mem_config), + eq_unary(input, 1.0, output_mem_config), + std::nullopt, + output_mem_config), + mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config), + grad_result, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector logiteps_bw(const Tensor& grad, const Tensor& input, float eps, const MemoryConfig& output_mem_config) -{ +std::vector logiteps_bw( + const Tensor& grad, const Tensor& input, float eps, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _logiteps_bw)(grad, input, eps, output_mem_config); } - std::vector _logit_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_result = mul(grad, recip(mul(input, rsub(input, 1.0f, output_mem_config), std::nullopt, output_mem_config)), std::nullopt, output_mem_config); - Tensor status = logical_and(gte_unary(input, 0.0f, output_mem_config), - lte_unary(input, 1.0f, output_mem_config), std::nullopt, output_mem_config); - grad_result = where(eq(status, ones_like(input, output_mem_config), std::nullopt, output_mem_config), grad_result, std::nanf("")); - grad_result = where(logical_or(eq_unary(input, 0.0, output_mem_config), - eq_unary(input, 1.0, output_mem_config), std::nullopt, output_mem_config), - mul_unary(sign(grad, output_mem_config), - std::numeric_limits::infinity(), output_mem_config), grad_result, output_mem_config); + Tensor grad_result = + mul(grad, + recip(mul(input, rsub(input, 1.0f, output_mem_config), std::nullopt, output_mem_config)), + std::nullopt, + output_mem_config); + Tensor status = logical_and( + gte_unary(input, 0.0f, output_mem_config), + lte_unary(input, 1.0f, output_mem_config), + std::nullopt, + output_mem_config); + grad_result = where( + eq(status, ones_like(input, output_mem_config), std::nullopt, output_mem_config), grad_result, std::nanf("")); + grad_result = where( + logical_or( + eq_unary(input, 0.0, output_mem_config), + eq_unary(input, 1.0, output_mem_config), + std::nullopt, + output_mem_config), + mul_unary(sign(grad, output_mem_config), std::numeric_limits::infinity(), output_mem_config), + grad_result, + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector logit_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector logit_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _logit_bw)(grad, input, output_mem_config); } @@ -1836,15 +2372,15 @@ std::vector logit_bw(const Tensor& grad, const Tensor& input, const Memo // result = grad_data / torch.square(1 + torch.abs(input)) std::vector _softsign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - UnaryWithParam op1 {UnaryOpType::ABS}; - UnaryWithParam op2 {UnaryOpType::ADD_UNARY_SFPU, 1.0f}; - UnaryWithParam op3 {UnaryOpType::SQUARE}; - UnaryWithParam op4 {UnaryOpType::RECIP}; - grad_tensor.emplace_back( mul(grad, unary_chain( input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config)); + UnaryWithParam op1{UnaryOpType::ABS}; + UnaryWithParam op2{UnaryOpType::ADD_UNARY_SFPU, 1.0f}; + UnaryWithParam op3{UnaryOpType::SQUARE}; + UnaryWithParam op4{UnaryOpType::RECIP}; + grad_tensor.emplace_back( + mul(grad, unary_chain(input, {op1, op2, op3, op4}, output_mem_config), std::nullopt, output_mem_config)); return grad_tensor; } -std::vector softsign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector softsign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _softsign_bw)(grad, input, output_mem_config); } @@ -1854,8 +2390,7 @@ std::vector _sign_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(zero_grad); return grad_tensor; } -std::vector sign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector sign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sign_bw)(grad, input, output_mem_config); } @@ -1865,23 +2400,29 @@ std::vector _ceil_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(zero_grad); return grad_tensor; } -std::vector ceil_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector ceil_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _ceil_bw)(grad, input, output_mem_config); } // bw(log2(in)) = grad/(in * 0.69314718055994530942) std::vector _log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits::infinity(), std::numeric_limits::infinity(), output_mem_config); - Tensor grad_a = mul(grad, recip(mul_unary(input, M_LN2, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "), - where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), output_mem_config); + Tensor t_inf = where( + ltz(grad, output_mem_config), + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + output_mem_config); + Tensor grad_a = mul( + grad, recip(mul_unary(input, M_LN2, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + grad_a = where( + logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), + std::nanf(" "), + where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; } -std::vector log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _log2_bw)(grad, input, output_mem_config); } std::vector _ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { @@ -1890,48 +2431,47 @@ std::vector _ge_bw(const Tensor& grad, const MemoryConfig& output_mem_co grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _ge_bw)(grad, output_mem_config); } - std::vector _le_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor t_zero = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector le_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector le_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _le_bw)(grad, output_mem_config); } - -std::vector _unary_fmod_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { +std::vector _unary_fmod_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector unary_fmod_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) -{ +std::vector unary_fmod_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_fmod_bw)(grad, input, scalar, output_mem_config); } -std::vector _unary_remainder_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { +std::vector _unary_remainder_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; grad_tensor.emplace_back(grad); return grad_tensor; } -std::vector unary_remainder_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) -{ +std::vector unary_remainder_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_remainder_bw)(grad, input, scalar, output_mem_config); } -#define CHECK_FOR_COMPLEX(input) do {\ - TT_ASSERT( utility::is_complex_shape(input), "works for complex shape only"); \ - /* TT_ASSERT( input.shape()[0] == 1, "tensor should have batch size 1"); */ \ - } while(0); +#define CHECK_FOR_COMPLEX(input) \ + do { \ + TT_ASSERT(utility::is_complex_shape(input), "works for complex shape only"); \ + /* TT_ASSERT( input.shape()[0] == 1, "tensor should have batch size 1"); */ \ + } while (0); // complex conj // self: grad.conj() @@ -1943,8 +2483,7 @@ std::vector _conj_bw(const Tensor& grad, const Tensor& input, const Memo grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector conj_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector conj_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _conj_bw)(grad, input, output_mem_config); } @@ -1954,20 +2493,32 @@ std::vector _complex_recip_bw(const Tensor& grad, const Tensor& input, c CHECK_FOR_COMPLEX(input); CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; - Tensor input_r = real(input,output_mem_config); - Tensor input_i = imag(input,output_mem_config); - Tensor condition_nan = logical_and(eqz(input_r,output_mem_config), eqz(input_i,output_mem_config), std::nullopt, output_mem_config); + Tensor input_r = real(input, output_mem_config); + Tensor input_i = imag(input, output_mem_config); + Tensor condition_nan = + logical_and(eqz(input_r, output_mem_config), eqz(input_i, output_mem_config), std::nullopt, output_mem_config); input_r.deallocate(); input_i.deallocate(); Tensor nan_flag = mk_complex(condition_nan, condition_nan, output_mem_config); condition_nan.deallocate(); - Tensor grad_result = where(nan_flag, full_like(input, std::nanf(""), output_mem_config), complex_mul(neg(grad, output_mem_config), conj(complex_mul(complex_recip(input, output_mem_config), complex_recip(input, output_mem_config), output_mem_config), output_mem_config), output_mem_config), output_mem_config) ; + Tensor grad_result = where( + nan_flag, + full_like(input, std::nanf(""), output_mem_config), + complex_mul( + neg(grad, output_mem_config), + conj( + complex_mul( + complex_recip(input, output_mem_config), + complex_recip(input, output_mem_config), + output_mem_config), + output_mem_config), + output_mem_config), + output_mem_config); nan_flag.deallocate(); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector complex_recip_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector complex_recip_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_recip_bw)(grad, input, output_mem_config); } @@ -1976,12 +2527,12 @@ std::vector complex_recip_bw(const Tensor& grad, const Tensor& input, co std::vector _imag_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); std::vector grad_tensor; - Tensor grad_result = mk_complex(zeros_like(real(input, output_mem_config), output_mem_config), grad, output_mem_config) ; + Tensor grad_result = + mk_complex(zeros_like(real(input, output_mem_config), output_mem_config), grad, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector imag_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector imag_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _imag_bw)(grad, input, output_mem_config); } @@ -1990,26 +2541,41 @@ std::vector imag_bw(const Tensor& grad, const Tensor& input, const Memor std::vector _real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); std::vector grad_tensor; - Tensor grad_result = mk_complex(grad, zeros_like(imag(input, output_mem_config), output_mem_config), output_mem_config); + Tensor grad_result = + mk_complex(grad, zeros_like(imag(input, output_mem_config), output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _real_bw)(grad, input, output_mem_config); } // angle at::where(self == 0.0, at::zeros({}, self.options()), grad * self / self.abs().pow(2) -std::vector _angle_bw(const Tensor& grad, const Tensor& input, bool is_complextensor, const MemoryConfig& output_mem_config) { +std::vector _angle_bw( + const Tensor& grad, const Tensor& input, bool is_complextensor, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - if(is_complextensor){ + if (is_complextensor) { CHECK_FOR_COMPLEX(input); Tensor inp_r = real(input, output_mem_config); Tensor inp_i = imag(input, output_mem_config); - Tensor condition_zero = logical_and(eqz(inp_r,output_mem_config), eqz(inp_i,output_mem_config), std::nullopt, output_mem_config); - Tensor abs_squared = recip(add(square(inp_r, output_mem_config), square(inp_i, output_mem_config), std::nullopt, output_mem_config), output_mem_config); - Tensor real = where(condition_zero, zeros_like(inp_r, output_mem_config), mul(grad, mul(neg(inp_i, output_mem_config), abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config); - Tensor imag = where(condition_zero, zeros_like(inp_i, output_mem_config), mul(grad, mul(inp_r, abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + Tensor condition_zero = + logical_and(eqz(inp_r, output_mem_config), eqz(inp_i, output_mem_config), std::nullopt, output_mem_config); + Tensor abs_squared = recip( + add(square(inp_r, output_mem_config), square(inp_i, output_mem_config), std::nullopt, output_mem_config), + output_mem_config); + Tensor real = where( + condition_zero, + zeros_like(inp_r, output_mem_config), + mul(grad, + mul(neg(inp_i, output_mem_config), abs_squared, std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); + Tensor imag = where( + condition_zero, + zeros_like(inp_i, output_mem_config), + mul(grad, mul(inp_r, abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), + output_mem_config); condition_zero.deallocate(); abs_squared.deallocate(); inp_r.deallocate(); @@ -2018,15 +2584,14 @@ std::vector _angle_bw(const Tensor& grad, const Tensor& input, bool is_c real.deallocate(); imag.deallocate(); grad_tensor.emplace_back(grad_result); - } - else { + } else { Tensor grad_result = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(grad_result); } return grad_tensor; } -std::vector angle_bw(const Tensor& grad, const Tensor& input, bool is_complextensor, const MemoryConfig& output_mem_config) -{ +std::vector angle_bw( + const Tensor& grad, const Tensor& input, bool is_complextensor, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _angle_bw)(grad, input, is_complextensor, output_mem_config); } @@ -2038,12 +2603,18 @@ std::vector _complex_abs_bw(const Tensor& grad, const Tensor& input, con Tensor result = complex_abs(input, output_mem_config); result = mk_complex(result, result, output_mem_config); Tensor grad_c = mk_complex(grad, grad, output_mem_config); - Tensor grad_result = where(eqz(result, output_mem_config), zeros_like(result, output_mem_config), mul(grad_c, mul(input, recip(result, output_mem_config), std::nullopt, output_mem_config),std::nullopt, output_mem_config), output_mem_config ); + Tensor grad_result = where( + eqz(result, output_mem_config), + zeros_like(result, output_mem_config), + mul(grad_c, + mul(input, recip(result, output_mem_config), std::nullopt, output_mem_config), + std::nullopt, + output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector complex_abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector complex_abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_abs_bw)(grad, input, output_mem_config); } // polar @@ -2051,18 +2622,28 @@ std::vector complex_abs_bw(const Tensor& grad, const Tensor& input, cons // result_mul_1_j = result * torch.tensor(0.0 + 1.0j) // grad_angle = torch.real(grad_conj * result_mul_1_j) // polar fwd op uses sin and cos hence input_b range is (0, 2*pi) -std::vector _polar_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { +std::vector _polar_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; Tensor result = polar(input_a, input_b, output_mem_config); Tensor abs_result = complex_abs(result, output_mem_config); abs_result = mk_complex(abs_result, abs_result, output_mem_config); - Tensor sgn_result = where(eqz(abs_result, output_mem_config), zeros_like(result, output_mem_config), mul(result, recip(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config ); + Tensor sgn_result = where( + eqz(abs_result, output_mem_config), + zeros_like(result, output_mem_config), + mul(result, recip(abs_result, output_mem_config), std::nullopt, output_mem_config), + output_mem_config); abs_result.deallocate(); - Tensor grad_abs = real(complex_mul(conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config); + Tensor grad_abs = + real(complex_mul(conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config); sgn_result.deallocate(); - Tensor flip_tensor = mk_complex(zeros_like(input_a, output_mem_config), full_like(input_b, 1.0, output_mem_config), output_mem_config); - Tensor grad_angle = real(complex_mul(conj(grad, output_mem_config), complex_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config); + Tensor flip_tensor = mk_complex( + zeros_like(input_a, output_mem_config), full_like(input_b, 1.0, output_mem_config), output_mem_config); + Tensor grad_angle = real( + complex_mul( + conj(grad, output_mem_config), complex_mul(result, flip_tensor, output_mem_config), output_mem_config), + output_mem_config); result.deallocate(); flip_tensor.deallocate(); Tensor grad_result = mk_complex(grad_abs, grad_angle, output_mem_config); @@ -2071,92 +2652,108 @@ std::vector _polar_bw(const Tensor& grad, const Tensor& input_a, const T grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector polar_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) -{ +std::vector polar_bw( + const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _polar_bw)(grad, input_a, input_b, output_mem_config); } // complex div // self: grad / other.conj(); // other: -grad * ((self / other) / other).conj(); -std::vector _complex_div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _complex_div_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); CHECK_FOR_COMPLEX(other); CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; - Tensor other_r = real(other,output_mem_config); - Tensor other_i = imag(other,output_mem_config); - Tensor condition_nan = logical_and(eqz(other_r,output_mem_config), eqz(other_i,output_mem_config), std::nullopt, output_mem_config); + Tensor other_r = real(other, output_mem_config); + Tensor other_i = imag(other, output_mem_config); + Tensor condition_nan = + logical_and(eqz(other_r, output_mem_config), eqz(other_i, output_mem_config), std::nullopt, output_mem_config); other_r.deallocate(); other_i.deallocate(); Tensor nan_flag = mk_complex(condition_nan, condition_nan, output_mem_config); condition_nan.deallocate(); - Tensor grad_a = where(nan_flag, full_like(input, std::nanf(""), output_mem_config), complex_div(grad, conj(other,output_mem_config), output_mem_config), output_mem_config); + Tensor grad_a = where( + nan_flag, + full_like(input, std::nanf(""), output_mem_config), + complex_div(grad, conj(other, output_mem_config), output_mem_config), + output_mem_config); grad_tensor.emplace_back(grad_a); Tensor result = complex_div(input, other, output_mem_config); - Tensor grad_b = where(nan_flag, full_like(input, std::nanf(""), output_mem_config), complex_mul(neg(grad,output_mem_config), conj(complex_div(result, other, output_mem_config ),output_mem_config), output_mem_config), output_mem_config); + Tensor grad_b = where( + nan_flag, + full_like(input, std::nanf(""), output_mem_config), + complex_mul( + neg(grad, output_mem_config), + conj(complex_div(result, other, output_mem_config), output_mem_config), + output_mem_config), + output_mem_config); result.deallocate(); nan_flag.deallocate(); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector complex_div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector complex_div_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_div_bw)(grad, input, other, output_mem_config); } // complex mul // grad_input = grad * other.conj() // grad_other = grad * input.conj() -std::vector _complex_mul_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { +std::vector _complex_mul_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); CHECK_FOR_COMPLEX(other); CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; - Tensor grad_a = complex_mul(grad, conj(other,output_mem_config), output_mem_config); + Tensor grad_a = complex_mul(grad, conj(other, output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_a); - Tensor grad_b = complex_mul(grad, conj(input,output_mem_config), output_mem_config); + Tensor grad_b = complex_mul(grad, conj(input, output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector complex_mul_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) -{ +std::vector complex_mul_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_mul_bw)(grad, input, other, output_mem_config); } // complex add // self: grad, other: grad * alpha -std::vector _complex_add_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { +std::vector _complex_add_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); CHECK_FOR_COMPLEX(other); CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; grad_tensor.emplace_back(grad); - Tensor grad_b = mul_unary(grad, alpha, output_mem_config ); + Tensor grad_b = mul_unary(grad, alpha, output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector complex_add_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector complex_add_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_add_bw)(grad, input, other, alpha, output_mem_config); } // complex sub // self: grad, other: -grad * alpha -std::vector _complex_sub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { +std::vector _complex_sub_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { CHECK_FOR_COMPLEX(input); CHECK_FOR_COMPLEX(other); CHECK_FOR_COMPLEX(grad); std::vector grad_tensor; grad_tensor.emplace_back(grad); - UnaryWithParam op1 {UnaryOpType::NEG}; - UnaryWithParam op2 {UnaryOpType::MUL_UNARY_SFPU, alpha}; - Tensor grad_b = unary_chain( grad, {op1, op2}, output_mem_config); + UnaryWithParam op1{UnaryOpType::NEG}; + UnaryWithParam op2{UnaryOpType::MUL_UNARY_SFPU, alpha}; + Tensor grad_b = unary_chain(grad, {op1, op2}, output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } -std::vector complex_sub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) -{ +std::vector complex_sub_bw( + const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _complex_sub_bw)(grad, input, other, alpha, output_mem_config); } #undef CHECK_FOR_COMPLEX @@ -2164,70 +2761,75 @@ std::vector complex_sub_bw(const Tensor& grad, const Tensor& input, cons std::vector _multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor digamma_result = mul(grad, digamma(input, output_mem_config), std::nullopt, output_mem_config); - Tensor digamma_result_2 = mul(grad, digamma(add_unary(-0.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + Tensor digamma_result_2 = mul( + grad, digamma(add_unary(-0.5, input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); Tensor grad_result = add(digamma_result, digamma_result_2, std::nullopt, output_mem_config); - digamma_result = mul(grad, digamma(add_unary(-1.0 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + digamma_result = mul( + grad, digamma(add_unary(-1.0, input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config); - digamma_result = mul(grad, digamma(add_unary(-1.5 , input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + digamma_result = mul( + grad, digamma(add_unary(-1.5, input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); grad_result = add(grad_result, digamma_result, std::nullopt, output_mem_config); grad_tensor.emplace_back(grad_result); return grad_tensor; } -std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) -{ +std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config); } // Repeat Backward -std::vector _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) { +std::vector _repeat_bw( + const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) { std::vector grad_tensor; auto shape_wh = input.get_legacy_shape(); - TT_FATAL( shape_wh[0] == 1 && "input shape[0] should be 1"); + TT_FATAL(shape_wh[0] == 1 && "input shape[0] should be 1"); // input.get_legacy_shape()[0] // If repeat shape has 0's, it returns zeros of given input if (shape[0] == 0 || shape[1] == 0 || shape[2] == 0 || shape[3] == 0) { Tensor zero_tensor = zeros_like(input, output_mem_config); grad_tensor.emplace_back(zero_tensor); return grad_tensor; - } - else if (shape[0] > 1){ + } else if (shape[0] > 1) { std::vector dim = {0}; - TT_FATAL( shape[1] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[1], [2], [3] should be 1"); + TT_FATAL(shape[1] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[1], [2], [3] should be 1"); Shape required = {1, shape_wh[1], shape_wh[2], shape_wh[3]}; - Tensor result = tt::operations::primary::moreh_sum(grad, dim, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), output_mem_config); + Tensor result = tt::operations::primary::moreh_sum( + grad, + dim, + zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; - } - else if (shape[1] > 1) - { + } else if (shape[1] > 1) { std::vector dim = {1}; - TT_FATAL( shape[0] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[0], [2], [3] should be 1"); + TT_FATAL(shape[0] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[0], [2], [3] should be 1"); Shape required = {shape_wh[0], 1, shape_wh[2], shape_wh[3]}; - Tensor result = tt::operations::primary::moreh_sum(grad, dim, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), output_mem_config); + Tensor result = tt::operations::primary::moreh_sum( + grad, + dim, + zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), + output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } return grad_tensor; - } -std::vector repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) -{ +std::vector repeat_bw( + const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config); } - std::vector _floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor t_zero = zeros_like(grad, output_mem_config); grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _floor_bw)(grad, output_mem_config); } @@ -2237,24 +2839,25 @@ std::vector _round_bw(const Tensor& grad, const MemoryConfig& output_mem grad_tensor.emplace_back(t_zero); return grad_tensor; } -std::vector round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) -{ +std::vector round_bw(const Tensor& grad, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _round_bw)(grad, output_mem_config); } -std::vector _unary_div_no_nan_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { +std::vector _unary_div_no_nan_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor zeros = zeros_like(grad, output_mem_config); Tensor val = full_like(input, scalar, output_mem_config); - Tensor result = where(eq_unary(val, 0, output_mem_config), zeros, mul_unary(grad, 1/scalar, output_mem_config), output_mem_config); + Tensor result = where( + eq_unary(val, 0, output_mem_config), zeros, mul_unary(grad, 1 / scalar, output_mem_config), output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } -std::vector unary_div_no_nan_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) -{ +std::vector unary_div_no_nan_bw( + const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _unary_div_no_nan_bw)(grad, input, scalar, output_mem_config); } -}//namespace tt_metal +} // namespace tt_metal -}//namespace tt +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index f5a4ddce68b..89573265df2 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -15,268 +15,727 @@ namespace tt { namespace tt_metal { -std::vector addalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector addcmul_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_mul_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_add_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_pow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector addcdiv_bw(const Tensor& grad, const Tensor& input, const Tensor& tensor1, const Tensor& tensor2, float value, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector mul_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector add_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector exp_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector sqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_assign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_assign_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector rdiv_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector max_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector min_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector embedding_bw(const Tensor& grad, const Tensor& input, const Tensor& weight, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector> addalpha_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); + +std::vector addcmul_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_mul_bw( + const Tensor& grad, + const Tensor& input, + float scalar, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_add_bw( + const Tensor& grad, + const Tensor& input, + float alpha, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_pow_bw( + const Tensor& grad, + const Tensor& input, + float exponent, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector addcdiv_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& tensor1, + const Tensor& tensor2, + float value, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector> mul_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_a_grad = std::nullopt, + std::optional input_b_grad = std::nullopt); + +std::vector> add_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); + +std::vector exp_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector sqrt_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_assign_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector binary_assign_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_div_bw( + const Tensor& grad, + const Tensor& input, + float scalar, + string round_mode, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector div_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + string round_mode, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector rdiv_bw( + const Tensor& grad, + const Tensor& input, + float scalar, + string round_mode, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector max_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector min_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector embedding_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& weight, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); // bw = grad(1 - tanh(x) ** 2) -std::vector tanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector tanh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); // grad(sigmoid) = grad*(1 - sigmoid(x))*sigmoid(x) -std::vector sigmoid_bw(const Tensor& grad, const Tensor& esinput, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector tan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector where_bw(const Tensor& grad, const Tensor& condition, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector fill_zero_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector fill_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector sub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_sub_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector rsub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector log_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_le_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_abs_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector rsqrt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector neg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector relu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector lt_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector gt_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector ne_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector clamp_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector clamp_min_bw(const Tensor& grad, const Tensor& input, float min, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector clamp_max_bw(const Tensor& grad, const Tensor& input, float max, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector atan2_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector hypot_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector exp2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector gelu_bw(const Tensor& grad, const Tensor& input, string approximate, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector bias_gelu_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, string approximate, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector bias_gelu_unary_bw(const Tensor& grad, const Tensor& input, float bias, string approximate, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector squared_difference_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector sigmoid_bw( + const Tensor& grad, + const Tensor& esinput, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector tan_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector where_bw( + const Tensor& grad, + const Tensor& condition, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector fill_zero_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector fill_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector sub_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_sub_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector rsub_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector log_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector binary_le_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector abs_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_abs_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector rsqrt_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector neg_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector relu_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector lt_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector gt_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector ne_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector clamp_bw( + const Tensor& grad, + const Tensor& input, + float min, + float max, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector clamp_min_bw( + const Tensor& grad, + const Tensor& input, + float min, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector clamp_max_bw( + const Tensor& grad, + const Tensor& input, + float max, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector atan2_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector hypot_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector exp2_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector expm1_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector gelu_bw( + const Tensor& grad, + const Tensor& input, + string approximate, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector bias_gelu_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + string approximate, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector bias_gelu_unary_bw( + const Tensor& grad, + const Tensor& input, + float bias, + string approximate, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector squared_difference_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); // lerp(input, end, weight) = self: grad * (1 - weight), end: grad * weight, weight is float -std::vector lerp_bw(const Tensor& grad, const Tensor& input, const Tensor& end, float weight, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector lerp_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& end, + float weight, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); // lerp(input, end, weight) = self: grad * (1 - weight), end: grad * weight, weight is tensor -std::vector lerp_bw(const Tensor& grad, const Tensor& input, const Tensor& end, const Tensor& weight, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector ldexp_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector xlogy_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector logaddexp_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector logaddexp2_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector concat_bw(const Tensor& grad, const Tensor& input, const Tensor& other, int dim, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector hardsigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector hardshrink_bw(const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector softshrink_bw(const Tensor& grad, const Tensor& input, float lambd, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector hardswish_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector softplus_bw(const Tensor& grad, const Tensor& input, float beta=1.0, float threshold=20.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector polygamma_bw(const Tensor& grad, const Tensor& input, int n, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector atan_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector atanh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector asin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector asinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector cosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector cos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector acosh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector acos_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector erfinv_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector leaky_relu_bw(const Tensor& grad, const Tensor& input, float negative_slope, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector elu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector hardtanh_bw(const Tensor& grad, const Tensor& input, float min, float max, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector angle_bw(const Tensor& grad, const Tensor& input, bool is_complextensor = true, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector sin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector sinh_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector celu_bw(const Tensor& grad, const Tensor& input, float alpha, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_lt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector subalpha_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector log10_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_ne_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector erf_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector erfc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector digamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector deg2rad_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector rad2deg_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector reciprocal_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector relu6_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector rpow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector silu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector selu_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_ge_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_eq_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector binary_gt_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector square_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector lgamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector frac_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector trunc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector prod_bw(const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector log_sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector tanhshrink_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector threshold_bw(const Tensor& grad, const Tensor& input, float threshold, float value, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_eq_bw(const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector logit_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector logiteps_bw(const Tensor& grad, const Tensor& input, float eps=0.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector softsign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector sign_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector ceil_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector ge_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector le_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_fmod_bw(const Tensor& grad, const Tensor& input, float eps=0.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_remainder_bw(const Tensor& grad, const Tensor& input, float eps=0.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector conj_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_recip_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector imag_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_mul_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector polar_bw(const Tensor& grad, const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_add_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector complex_sub_bw(const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config); - -std::vector floor_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector round_bw(const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -std::vector unary_div_no_nan_bw(const Tensor& grad, const Tensor& input, float scalar=1.0f, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -} //namespace tt_metal - -} //namespace tt +std::vector lerp_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& end, + const Tensor& weight, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector ldexp_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector xlogy_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector logaddexp_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector logaddexp2_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector concat_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + int dim, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector hardsigmoid_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector i0_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector hardshrink_bw( + const Tensor& grad, + const Tensor& input, + float lambd, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector softshrink_bw( + const Tensor& grad, + const Tensor& input, + float lambd, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector hardswish_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector softplus_bw( + const Tensor& grad, + const Tensor& input, + float beta = 1.0, + float threshold = 20.0, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector polygamma_bw( + const Tensor& grad, + const Tensor& input, + int n, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector atan_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector atanh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector asin_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector asinh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector cosh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector cos_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector acosh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector acos_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector erfinv_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector leaky_relu_bw( + const Tensor& grad, + const Tensor& input, + float negative_slope, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector elu_bw( + const Tensor& grad, + const Tensor& input, + float alpha, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector hardtanh_bw( + const Tensor& grad, + const Tensor& input, + float min, + float max, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector angle_bw( + const Tensor& grad, + const Tensor& input, + bool is_complextensor = true, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector sin_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector sinh_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector celu_bw( + const Tensor& grad, + const Tensor& input, + float alpha, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector binary_lt_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector subalpha_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha = 1.0, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector log10_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector log1p_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector binary_ne_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector erf_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector erfc_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector digamma_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector deg2rad_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector rad2deg_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector reciprocal_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector relu6_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector rpow_bw( + const Tensor& grad, + const Tensor& input, + float exponent, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector silu_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector selu_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector binary_ge_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector> binary_eq_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + const std::vector& are_required_outputs = std::vector{true, true}, + std::optional input_grad = std::nullopt, + std::optional other_grad = std::nullopt); + +std::vector binary_gt_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector square_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector lgamma_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector frac_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector trunc_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector prod_bw( + const Tensor& grad, + const Tensor& input, + bool all_dimensions, + int64_t dim, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector log_sigmoid_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector tanhshrink_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector threshold_bw( + const Tensor& grad, + const Tensor& input, + float threshold, + float value, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_eq_bw( + const Tensor& grad, + const Tensor& input, + float other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector logit_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector logiteps_bw( + const Tensor& grad, + const Tensor& input, + float eps = 0.0f, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector softsign_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector sign_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector ceil_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector log2_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector ge_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector le_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_fmod_bw( + const Tensor& grad, + const Tensor& input, + float eps = 0.0f, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_remainder_bw( + const Tensor& grad, + const Tensor& input, + float eps = 0.0f, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector conj_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_recip_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector imag_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector real_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_mul_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_div_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector polar_bw( + const Tensor& grad, + const Tensor& input_a, + const Tensor& input_b, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_add_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha = 1.0, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector complex_sub_bw( + const Tensor& grad, + const Tensor& input, + const Tensor& other, + float alpha = 1.0, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector multigammaln_bw( + const Tensor& grad, + const Tensor& input, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector repeat_bw( + const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config); + +std::vector floor_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector round_bw( + const Tensor& grad, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +std::vector unary_div_no_nan_bw( + const Tensor& grad, + const Tensor& input, + float scalar = 1.0f, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index cbd207f7314..dfe9a71320a 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -10,7 +10,7 @@ namespace tt::tt_metal::detail{ void TensorModuleBackwardOPs( py::module & m_tensor){ m_tensor.def("addalpha_bw", &tt::tt_metal::addalpha_bw, - py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("are_required_outputs").noconvert() = std::vector{true, true}, py::arg("input_grad").noconvert() = std::nullopt,py::arg("other_grad").noconvert() = std::nullopt, R"doc( Performs backward operations for multiplication of ``input_b`` and ``alpha`` tensors with given ``grad``. Input tensor must have BFLOAT16 data type. @@ -25,6 +25,9 @@ namespace tt::tt_metal::detail{ "input_b", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "alpha", "Alpha value", "float", "default to 1.0f", "No" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "are_required_outputs", "Boolean values for the required outputs: input_a_grad, input_b_grad ", "List of bool", "Default value is [True, True]", "No" + "input_grad", "Optional Output Tensor for input_grad", "Tensor", "Default value is None", "No" + "other_grad", "Optional Output Tensor for other_grad", "Tensor", "Default value is None", "No" )doc"); m_tensor.def("conj_bw", py::overload_cast(&conj_bw), @@ -77,7 +80,7 @@ namespace tt::tt_metal::detail{ )doc"); m_tensor.def("mul_bw", &tt::tt_metal::mul_bw, - py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("are_required_outputs").noconvert() = std::vector{true, true}, py::arg("input_a_grad").noconvert() = std::nullopt,py::arg("input_b_grad").noconvert() = std::nullopt, R"doc( Performs backward operations for multiplication of two input tensors with given ``grad`` Input tensors must have BFLOAT16 data type. @@ -91,6 +94,9 @@ namespace tt::tt_metal::detail{ "input_a", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "input_b", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "are_required_outputs", "Boolean values for the required outputs: input_a_grad, input_b_grad ", "List of bool", "Default value is [True, True]", "No" + "input_grad", "Optional Output Tensor for input_a gradient", "Tensor", "Default value is None", "No" + "other_grad", "Optional Output Tensor for input_b gradient", "Tensor", "Default value is None", "No" )doc"); m_tensor.def("exp_bw", &tt::tt_metal::exp_bw, @@ -266,7 +272,7 @@ namespace tt::tt_metal::detail{ )doc"); m_tensor.def("add_bw", &tt::tt_metal::add_bw, - py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("are_required_outputs").noconvert() = std::vector{true, true}, py::arg("input_grad").noconvert() = std::nullopt,py::arg("other_grad").noconvert() = std::nullopt, R"doc( Performs backward operations for addition of ``input_b`` tensors with given ``grad``. Input tensor must have BFLOAT16 data type. @@ -280,6 +286,9 @@ namespace tt::tt_metal::detail{ "input_a", "Tensor add is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "input_b", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "are_required_outputs", "Boolean values for the required outputs: input_a_grad, input_b_grad ", "List of bool", "Default value is [True, True]", "No" + "input_grad", "Optional Output Tensor for input_a gradient", "Tensor", "Default value is None", "No" + "other_grad", "Optional Output Tensor for input_b gradient", "Tensor", "Default value is None", "No" )doc"); m_tensor.def("relu_bw", &tt::tt_metal::relu_bw, @@ -1637,7 +1646,7 @@ namespace tt::tt_metal::detail{ )doc"); m_tensor.def("binary_eq_bw", &tt::tt_metal::binary_eq_bw, - py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("other").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("are_required_outputs").noconvert() = std::vector{true, true}, py::arg("input_grad").noconvert() = std::nullopt,py::arg("other_grad").noconvert() = std::nullopt, R"doc( Returns an tensor of zeros like ``grad`` tensor and ``input`` tensor. Input tensors must have BFLOAT16 data type. @@ -1649,7 +1658,11 @@ namespace tt::tt_metal::detail{ "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "other", "Other Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "are_required_outputs", "Boolean values for the required outputs: input_grad, other_grad ", "List of bool", "Default value is [True, True]", "No" + "input_grad", "Optional Output Tensor for input gradient", "Tensor", "Default value is None", "No" + "other_grad", "Optional Output Tensor for other gradient", "Tensor", "Default value is None", "No" )doc"); m_tensor.def("binary_gt_bw", &tt::tt_metal::binary_gt_bw,