From 2405e060b3b448a3b6f5f2de740698256a60959a Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 15 Dec 2020 22:59:19 -0800 Subject: [PATCH 1/4] Fix Size translation to support dynamic inputs Size translation calls get_shape() which doesn't work on dynamic inputs. Fix the Size translation by computing size at runtime. --- ngraph_bridge/ngraph_builder.cc | 18 ++++++------------ ngraph_bridge/ngraph_mark_for_clustering.cc | 4 +++- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index c1aabd0ed..02ad7459b 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2122,21 +2122,15 @@ static Status TranslateSizeOp(const Node* op, const std::vector&, DataType dtype; TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "out_type", &dtype)); - - // Size has an attribute to specify output, int32 or int64 ng::element::Type type; TF_RETURN_IF_ERROR(util::TFDataTypeToNGraphElementType(dtype, &type)); - auto ng_input_shape = ng_input.get_shape(); - int64 result = 1; - for (auto dim : ng_input_shape) { - result *= dim; - } - - // make a scalar with value equals to result - auto ng_result = ConstructNgNode( - op->name(), type, ng::Shape(0), std::vector({result})); - + auto ng_input_shape = + ConstructNgNode(op->name(), ng_input, type); + auto ng_axis = ConstructNgNode( + op->name(), ngraph::element::i64, ngraph::Shape{1}, vector{0}); + auto ng_result = + ConstructNgNode(op->name(), ng_input_shape, ng_axis); SaveNgOp(ng_op_map, op->name(), ng_result); return Status::OK(); } diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index e825f0221..952531a0b 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -736,7 +736,9 @@ GetTFToNgOpMap() { {"Sigmoid", {std::make_shared()}}, {"Sin", {std::make_shared()}}, {"Sinh", {std::make_shared()}}, - {"Size", {constant}}, + {"Size", + {std::make_shared(), + std::make_shared()}}, {"Sign", {std::make_shared()}}, {"Slice", {constant, std::make_shared()}}, {"Snapshot", {}}, From 08bd3fb8568f8c3499926e528790c152e2f7f94f Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 15 Dec 2020 23:04:24 -0800 Subject: [PATCH 2/4] Avoid fixing negative dims in Split since it's handled by IE --- ngraph_bridge/ngraph_builder.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 02ad7459b..02412eda9 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2252,15 +2252,11 @@ static Status TranslateSplitOp( int32 num_split; TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "num_split", &num_split)); - ng::Shape shape = ng_input.get_shape(); - int rank = shape.size(); - std::vector split_dim_vec; TF_RETURN_IF_ERROR( GetStaticInputVector(op, 0, static_input_map, &split_dim_vec)); - int split_dim = split_dim_vec[0] + (split_dim_vec[0] < 0 ? (int64)rank : 0); auto ng_split_dim = ConstructNgNode( - op->name(), ng::element::u64, ng::Shape{}, split_dim); + op->name(), ng::element::u64, ng::Shape{}, split_dim_vec[0]); auto ng_split = make_shared(ng_input, ng_split_dim, num_split); for (int i = 0; i < num_split; ++i) { From ef52dda95a3925bb2101843a43a74b2a1023fc32 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Wed, 16 Dec 2020 17:52:13 -0800 Subject: [PATCH 3/4] Compute shape statically if input is static --- ngraph_bridge/ngraph_builder.cc | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 02412eda9..6badb32e1 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2125,12 +2125,24 @@ static Status TranslateSizeOp(const Node* op, const std::vector&, ng::element::Type type; TF_RETURN_IF_ERROR(util::TFDataTypeToNGraphElementType(dtype, &type)); - auto ng_input_shape = - ConstructNgNode(op->name(), ng_input, type); - auto ng_axis = ConstructNgNode( - op->name(), ngraph::element::i64, ngraph::Shape{1}, vector{0}); - auto ng_result = - ConstructNgNode(op->name(), ng_input_shape, ng_axis); + ngraph::Output ng_result; + if (ng_input.get_partial_shape().is_static()) { + auto ng_input_shape = ng_input.get_shape(); + int64 result = 1; + for (auto dim : ng_input_shape) { + result *= dim; + } + // make a scalar with value equals to result + ng_result = ConstructNgNode(op->name(), type, ng::Shape(0), + std::vector({result})); + } else { + auto ng_input_shape = + ConstructNgNode(op->name(), ng_input, type); + auto ng_axis = ConstructNgNode( + op->name(), ngraph::element::i64, ngraph::Shape{1}, vector{0}); + ng_result = + ConstructNgNode(op->name(), ng_input_shape, ng_axis); + } SaveNgOp(ng_op_map, op->name(), ng_result); return Status::OK(); } From fbcab2a39bb5676b5222d3612ff556a7e1ea754a Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Wed, 16 Dec 2020 18:08:59 -0800 Subject: [PATCH 4/4] Handle negative axis for split --- ngraph_bridge/ngraph_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 6badb32e1..7e1b2954c 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2268,7 +2268,7 @@ static Status TranslateSplitOp( TF_RETURN_IF_ERROR( GetStaticInputVector(op, 0, static_input_map, &split_dim_vec)); auto ng_split_dim = ConstructNgNode( - op->name(), ng::element::u64, ng::Shape{}, split_dim_vec[0]); + op->name(), ng::element::i32, ng::Shape{}, split_dim_vec[0]); auto ng_split = make_shared(ng_input, ng_split_dim, num_split); for (int i = 0; i < num_split; ++i) {