From 69ee26f48b4cd2872bd911dc802867e4570c94a3 Mon Sep 17 00:00:00 2001 From: Pranshu-S Date: Sun, 24 Mar 2024 12:07:09 +0530 Subject: [PATCH] Fix cum_sum complex-type logic --- .../tensorflow_common/src/op/cumsum.cpp | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/cumsum.cpp b/src/frontends/tensorflow_common/src/op/cumsum.cpp index c0e7d1a79df42b..b267d852a860e4 100644 --- a/src/frontends/tensorflow_common/src/op/cumsum.cpp +++ b/src/frontends/tensorflow_common/src/op/cumsum.cpp @@ -5,11 +5,9 @@ #include "common_op_table.hpp" #include "openvino/op/cum_sum.hpp" #include "helper_ops/complex_type_mark.hpp" -#include "openvino/op/greater.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/logical_or.hpp" #include "openvino/op/select.hpp" -#include "openvino/op/add.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/less.hpp" using namespace std; using namespace ov::op; @@ -20,7 +18,7 @@ namespace tensorflow { namespace op { OutputVector translate_cumsum_op(const NodeContext& node) { - default_op_checks(node, 2, {"Cumsum"}); + default_op_checks(node, 2, {"Cumsum"}, true); auto x = node.get_input(0); auto axis = node.get_input(1); auto exclusive = node.get_attribute("exclusive", false); @@ -30,22 +28,19 @@ OutputVector translate_cumsum_op(const NodeContext& node) { if (complex_type_mark) { x = complex_type_mark->input_value(0); auto zero = create_same_type_const_scalar(x, 0); - - auto is_zero = make_shared(axis, zero); - auto greater_than_zero = make_shared(axis, zero); - - auto logical_or = make_shared(is_zero, greater_than_zero); - + auto less_than_zero = make_shared(axis, zero); auto const_one = make_shared(element::i32, Shape{}, 1); - auto const_minus_one = make_shared(element::i32, Shape{}, -1); - auto axis_update = make_shared(logical_or, const_one, const_minus_one); - - auto new_axis = make_shared(axis, axis_update); + auto axis_update = make_shared(less_than_zero, const_one, zero); + auto new_axis = make_shared(axis, axis_update); } auto cum_sum = make_shared(x, axis, exclusive, reverse); set_node_name(node.get_name(), cum_sum); + if (complex_type_mark){ + auto cum_sum_complex = make_shared(cum_sum, complex_type_mark->get_complex_part_type()); + return {cum_sum_complex}; + } return cum_sum->outputs(); } } // namespace op