diff --git a/bazel/BUILD b/bazel/BUILD index f0674337a..b9e714a52 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -44,6 +44,7 @@ cc_library( "ngraph_bridge/ngraph_partial_shapes.h", "ngraph_bridge/ngraph_prefetch_shared_data.h", "ngraph_bridge/ngraph_pipelined_tensors.h", + "ngraph_bridge/ngraph_register_stub_kernels.h", "ngraph_bridge/ngraph_rewrite_for_tracking.h", "ngraph_bridge/ngraph_tensor_manager.h", "ngraph_bridge/ngraph_timer.h", @@ -89,6 +90,7 @@ cc_library( "ngraph_bridge/ngraph_mark_for_clustering.cc", "ngraph_bridge/ngraph_partial_shapes.cc", "ngraph_bridge/ngraph_pipelined_tensors.cc", + "ngraph_bridge/ngraph_register_stub_kernels.cc", "ngraph_bridge/ngraph_rewrite_for_tracking.cc", "ngraph_bridge/ngraph_tensor_manager.cc", "ngraph_bridge/ngraph_tracked_variable.cc", diff --git a/ngraph_bridge/CMakeLists.txt b/ngraph_bridge/CMakeLists.txt index 178d09536..ff85b7843 100644 --- a/ngraph_bridge/CMakeLists.txt +++ b/ngraph_bridge/CMakeLists.txt @@ -53,6 +53,7 @@ set(SRC ngraph_freshness_tracker.cc ngraph_mark_for_clustering.cc ngraph_partial_shapes.cc + ngraph_register_stub_kernels.cc ngraph_rewrite_for_tracking.cc ngraph_rewrite_pass.cc ngraph_tensor_manager.cc diff --git a/ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc b/ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc index 4035fbf2e..066f8b18b 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc @@ -33,133 +33,27 @@ #include "ngraph_bridge/ngraph_utils.h" #include "ngraph_bridge/ngraph_var.h" +#include "ngraph_bridge/ngraph_register_stub_kernels.h" + using namespace std; namespace ng = ngraph; namespace tensorflow { namespace ngraph_bridge { -/* ------------------------------------------------- -// -// NGraphApplyMomentumOp -// ----------------------------------------------------*/ - -class NGraphApplyMomentumOp : public OpKernel { - private: - public: - explicit NGraphApplyMomentumOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES(context, false, - errors::Internal("This constructor should not get called", - name(), "\n")); - } - - //--------------------------------------------------------------------------- - // ~NGraphApplyMomentumOp() - //--------------------------------------------------------------------------- - ~NGraphApplyMomentumOp() override {} - - // This will never be called - void Compute(OpKernelContext* context) override { - OP_REQUIRES( - context, false, - errors::Internal("This kernel should not get called", name(), "\n")); - } // end of compute function -}; // end of NGraphApplyGradientDescent class definition - -REGISTER_KERNEL_BUILDER(Name("NGraphApplyMomentum").Device(DEVICE_CPU), - NGraphApplyMomentumOp); -/* ------------------------------------------------- -// -// NGraphApplyGradientDescentOp -// ----------------------------------------------------*/ - -class NGraphApplyGradientDescentOp : public OpKernel { - private: - public: - explicit NGraphApplyGradientDescentOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES(context, false, - errors::Internal("This constructor should not get called", - name(), "\n")); - } - - //--------------------------------------------------------------------------- - // ~NGraphApplyGradientDescentOp() - //--------------------------------------------------------------------------- - ~NGraphApplyGradientDescentOp() override {} - - // This will never be called - void Compute(OpKernelContext* context) override { - OP_REQUIRES( - context, false, - errors::Internal("This kernel should not get called", name(), "\n")); - } // end of compute function -}; // end of NGraphApplyGradientDescent class definition - -REGISTER_KERNEL_BUILDER(Name("NGraphApplyGradientDescent").Device(DEVICE_CPU), - NGraphApplyGradientDescentOp); - -/* ------------------------------------------------- -// -// NGraphAssignSubOp -// ----------------------------------------------------*/ - -// Computes *input[0] = *input[0] - input[1] -class NGraphAssignSubOp : public OpKernel { - private: - // bool use_exclusive_lock_; //TF op has this - ~NGraphAssignSubOp() override {} - - public: - explicit NGraphAssignSubOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES(context, false, - errors::Internal("This constructor should not get called", - name(), "\n")); - } - - void Compute(OpKernelContext* context) override { - OP_REQUIRES( - context, false, - errors::Internal("This kernel should not get called", name(), "\n")); - } -}; - -REGISTER_KERNEL_BUILDER(Name("NGraphAssignSub").Device(DEVICE_CPU), - NGraphAssignSubOp); - -/* ------------------------------------------------- -// -// NGraphAssignAddOp -// ----------------------------------------------------*/ - -// Computes *input[0] = *input[0] + input[1] -class NGraphAssignAddOp : public OpKernel { - public: - explicit NGraphAssignAddOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES(context, false, - errors::Internal("This constructor should not get called", - name(), "\n")); - } - - void Compute(OpKernelContext* context) override { - OP_REQUIRES( - context, false, - errors::Internal("This kernel should not get called", name(), "\n")); - } - - private: - ~NGraphAssignAddOp() override {} -}; -REGISTER_KERNEL_BUILDER(Name("NGraphAssignAdd").Device(DEVICE_CPU), - NGraphAssignAddOp); +// Register NGraphOptimizers here +// These Optimizer Ops are replaced by a TF computational subgraph +// in ReplaceModifiers Rewrite Pass. Hence, these Stub Kernels/Op will never get +// called + +// Keep them in alphabetical order +REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyGradientDescent"); +REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyMomentum"); +REGISTER_NGRAPH_STUB_KERNEL( + "NGraphAssignAdd"); //*input[0] = *input[0] + input[1] +REGISTER_NGRAPH_STUB_KERNEL( + "NGraphAssignSub"); //*input[0] = *input[0] - input[1] } // namespace ngraph_bridge diff --git a/ngraph_bridge/ngraph_register_stub_kernels.cc b/ngraph_bridge/ngraph_register_stub_kernels.cc new file mode 100644 index 000000000..aa0dca466 --- /dev/null +++ b/ngraph_bridge/ngraph_register_stub_kernels.cc @@ -0,0 +1,70 @@ +/******************************************************************************* + * Copyright 2019-2020 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include "ngraph_bridge/ngraph_register_stub_kernels.h" + +using namespace std; + +namespace tensorflow { + +namespace ngraph_bridge { + +/* ------------------------------------------------- +// +// NGraphStubOp +// +---------------------------------------------------*/ +// Constructor +NGStubOp::NGStubOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES( + context, false, + errors::Internal("The constructor for OpType ", type_string(), + "should not get called. This Op is expected to have " + "been encapsulated or replaced by other ops. Op Name: ", + name(), "\n")); +} +// Compute +void NGStubOp::Compute(OpKernelContext* context) { + OP_REQUIRES( + context, false, + errors::Internal("This kernel for OpType ", type_string(), + "should not get called. This Op is expected to have " + "been encapsulated or replaced by other ops. Op Name: ", + name(), "\n")); +} +// Destructor +NGStubOp::~NGStubOp() {} + +/* ------------------------------------------------- */ + +// Register Bfloat Stub Kernels + +// TF Ops that work on bfloat DataType get assigned Device XLA_CPU +// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub +// bfloat16 kernels here. The expectation is when we register the stub kernels +// for bfloat16 TF is going to assign DEVICE_CPU to the respective Ops and +// we will encapsulate them +// These Stub Kernels/Op will never get called + +// Keep them in alphabetical order +REGISTER_NGRAPH_STUB_BFLOAT_KERNEL("Conv2D") + +} // namespace ngraph_bridge + +} // namespace tensorflow diff --git a/ngraph_bridge/ngraph_register_stub_kernels.h b/ngraph_bridge/ngraph_register_stub_kernels.h new file mode 100644 index 000000000..543e503d1 --- /dev/null +++ b/ngraph_bridge/ngraph_register_stub_kernels.h @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright 2019-2020 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#ifndef NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ +#define NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace std; + +namespace tensorflow { + +namespace ngraph_bridge { + +/* ------------------------------------------------- +// +// NGStubOp +// +---------------------------------------------------*/ + +class NGStubOp : public OpKernel { + public: + explicit NGStubOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + ~NGStubOp() override; +}; + +#define REGISTER_NGRAPH_STUB_KERNEL(optype) \ + REGISTER_KERNEL_BUILDER(Name(optype).Device(DEVICE_CPU), NGStubOp); + +#define REGISTER_NGRAPH_STUB_BFLOAT_KERNEL(optype) \ + REGISTER_KERNEL_BUILDER( \ + Name(optype).Device(DEVICE_CPU).TypeConstraint("T"), \ + NGStubOp); + +} // namespace ngraph_bridge + +} // namespace tensorflow + +#endif // NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ diff --git a/ngraph_bridge/ngraph_utils.cc b/ngraph_bridge/ngraph_utils.cc index 24f80362e..f2b177267 100644 --- a/ngraph_bridge/ngraph_utils.cc +++ b/ngraph_bridge/ngraph_utils.cc @@ -223,6 +223,11 @@ Status TensorToStream(std::ostream& ostream, const Tensor& tensor) { case DT_BOOL: TensorDataToStream(ostream, n_elements, data); break; + case DT_BFLOAT16: + return errors::Internal( + "TensorToStream got data type bfloat16. No compatible standard C++ " + "data type."); + break; default: return errors::Internal("TensorToStream got unsupported data type ", DataType_Name(tensor.dtype())); @@ -272,6 +277,8 @@ Status TFDataTypeToNGraphElementType(DataType tf_dt, break; case DataType::DT_QINT32: *ng_et = ng::element::i32; + case DataType::DT_BFLOAT16: + *ng_et = ng::element::bf16; break; default: return errors::Unimplemented("Unsupported TensorFlow data type: ", @@ -322,15 +329,16 @@ void print_node_histogram(const std::unordered_map& histogram, const gtl::ArraySlice& NGraphDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, - DT_UINT16, DT_UINT32, DT_UINT64, DT_BOOL, DT_QINT8, DT_QUINT8}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, + DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16}; return result; } const gtl::ArraySlice& NGraphNumericDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, - DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, + DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16}; return result; } @@ -358,7 +366,7 @@ const gtl::ArraySlice& NGraphSupportedQuantizedDTypes() { } const gtl::ArraySlice& NGraphRealDTypes() { - static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE}; + static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}; return result; } diff --git a/test/python/test_bfloat16.py b/test/python/test_bfloat16.py new file mode 100644 index 000000000..61fae7d67 --- /dev/null +++ b/test/python/test_bfloat16.py @@ -0,0 +1,131 @@ +# ============================================================================== +# Copyright 2019-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow bridge bfloat16 matmul operation test + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import numpy as np + +import tensorflow as tf +import os +import sys +from common import NgraphTest + +np.random.seed(5) + + +class TestBfloat16(NgraphTest): + + @pytest.mark.skip( + reason="CPU backend does not support dtype bf16 for MatMul/Dot Op") + def test_matmul_bfloat16(self): + a = tf.placeholder(tf.bfloat16, [2, 3], name='a') + x = tf.placeholder(tf.bfloat16, [3, 4], name='x') + a_inp = np.random.rand(2, 3) + x_inp = np.random.rand(3, 4) + out = tf.matmul(a, x) + + def run_test(sess): + return sess.run((out,), feed_dict={a: a_inp, x: x_inp}) + + assert self.with_ngraph(run_test) == self.without_ngraph(run_test) + + # For testing, we usually run the same graph on TF by disabling NGraph Rewrites. + # However, in this case as we register CPU bfloat dummy kernels, TF assigns device CPU + # to bfloat ops and hits the asserts in the dummy kernel. + # So, we are testing with expected values. + # For an ideal run on TF, we need to run on vanilla TF w/o importing ngraph-bridge + def test_conv2d_bfloat16(self): + # Graph + input_shape_nhwc = (1, 4, 4, 1) + filter_shape_hwio = (3, 3, 1, 2) + input_pl = tf.placeholder(tf.bfloat16, input_shape_nhwc, name="inp_pl") + filter_shape_pl = tf.placeholder( + tf.bfloat16, filter_shape_hwio, name="filter_pl") + input_values = np.arange(16).reshape( + input_shape_nhwc) #np.random.rand(*input_shape_nhwc) + filter_values = np.arange(18).reshape( + filter_shape_hwio) # np.random.rand(*filter_shape_hwio) + padding = "VALID" + strides = [1, 1, 1, 1] + conv_op = tf.nn.conv2d( + input_pl, + filter_shape_pl, + strides, + padding, + data_format='NHWC', + dilations=None, + name=None) + + def run_test(sess): + return sess.run((conv_op,), + feed_dict={ + input_pl: input_values, + filter_shape_pl: filter_values + }) + + ng_val = self.with_ngraph(run_test) + expected_val = np.reshape( + np.array([516, 560, 588, 640, 804, 884, 876, 968]), (1, 2, 2, 2)) + assert np.allclose(ng_val, expected_val) + + # For testing, we usually run the same graph on TF by disabling NGraph Rewrites. + # However, in this case as we register CPU bfloat dummy kernels, TF assigns device CPU + # to bfloat ops and hits the asserts in the dummy kernel. + # So, we are testing with expected values. + # For an ideal run on TF, we need to run on vanilla TF w/o importing ngraph-bridge + def test_conv2d_cast_bfloat16(self): + # Graph + input_shape_nhwc = (1, 4, 4, 1) + filter_shape_hwio = (3, 3, 1, 2) + input_pl = tf.placeholder(tf.float32, input_shape_nhwc, name="inp_pl") + filter_shape_pl = tf.placeholder( + tf.float32, filter_shape_hwio, name="filter_pl") + input_values = np.arange(16).reshape( + input_shape_nhwc) #np.random.rand(*input_shape_nhwc) + filter_values = np.arange(18).reshape( + filter_shape_hwio) # np.random.rand(*filter_shape_hwio) + # cast to bloat + input_cast = tf.cast(input_pl, dtype=tf.bfloat16) + filter_cast = tf.cast(filter_shape_pl, dtype=tf.bfloat16) + padding = "VALID" + strides = [1, 1, 1, 1] + conv_op = tf.nn.conv2d( + input_cast, + filter_cast, + strides, + padding, + data_format='NHWC', + dilations=None, + name=None) + # cast to float + out = tf.cast(conv_op, dtype=tf.float32) + + def run_test(sess): + return sess.run((out,), + feed_dict={ + input_pl: input_values, + filter_shape_pl: filter_values + }) + + ng_val = self.with_ngraph(run_test) + expected_val = np.reshape( + np.array([516, 560, 588, 640, 804, 884, 876, 968]), (1, 2, 2, 2)) + assert np.allclose(ng_val, expected_val)