diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index c1aabd0ed..7e1b2954c 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2122,21 +2122,27 @@ static Status TranslateSizeOp(const Node* op, const std::vector&, DataType dtype; TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "out_type", &dtype)); - - // Size has an attribute to specify output, int32 or int64 ng::element::Type type; TF_RETURN_IF_ERROR(util::TFDataTypeToNGraphElementType(dtype, &type)); - auto ng_input_shape = ng_input.get_shape(); - int64 result = 1; - for (auto dim : ng_input_shape) { - result *= dim; + ngraph::Output ng_result; + if (ng_input.get_partial_shape().is_static()) { + auto ng_input_shape = ng_input.get_shape(); + int64 result = 1; + for (auto dim : ng_input_shape) { + result *= dim; + } + // make a scalar with value equals to result + ng_result = ConstructNgNode(op->name(), type, ng::Shape(0), + std::vector({result})); + } else { + auto ng_input_shape = + ConstructNgNode(op->name(), ng_input, type); + auto ng_axis = ConstructNgNode( + op->name(), ngraph::element::i64, ngraph::Shape{1}, vector{0}); + ng_result = + ConstructNgNode(op->name(), ng_input_shape, ng_axis); } - - // make a scalar with value equals to result - auto ng_result = ConstructNgNode( - op->name(), type, ng::Shape(0), std::vector({result})); - SaveNgOp(ng_op_map, op->name(), ng_result); return Status::OK(); } @@ -2258,15 +2264,11 @@ static Status TranslateSplitOp( int32 num_split; TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "num_split", &num_split)); - ng::Shape shape = ng_input.get_shape(); - int rank = shape.size(); - std::vector split_dim_vec; TF_RETURN_IF_ERROR( GetStaticInputVector(op, 0, static_input_map, &split_dim_vec)); - int split_dim = split_dim_vec[0] + (split_dim_vec[0] < 0 ? (int64)rank : 0); auto ng_split_dim = ConstructNgNode( - op->name(), ng::element::u64, ng::Shape{}, split_dim); + op->name(), ng::element::i32, ng::Shape{}, split_dim_vec[0]); auto ng_split = make_shared(ng_input, ng_split_dim, num_split); for (int i = 0; i < num_split; ++i) { diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index e825f0221..952531a0b 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -736,7 +736,9 @@ GetTFToNgOpMap() { {"Sigmoid", {std::make_shared()}}, {"Sin", {std::make_shared()}}, {"Sinh", {std::make_shared()}}, - {"Size", {constant}}, + {"Size", + {std::make_shared(), + std::make_shared()}}, {"Sign", {std::make_shared()}}, {"Slice", {constant, std::make_shared()}}, {"Snapshot", {}},