diff --git a/README.md b/README.md index d328841a0..c6a9dff82 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ Once TensorFlow's dependencies are installed, clone the `ngraph-bridge` repo: git clone https://github.com/tensorflow/ngraph-bridge.git cd ngraph-bridge - git checkout v0.19.0-rc3 + git checkout v0.19.0-rc4 Run the following Python script to build TensorFlow, nGraph, and the bridge. Use Python 3.5: diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 8dea2da32..51adb9c12 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -375,8 +375,9 @@ static Status TranslateUnaryOp( create_unary_op) { shared_ptr ng_input; TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input)); - SaveNgOp(ng_op_map, op->name(), create_unary_op(ng_input)); - + auto ng_node = create_unary_op(ng_input); + ng_node->add_provenance_tag(op->name()); + SaveNgOp(ng_op_map, op->name(), ng_node); return Status::OK(); } @@ -435,8 +436,13 @@ static Status TranslateBinaryOp( std::tie(ng_lhs, ng_rhs) = ng::builder::numpy_broadcast(std::make_pair(ng_lhs, ng_rhs)); + ng_lhs->add_provenance_tag(op->name()); + ng_rhs->add_provenance_tag(op->name()); - SaveNgOp(ng_op_map, op->name(), create_binary_op(ng_lhs, ng_rhs)); + auto ng_node = create_binary_op(ng_lhs, ng_rhs); + ng_node->add_provenance_tag(op->name()); + + SaveNgOp(ng_op_map, op->name(), ng_node); return Status::OK(); } @@ -497,6 +503,7 @@ static Status TranslateQuantizedPoolOp(const Node* op, ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_kernel_shape); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape); @@ -526,8 +533,10 @@ static Status TranslateQuantizedPoolOp(const Node* op, ng_input, ng_kernel_shape, ng_strides, ng_padding_below, ng_padding_above, dummy_min, dummy_max); } + ng_quant_pool->add_provenance_tag(op->name()); BatchToTensorflow(is_nhwc, ng_quant_pool); + ng_quant_pool->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_quant_pool); // For QuantizedAvgPool and QuantizedMaxPool input min-max remains unchanged // and is just propagated along @@ -633,6 +642,7 @@ static Status TranslateAvgPoolOp(const Node* op, BatchedOpParamToNGraph(is_nhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_kernel_shape); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -652,6 +662,7 @@ static Status TranslateAvgPoolOp(const Node* op, ng_padding_above, false); BatchToTensorflow(is_nhwc, ng_avgpool); + ng_avgpool->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "avgpool outshape: {" << ng::join(ng_avgpool->get_shape()) << "}"; @@ -702,6 +713,7 @@ static Status TranslateAvgPoolGradOp( BatchedOpParamReshape(is_nhwc, ng_orig_input_shape, ng_forward_arg_shape); BatchToNGraph(is_nhwc, ng_grad); + ng_grad->add_provenance_tag(op->name()); BatchedOpParamToNGraph(is_nhwc, tf_strides, ng_strides); BatchedOpParamToNGraph(is_nhwc, ng_orig_input_shape, ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_window_shape); @@ -727,6 +739,7 @@ static Status TranslateAvgPoolGradOp( ng_strides, ng_padding_below, ng_padding_above, false); BatchToTensorflow(is_nhwc, ng_avgpool_backprop); + ng_avgpool_backprop->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "avgpoolbackprop outshape: {" << ng::join(ng_avgpool_backprop->get_shape()) << "}"; @@ -785,6 +798,7 @@ static Status TranslateBatchMatMulOp( ng_lhs_axes.push_back(n_dims - 1); ng_lhs_axes.push_back(n_dims - 2); ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); + ng_lhs->add_provenance_tag(op->name()); ng_lhs_shape = ng_lhs->get_shape(); } else { ng_lhs_axes.push_back(n_dims - 2); @@ -795,6 +809,7 @@ static Status TranslateBatchMatMulOp( ng_rhs_axes.push_back(n_dims - 1); ng_rhs_axes.push_back(n_dims - 2); ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); + ng_rhs->add_provenance_tag(op->name()); ng_rhs_shape = ng_rhs->get_shape(); } else { ng_rhs_axes.push_back(n_dims - 2); @@ -838,15 +853,18 @@ static Status TranslateBatchMatMulOp( ng_lhs_axes.push_back(n_dims - 1); ng_lhs_axes.push_back(n_dims - 2); ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng_lhs_axes); + ng_lhs->add_provenance_tag(op->name()); } if (tf_adj_y) { ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); + ng_rhs->add_provenance_tag(op->name()); } else { ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 1); ng_rhs_axes.insert(ng_rhs_axes.begin(), n_dims - 2); ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng_rhs_axes); + ng_rhs->add_provenance_tag(op->name()); } ng_lhs_shape = ng_lhs->get_shape(); @@ -1232,6 +1250,7 @@ static Status TranslateConv2DOp(const Node* op, BatchedOpParamToNGraph(is_nhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_dilations, ng_dilations); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations); @@ -1241,6 +1260,7 @@ static Status TranslateConv2DOp(const Node* op, ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; Reshape<3, 2, 0, 1>(ng_filter); + ng_filter->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -1256,6 +1276,7 @@ static Status TranslateConv2DOp(const Node* op, ng_padding_below, ng_padding_above); BatchToTensorflow(is_nhwc, ng_conv); + ng_conv->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_conv); return Status::OK(); } @@ -1327,6 +1348,7 @@ static Status TranslateConv2DBackpropFilterOp( // nGraph Padding Above [f] // nGraph Dilation Stride [f] BatchToNGraph(is_nhwc, ng_data_batch); + ng_data_batch->add_provenance_tag(op->name()); // tf_filter shape : // [filter_height, filter_width, in_channels, out_channels] // reshape for nGraph @@ -1335,6 +1357,7 @@ static Status TranslateConv2DBackpropFilterOp( static_cast(tf_filter_sizes[0]), static_cast(tf_filter_sizes[1])}; BatchToNGraph(is_nhwc, ng_output_delta); + ng_output_delta->add_provenance_tag(op->name()); BatchedOpParamToNGraph(is_nhwc, tf_strides, ng_window_movement_strides_forward); BatchedOpParamToNGraph(is_nhwc, tf_dilations, @@ -1376,6 +1399,7 @@ static Status TranslateConv2DBackpropFilterOp( // Reshape the output to tf format : [filter_height, filter_width, // in_channels, out_channels] Reshape<2, 3, 1, 0>(ng_back_prop_filter); + ng_back_prop_filter->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_back_prop_filter); return Status::OK(); @@ -1431,6 +1455,7 @@ static Status TranslateConv2DBackpropInputOp( BatchedOpParamToNGraph(is_nhwc, tf_input_sizes, ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_dilations, ng_dilations); BatchToNGraph(is_nhwc, ng_out_backprop); + ng_out_backprop->add_provenance_tag(op->name()); if (is_nhwc) { ng_batch_shape = {static_cast(tf_input_sizes[0]), static_cast(tf_input_sizes[3]), @@ -1451,6 +1476,7 @@ static Status TranslateConv2DBackpropInputOp( ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; Reshape<3, 2, 0, 1>(ng_filter); + ng_filter->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -1468,6 +1494,7 @@ static Status TranslateConv2DBackpropInputOp( ng::Strides(ng_batch_shape.size() - 2, 1)); BatchToTensorflow(is_nhwc, ng_data); + ng_data->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_data); return Status::OK(); @@ -1519,6 +1546,7 @@ static Status TranslateConv3DOp(const Node* op, BatchedOpParam3DToNGraph(is_ndhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParam3DToNGraph(is_ndhwc, tf_dilations, ng_dilations); BatchToNGraph3D(is_ndhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations); @@ -1529,6 +1557,7 @@ static Status TranslateConv3DOp(const Node* op, ng_kernel_shape[1] = ng_filter_shape[1]; ng_kernel_shape[2] = ng_filter_shape[2]; Reshape3D<4, 3, 0, 1, 2>(ng_filter); + ng_filter->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -1544,6 +1573,7 @@ static Status TranslateConv3DOp(const Node* op, ng_padding_below, ng_padding_above); BatchToTensorflow3D(is_ndhwc, ng_conv); + ng_conv->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_conv); return Status::OK(); } @@ -1697,6 +1727,7 @@ static Status TranslateDepthToSpaceOp(const Node* op, auto transposed = ng::builder::numpy_transpose(reshaped, ng_transpose_permutation); + transposed->add_provenance_tag(op->name()); ng::AxisVector ng_axis_order_second_reshape(transposed->get_shape().size()); std::iota(ng_axis_order_second_reshape.begin(), @@ -1744,6 +1775,7 @@ static Status TranslateDepthwiseConv2dNativeOp( BatchedOpParamToNGraph(is_nhwc, tf_strides, ng_strides); BatchedOpParamToNGraph(is_nhwc, tf_dilations, ng_dilations); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations); @@ -1753,6 +1785,7 @@ static Status TranslateDepthwiseConv2dNativeOp( ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; Reshape<3, 2, 0, 1>(ng_filter); + ng_filter->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -1779,8 +1812,8 @@ static Status TranslateDepthwiseConv2dNativeOp( const std::vector f_lower_bound{0, i, 0, 0}; const std::vector f_upper_bound{filter_shape[0], i + 1, filter_shape[2], filter_shape[3]}; - auto ng_sliced_filter = - make_shared(ng_filter, f_lower_bound, f_upper_bound); + auto ng_sliced_filter = ConstructNgNode( + op->name(), ng_filter, f_lower_bound, f_upper_bound); NGRAPH_VLOG(3) << "depthwise conv 2d."; NGRAPH_VLOG(3) << "sliced shape " << ng::join(ng_sliced_input->get_shape()); @@ -1798,6 +1831,7 @@ static Status TranslateDepthwiseConv2dNativeOp( op->name(), ng_args, ng_concatenation_axis); BatchToTensorflow(is_nhwc, ng_concat); + ng_concat->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_concat); return Status::OK(); } @@ -1927,6 +1961,7 @@ static Status TranslateFusedBatchNormOp( NGRAPH_VLOG(3) << "epsilon: " << tf_epsilon; BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); std::shared_ptr ng_batch_norm; @@ -1955,6 +1990,7 @@ static Status TranslateFusedBatchNormOp( Bessel_scale); BatchToTensorflow(is_nhwc, ng_y); + ng_y->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_y); SaveNgOp(ng_op_map, op->name(), ng_mean); @@ -1979,6 +2015,7 @@ static Status TranslateFusedBatchNormOp( op->name(), tf_epsilon, ng_scale, ng_offset, ng_input, ng_mean, ng_variance); BatchToTensorflow(is_nhwc, ng_batch_norm); + ng_batch_norm->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_batch_norm); if (is_v3) { SaveNgOp(ng_op_map, op->name(), ng_mean); @@ -2058,7 +2095,9 @@ static Status TranslateFusedBatchNormGradOp(const Node* op, std::vector{ng::shape_size(ng_scale->get_shape()), "0"}); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); BatchToNGraph(is_nhwc, ng_delta); + ng_delta->add_provenance_tag(op->name()); std::shared_ptr ng_batch_norm_backprop; @@ -2077,6 +2116,7 @@ static Status TranslateFusedBatchNormGradOp(const Node* op, ng_batch_norm_backprop, 2); BatchToTensorflow(is_nhwc, ng_input_delta_op); + ng_input_delta_op->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_input_delta_op); SaveNgOp(ng_op_map, op->name(), ng_scale_delta_op); @@ -2149,10 +2189,12 @@ static Status TranslateFusedMatMulOp(const Node* op, if (GetNodeAttr(op->attrs(), "transpose_a", &transpose_a) == Status::OK() && transpose_a) { ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng::AxisVector{1, 0}); + ng_lhs->add_provenance_tag(op->name()); } if (GetNodeAttr(op->attrs(), "transpose_b", &transpose_b) == Status::OK() && transpose_b) { ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng::AxisVector{1, 0}); + ng_rhs->add_provenance_tag(op->name()); } // The default axis count for nGraph's Dot op is 1, which is just what @@ -2320,6 +2362,7 @@ static Status TranslateFusedConv2DOp(const Node* op, BatchedOpParamToNGraph(is_nhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_dilations, ng_dilations); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_dilations: " << ng::join(ng_dilations); @@ -2329,6 +2372,7 @@ static Status TranslateFusedConv2DOp(const Node* op, ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; Reshape<3, 2, 0, 1>(ng_filter); + ng_filter->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -2373,6 +2417,7 @@ static Status TranslateFusedConv2DOp(const Node* op, TF_RETURN_IF_ERROR(CreateNgConv(ng_input, ng_filter, ng_conv)); BatchToTensorflow(is_nhwc, ng_conv); + ng_conv->add_provenance_tag(op->name()); auto ng_conv_shape = ng_conv->get_shape(); auto ng_bias_shape = ng_bias->get_shape(); @@ -2434,6 +2479,7 @@ static Status TranslateFusedConv2DOp(const Node* op, ng_offset, ng_conv, ng_mean, ng_variance); BatchToTensorflow(is_nhwc, ng_batch_norm); + ng_batch_norm->add_provenance_tag(op->name()); if (VecStrCmp(fused_ops, {"FusedBatchNorm", "Relu"})) { SaveNgOp(ng_op_map, op->name(), @@ -2543,10 +2589,12 @@ static Status TranslateMatMulOp(const Node* op, if (GetNodeAttr(op->attrs(), "transpose_a", &transpose_a) == Status::OK() && transpose_a) { ng_lhs = ng::builder::numpy_transpose(ng_lhs, ng::AxisVector{1, 0}); + ng_lhs->add_provenance_tag(op->name()); } if (GetNodeAttr(op->attrs(), "transpose_b", &transpose_b) == Status::OK() && transpose_b) { ng_rhs = ng::builder::numpy_transpose(ng_rhs, ng::AxisVector{1, 0}); + ng_rhs->add_provenance_tag(op->name()); } // The default axis count for nGraph's Dot op is 1, which is just what @@ -2591,6 +2639,7 @@ static Status TranslateMaxPoolOp(const Node* op, BatchedOpParamToNGraph(is_nhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_kernel_shape); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -2610,6 +2659,7 @@ static Status TranslateMaxPoolOp(const Node* op, ng_padding_above); BatchToTensorflow(is_nhwc, ng_maxpool); + ng_maxpool->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "maxpool outshape: {" << ng::join(ng_maxpool->get_shape()) << "}"; @@ -2653,6 +2703,7 @@ static Status TranslateMaxPool3DOp(const Node* op, BatchedOpParam3DToNGraph(is_ndhwc, ng_input->get_shape(), ng_image_shape); BatchedOpParam3DToNGraph(is_ndhwc, tf_ksize, ng_kernel_shape); BatchToNGraph3D(is_ndhwc, ng_input); + ng_input->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape); NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape); @@ -2672,6 +2723,7 @@ static Status TranslateMaxPool3DOp(const Node* op, ng_padding_above); BatchToTensorflow3D(is_ndhwc, ng_maxpool); + ng_maxpool->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "maxpool outshape: {" << ng::join(ng_maxpool->get_shape()) << "}"; @@ -2714,8 +2766,11 @@ static Status TranslateMaxPoolGradOp(const Node* op, BatchedOpParamToNGraph(is_nhwc, tf_strides, ng_strides); BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_kernel_shape); BatchToNGraph(is_nhwc, ng_input); + ng_input->add_provenance_tag(op->name()); BatchToNGraph(is_nhwc, ng_grad); + ng_grad->add_provenance_tag(op->name()); BatchToNGraph(is_nhwc, ng_fwd); + ng_fwd->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides); NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape); @@ -2732,6 +2787,7 @@ static Status TranslateMaxPoolGradOp(const Node* op, op->name(), ng_input, ng_grad, ng_fwd, ng_kernel_shape, ng_strides, ng_padding_below, ng_padding_above); BatchToTensorflow(is_nhwc, ng_maxpool_backprop); + ng_maxpool_backprop->add_provenance_tag(op->name()); NGRAPH_VLOG(3) << "maxpoolbackprop outshape: {" << ng::join(ng_maxpool_backprop->get_shape()) << "}"; SaveNgOp(ng_op_map, op->name(), ng_maxpool_backprop); @@ -2842,6 +2898,7 @@ static Status TranslateReduceOp( std::shared_ptr ng_node = create_ng_node(ng_input, ng_reduction_axes); + ng_node->add_provenance_tag(op->name()); // If keep_dims is specified we need to reshape to put back the reduced // axes, with length 1. @@ -2867,11 +2924,15 @@ static Status TranslateReduceOp( static Status TranslateMeanOp( const Node* op, const std::vector& static_input_map, Builder::OpMap& ng_op_map) { - return TranslateReduceOp( - op, static_input_map, ng_op_map, - [](std::shared_ptr ng_input, ng::AxisSet ng_reduction_axes) { - return ng::builder::mean(ng_input, ng_reduction_axes); - }); + string op_name = op->name(); + return TranslateReduceOp(op, static_input_map, ng_op_map, + [&op_name](std::shared_ptr ng_input, + ng::AxisSet ng_reduction_axes) { + auto mean_node = + ng::builder::mean(ng_input, ng_reduction_axes); + mean_node->add_provenance_tag(op_name); + return mean_node; + }); } template @@ -2935,8 +2996,12 @@ static Status TranslateOneHotOp( // broadcast to make all tensors same shape, as required by ngraph select op std::tie(ng_onehot_bool, ng_on) = ng::builder::numpy_broadcast(std::make_pair(ng_onehot_bool, ng_on)); + ng_onehot_bool->add_provenance_tag(op->name()); + ng_on->add_provenance_tag(op->name()); std::tie(ng_onehot_bool, ng_off) = ng::builder::numpy_broadcast(std::make_pair(ng_onehot_bool, ng_off)); + ng_onehot_bool->add_provenance_tag(op->name()); + ng_off->add_provenance_tag(op->name()); auto ng_onehot = ConstructNgNode(op->name(), ng_onehot_bool, ng_on, ng_off); @@ -3283,15 +3348,15 @@ static Status TranslateQuantizedConcatOpHelper( all_mins[idx] = min_tmp[0]; all_maxs[idx] = max_tmp[0]; - auto min_node = - make_shared(ng::element::f32, ng::Shape{}, min_tmp); - auto max_node = - make_shared(ng::element::f32, ng::Shape{}, max_tmp); + auto min_node = ConstructNgNode( + op->name(), ng::element::f32, ng::Shape{}, min_tmp); + auto max_node = ConstructNgNode( + op->name(), ng::element::f32, ng::Shape{}, max_tmp); - ng_all_mins.push_back(std::make_shared( - min_node, ngraph::AxisVector{}, ngraph::Shape{1})); - ng_all_maxs.push_back(std::make_shared( - max_node, ngraph::AxisVector{}, ngraph::Shape{1})); + ng_all_mins.push_back(ConstructNgNode( + op->name(), min_node, ngraph::AxisVector{}, ngraph::Shape{1})); + ng_all_maxs.push_back(ConstructNgNode( + op->name(), max_node, ngraph::AxisVector{}, ngraph::Shape{1})); } // return the min among the input_mins, and the max among the input_maxs @@ -3304,13 +3369,14 @@ static Status TranslateQuantizedConcatOpHelper( 1, *std::max_element(all_maxs.begin(), all_maxs.end())); // construct output_min and output_max - shared_ptr ng_min_of_mins = - make_shared(ng::element::f32, ng::Shape{}, min_of_mins); - shared_ptr ng_max_of_maxs = - make_shared(ng::element::f32, ng::Shape{}, max_of_maxs); + shared_ptr ng_min_of_mins = ConstructNgNode( + op->name(), ng::element::f32, ng::Shape{}, min_of_mins); + shared_ptr ng_max_of_maxs = ConstructNgNode( + op->name(), ng::element::f32, ng::Shape{}, max_of_maxs); auto ng_qconcat = ng::builder::ScaledQuantizedConcat( ng_args, size_t(concat_axis), ng_all_mins, ng_all_maxs); + ng_qconcat->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_qconcat); SaveNgOp(ng_op_map, op->name(), ng_min_of_mins); @@ -3363,14 +3429,17 @@ static Status TranslateQuantizedConv( BatchedOpParamToNGraph(is_nhwc, tf_dilations, ng_dilations); // Generally, the mapping is: 0->input, 1->filter, 2->bias, 3->sum input BatchToNGraph(is_nhwc, node_inps[0]); + node_inps[0]->add_provenance_tag(op->name()); // QconvBiasAdd variants if (num_node_inputs == 12) { BatchToNGraph(is_nhwc, node_inps[9]); + node_inps[9]->add_provenance_tag(op->name()); } auto& ng_filter_shape = node_inps[1]->get_shape(); ng_kernel_shape[0] = ng_filter_shape[0]; ng_kernel_shape[1] = ng_filter_shape[1]; Reshape<3, 2, 0, 1>(node_inps[1]); + node_inps[1]->add_provenance_tag(op->name()); ng::CoordinateDiff ng_padding_below{0, 0}; ng::CoordinateDiff ng_padding_above{0, 0}; Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape, @@ -3384,8 +3453,10 @@ static Status TranslateQuantizedConv( std::shared_ptr ng_quant_conv_bias = create_quantized_conv_node( node_inps, ng_strides, ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations); + ng_quant_conv_bias->add_provenance_tag(op->name()); BatchToTensorflow(is_nhwc, ng_quant_conv_bias); + ng_quant_conv_bias->add_provenance_tag(op->name()); SaveNgOp(ng_op_map, op->name(), ng_quant_conv_bias); // QconvBiasAdd variants have summand and its min/max as the last input // nodes @@ -3401,15 +3472,18 @@ template static Status TranslateQuantizedConv2DWithBiasMaybeReluAndRequantizeOp( const Node* op, const std::vector&, Builder::OpMap& ng_op_map) { - auto create_quantized_conv_node = []( + string op_name = op->name(); + auto create_quantized_conv_node = [&op_name]( std::vector> node_inps, ng::Strides ng_strides, ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below, ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) { - return ng::builder::ScaledQuantizedConvolutionBias( + auto ng_node = ng::builder::ScaledQuantizedConvolutionBias( node_inps[0], node_inps[1], node_inps[2], ng_strides, ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations, node_inps[3], node_inps[4], node_inps[5], node_inps[6], node_inps[7], node_inps[8], IsRelu); + ng_node->add_provenance_tag(op_name); + return ng_node; }; return TranslateQuantizedConv(op, ng_op_map, create_quantized_conv_node); } @@ -3417,15 +3491,18 @@ static Status TranslateQuantizedConv2DWithBiasMaybeReluAndRequantizeOp( static Status TranslateQuantizedConv2DWithBiasSumAndReluAndRequantizeOp( const Node* op, const std::vector&, Builder::OpMap& ng_op_map) { - auto create_quantized_conv_node = []( + string op_name = op->name(); + auto create_quantized_conv_node = [&op_name]( std::vector> node_inps, ng::Strides ng_strides, ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below, ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) { - return ng::builder::ScaledQuantizedConvolutionBiasAdd( + auto ng_node = ng::builder::ScaledQuantizedConvolutionBiasAdd( node_inps[0], node_inps[1], node_inps[2], node_inps[9], ng_strides, ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations, node_inps[3], node_inps[4], node_inps[5], node_inps[6], node_inps[7], node_inps[8], node_inps[10], node_inps[11], true); + ng_node->add_provenance_tag(op_name); + return ng_node; }; return TranslateQuantizedConv(op, ng_op_map, create_quantized_conv_node); } @@ -3433,15 +3510,18 @@ static Status TranslateQuantizedConv2DWithBiasSumAndReluAndRequantizeOp( static Status TranslateQuantizedConv2DWithBiasSignedSumAndReluAndRequantizeOp( const Node* op, const std::vector&, Builder::OpMap& ng_op_map) { - auto create_quantized_conv_node = []( + string op_name = op->name(); + auto create_quantized_conv_node = [&op_name]( std::vector> node_inps, ng::Strides ng_strides, ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below, ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) { - return ng::builder::ScaledQuantizedConvolutionBiasSignedAdd( + auto ng_node = ng::builder::ScaledQuantizedConvolutionBiasSignedAdd( node_inps[0], node_inps[1], node_inps[2], node_inps[9], ng_strides, ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations, node_inps[3], node_inps[4], node_inps[5], node_inps[6], node_inps[7], node_inps[8], node_inps[10], node_inps[11], true); + ng_node->add_provenance_tag(op_name); + return ng_node; }; return TranslateQuantizedConv(op, ng_op_map, create_quantized_conv_node); } @@ -3470,9 +3550,10 @@ static Status TranslateQuantizeV2Op(const Node* op, ng::op::Quantize::RoundMode ng_round_mode = ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN; - SaveNgOp(ng_op_map, op->name(), - ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et, - ng::AxisSet(), ng_round_mode)); + auto ng_node = ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et, + ng::AxisSet(), ng_round_mode); + ng_node->add_provenance_tag(op->name()); + SaveNgOp(ng_op_map, op->name(), ng_node); SaveNgOp(ng_op_map, op->name(), ng_min); SaveNgOp(ng_op_map, op->name(), ng_max); @@ -3486,9 +3567,10 @@ static Status TranslateDequantizeOp(const Node* op, TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input, &ng_min, &ng_max)); // TF only dequantizes to fp32 - SaveNgOp(ng_op_map, op->name(), - ng::builder::ScaledDequantize(ng_input, ng_min, ng_max, - ng::element::f32, ng::AxisSet())); + auto ng_node = ng::builder::ScaledDequantize(ng_input, ng_min, ng_max, + ng::element::f32, ng::AxisSet()); + ng_node->add_provenance_tag(op->name()); + SaveNgOp(ng_op_map, op->name(), ng_node); return Status::OK(); } @@ -3674,12 +3756,15 @@ static Status TranslateSigmoidGradOp(const Node* op, shared_ptr ng_delta; TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input, &ng_delta)); - auto ng_mul = ng_input * ng_delta; - auto ng_subtract = ConstructNgNode( - op->name(), ng_input->get_element_type(), - ng_input->get_shape(), std::vector({1})) - - ng_input; - auto ng_result = ng_mul * ng_subtract; + auto ng_mul = + ConstructNgNode(op->name(), ng_input, ng_delta); + auto ng_subtract = ConstructNgNode( + op->name(), ConstructNgNode( + op->name(), ng_input->get_element_type(), + ng_input->get_shape(), std::vector({1})), + ng_input); + auto ng_result = + ConstructNgNode(op->name(), ng_mul, ng_subtract); SaveNgOp(ng_op_map, op->name(), ng_result); return Status::OK(); @@ -4658,8 +4743,9 @@ static Status TranslateTransposeOp( NGRAPH_VLOG(3) << ng::join(ng_axis_order); - SaveNgOp(ng_op_map, op->name(), - ng::builder::numpy_transpose(ng_input, ng_axis_order)); + auto ng_node = ng::builder::numpy_transpose(ng_input, ng_axis_order); + ng_node->add_provenance_tag(op->name()); + SaveNgOp(ng_op_map, op->name(), ng_node); return Status::OK(); } @@ -4767,8 +4853,12 @@ static Status TranslateSelectOp(const Node* op, std::tie(ng_input1, ng_input2) = ng::builder::numpy_broadcast( std::make_pair(length != 0 ? ng_input_new : ng_input1, ng_input2)); + ng_input1->add_provenance_tag(op->name()); + ng_input2->add_provenance_tag(op->name()); std::tie(ng_input2, ng_input3) = ng::builder::numpy_broadcast(std::make_pair(ng_input2, ng_input3)); + ng_input2->add_provenance_tag(op->name()); + ng_input3->add_provenance_tag(op->name()); ng_select = ConstructNgNode(op->name(), ng_input1, ng_input2, ng_input3); @@ -5056,6 +5146,23 @@ Status Builder::TranslateGraph( result->set_needs_default_layout(true); } + auto check_if_result_or_parameter = [](shared_ptr n) { + // Pointer will cast to nullptr if this node is not a Result + auto ng_node = dynamic_pointer_cast(n); + bool is_result = (ng_node != nullptr); + return n->is_parameter() || is_result; + }; + + for (auto n : ng_function->get_ordered_ops()) { + // Results and Parameters are not expected to have provenance tags + if (!check_if_result_or_parameter(n)) { + if (n->get_provenance_tags().size() == 0) { + return errors::Internal("Found ngraph node ", n->get_name(), + " which does not have provenance tag set"); + } + } + } + return Status::OK(); } diff --git a/ngraph_bridge/version.cc b/ngraph_bridge/version.cc index 2fa7cc9fd..da3c6e649 100644 --- a/ngraph_bridge/version.cc +++ b/ngraph_bridge/version.cc @@ -32,7 +32,7 @@ // candidate such as v0.7.0-rc0 // The code in master will always have the last released version number // with a suffix of '-master' -#define NG_TF_VERSION_SUFFIX "-rc3" +#define NG_TF_VERSION_SUFFIX "-rc4" #define VERSION_STR_HELPER(x) #x #define VERSION_STR(x) VERSION_STR_HELPER(x) diff --git a/python/setup.in.py b/python/setup.in.py index e91174d52..f71d432ad 100644 --- a/python/setup.in.py +++ b/python/setup.in.py @@ -59,7 +59,7 @@ def get_tag(self): setup( name='ngraph_tensorflow_bridge', - version='0.19.0rc3', + version='0.19.0rc4', description='Intel nGraph compiler and runtime for TensorFlow', long_description=long_description, long_description_content_type="text/markdown", diff --git a/test/ci/buildkite/ngtf-cpu_ubuntu-bin-build.yaml b/test/ci/buildkite/ngtf-cpu_ubuntu-bin-build.yaml index a5ae188af..28c7e721f 100644 --- a/test/ci/buildkite/ngtf-cpu_ubuntu-bin-build.yaml +++ b/test/ci/buildkite/ngtf-cpu_ubuntu-bin-build.yaml @@ -5,7 +5,7 @@ label: ":gear: Setup" timeout_in_minutes: 30 agents: - - "queue=cpu-centos" + - "queue=cpu" parallelism: 1 - wait @@ -17,7 +17,7 @@ label: ":hammer_and_wrench: Build" timeout_in_minutes: 60 agents: - - "queue=cpu-centos" + - "queue=cpu" parallelism: 1 - wait @@ -31,7 +31,7 @@ label: ":bazel: Bazel Build" timeout_in_minutes: 30 agents: - - "queue=cpu-centos" + - "queue=cpu" - wait @@ -44,7 +44,7 @@ label: ":bar_chart: ResNet50" timeout_in_minutes: 30 agents: - - "queue=cpu-centos" + - "queue=cpu" - wait: ~ continue_on_failure: true @@ -52,5 +52,5 @@ rm -rf /localdisk/buildkite/artifacts/$BUILDKITE_BUILD_ID label: ":wastebasket: Cleanup" agents: - - "queue=cpu-centos" + - "queue=cpu"