Skip to content

Commit

Permalink
Fix cum_sum complex-type logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Pranshu-S committed Mar 24, 2024
1 parent 57e2178 commit 69ee26f
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions src/frontends/tensorflow_common/src/op/cumsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
#include "common_op_table.hpp"
#include "openvino/op/cum_sum.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/less.hpp"

using namespace std;
using namespace ov::op;
Expand All @@ -20,7 +18,7 @@ namespace tensorflow {
namespace op {

OutputVector translate_cumsum_op(const NodeContext& node) {
default_op_checks(node, 2, {"Cumsum"});
default_op_checks(node, 2, {"Cumsum"}, true);
auto x = node.get_input(0);
auto axis = node.get_input(1);
auto exclusive = node.get_attribute<bool>("exclusive", false);
Expand All @@ -30,22 +28,19 @@ OutputVector translate_cumsum_op(const NodeContext& node) {
if (complex_type_mark) {
x = complex_type_mark->input_value(0);
auto zero = create_same_type_const_scalar<int32_t>(x, 0);

auto is_zero = make_shared<v1::Equal>(axis, zero);
auto greater_than_zero = make_shared<v1::Greater>(axis, zero);

auto logical_or = make_shared<v1::LogicalOr>(is_zero, greater_than_zero);

auto less_than_zero = make_shared<v1::Less>(axis, zero);
auto const_one = make_shared<v0::Constant>(element::i32, Shape{}, 1);
auto const_minus_one = make_shared<v0::Constant>(element::i32, Shape{}, -1);

auto axis_update = make_shared<v1::Select>(logical_or, const_one, const_minus_one);

auto new_axis = make_shared<v1::Add>(axis, axis_update);
auto axis_update = make_shared<v1::Select>(less_than_zero, const_one, zero);
auto new_axis = make_shared<v1::Subtract>(axis, axis_update);
}

auto cum_sum = make_shared<v0::CumSum>(x, axis, exclusive, reverse);
set_node_name(node.get_name(), cum_sum);
if (complex_type_mark){
auto cum_sum_complex = make_shared<ComplexTypeMark>(cum_sum, complex_type_mark->get_complex_part_type());
return {cum_sum_complex};
}
return cum_sum->outputs();
}
} // namespace op
Expand Down

0 comments on commit 69ee26f

Please sign in to comment.