Skip to content

Commit

Permalink
Sindhu/bfloat16 support (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhu-nervana authored Jan 30, 2020
1 parent 8d69c68 commit 310ca25
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 125 deletions.
2 changes: 2 additions & 0 deletions bazel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cc_library(
"ngraph_bridge/ngraph_partial_shapes.h",
"ngraph_bridge/ngraph_prefetch_shared_data.h",
"ngraph_bridge/ngraph_pipelined_tensors.h",
"ngraph_bridge/ngraph_register_stub_kernels.h",
"ngraph_bridge/ngraph_rewrite_for_tracking.h",
"ngraph_bridge/ngraph_tensor_manager.h",
"ngraph_bridge/ngraph_timer.h",
Expand Down Expand Up @@ -89,6 +90,7 @@ cc_library(
"ngraph_bridge/ngraph_mark_for_clustering.cc",
"ngraph_bridge/ngraph_partial_shapes.cc",
"ngraph_bridge/ngraph_pipelined_tensors.cc",
"ngraph_bridge/ngraph_register_stub_kernels.cc",
"ngraph_bridge/ngraph_rewrite_for_tracking.cc",
"ngraph_bridge/ngraph_tensor_manager.cc",
"ngraph_bridge/ngraph_tracked_variable.cc",
Expand Down
1 change: 1 addition & 0 deletions ngraph_bridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ set(SRC
ngraph_freshness_tracker.cc
ngraph_mark_for_clustering.cc
ngraph_partial_shapes.cc
ngraph_register_stub_kernels.cc
ngraph_rewrite_for_tracking.cc
ngraph_rewrite_pass.cc
ngraph_tensor_manager.cc
Expand Down
134 changes: 14 additions & 120 deletions ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,133 +33,27 @@
#include "ngraph_bridge/ngraph_utils.h"
#include "ngraph_bridge/ngraph_var.h"

#include "ngraph_bridge/ngraph_register_stub_kernels.h"

using namespace std;
namespace ng = ngraph;

namespace tensorflow {

namespace ngraph_bridge {
/* -------------------------------------------------
//
// NGraphApplyMomentumOp
//
---------------------------------------------------*/

class NGraphApplyMomentumOp : public OpKernel {
private:
public:
explicit NGraphApplyMomentumOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

//---------------------------------------------------------------------------
// ~NGraphApplyMomentumOp()
//---------------------------------------------------------------------------
~NGraphApplyMomentumOp() override {}

// This will never be called
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
} // end of compute function
}; // end of NGraphApplyGradientDescent class definition

REGISTER_KERNEL_BUILDER(Name("NGraphApplyMomentum").Device(DEVICE_CPU),
NGraphApplyMomentumOp);
/* -------------------------------------------------
//
// NGraphApplyGradientDescentOp
//
---------------------------------------------------*/

class NGraphApplyGradientDescentOp : public OpKernel {
private:
public:
explicit NGraphApplyGradientDescentOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

//---------------------------------------------------------------------------
// ~NGraphApplyGradientDescentOp()
//---------------------------------------------------------------------------
~NGraphApplyGradientDescentOp() override {}

// This will never be called
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
} // end of compute function
}; // end of NGraphApplyGradientDescent class definition

REGISTER_KERNEL_BUILDER(Name("NGraphApplyGradientDescent").Device(DEVICE_CPU),
NGraphApplyGradientDescentOp);

/* -------------------------------------------------
//
// NGraphAssignSubOp
//
---------------------------------------------------*/

// Computes *input[0] = *input[0] - input[1]
class NGraphAssignSubOp : public OpKernel {
private:
// bool use_exclusive_lock_; //TF op has this
~NGraphAssignSubOp() override {}

public:
explicit NGraphAssignSubOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
}
};

REGISTER_KERNEL_BUILDER(Name("NGraphAssignSub").Device(DEVICE_CPU),
NGraphAssignSubOp);

/* -------------------------------------------------
//
// NGraphAssignAddOp
//
---------------------------------------------------*/

// Computes *input[0] = *input[0] + input[1]
class NGraphAssignAddOp : public OpKernel {
public:
explicit NGraphAssignAddOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
}

private:
~NGraphAssignAddOp() override {}
};

REGISTER_KERNEL_BUILDER(Name("NGraphAssignAdd").Device(DEVICE_CPU),
NGraphAssignAddOp);
// Register NGraphOptimizers here
// These Optimizer Ops are replaced by a TF computational subgraph
// in ReplaceModifiers Rewrite Pass. Hence, these Stub Kernels/Op will never get
// called

// Keep them in alphabetical order
REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyGradientDescent");
REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyMomentum");
REGISTER_NGRAPH_STUB_KERNEL(
"NGraphAssignAdd"); //*input[0] = *input[0] + input[1]
REGISTER_NGRAPH_STUB_KERNEL(
"NGraphAssignSub"); //*input[0] = *input[0] - input[1]

} // namespace ngraph_bridge

Expand Down
70 changes: 70 additions & 0 deletions ngraph_bridge/ngraph_register_stub_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

#include "ngraph_bridge/ngraph_register_stub_kernels.h"

using namespace std;

namespace tensorflow {

namespace ngraph_bridge {

/* -------------------------------------------------
//
// NGraphStubOp
//
---------------------------------------------------*/
// Constructor
NGStubOp::NGStubOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES(
context, false,
errors::Internal("The constructor for OpType ", type_string(),
"should not get called. This Op is expected to have "
"been encapsulated or replaced by other ops. Op Name: ",
name(), "\n"));
}
// Compute
void NGStubOp::Compute(OpKernelContext* context) {
OP_REQUIRES(
context, false,
errors::Internal("This kernel for OpType ", type_string(),
"should not get called. This Op is expected to have "
"been encapsulated or replaced by other ops. Op Name: ",
name(), "\n"));
}
// Destructor
NGStubOp::~NGStubOp() {}

/* ------------------------------------------------- */

// Register Bfloat Stub Kernels

// TF Ops that work on bfloat DataType get assigned Device XLA_CPU
// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub
// bfloat16 kernels here. The expectation is when we register the stub kernels
// for bfloat16 TF is going to assign DEVICE_CPU to the respective Ops and
// we will encapsulate them
// These Stub Kernels/Op will never get called

// Keep them in alphabetical order
REGISTER_NGRAPH_STUB_BFLOAT_KERNEL("Conv2D")

} // namespace ngraph_bridge

} // namespace tensorflow
56 changes: 56 additions & 0 deletions ngraph_bridge/ngraph_register_stub_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_
#define NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace std;

namespace tensorflow {

namespace ngraph_bridge {

/* -------------------------------------------------
//
// NGStubOp
//
---------------------------------------------------*/

class NGStubOp : public OpKernel {
public:
explicit NGStubOp(OpKernelConstruction* context);

void Compute(OpKernelContext* context) override;

private:
~NGStubOp() override;
};

#define REGISTER_NGRAPH_STUB_KERNEL(optype) \
REGISTER_KERNEL_BUILDER(Name(optype).Device(DEVICE_CPU), NGStubOp);

#define REGISTER_NGRAPH_STUB_BFLOAT_KERNEL(optype) \
REGISTER_KERNEL_BUILDER( \
Name(optype).Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), \
NGStubOp);

} // namespace ngraph_bridge

} // namespace tensorflow

#endif // NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_
18 changes: 13 additions & 5 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ Status TensorToStream(std::ostream& ostream, const Tensor& tensor) {
case DT_BOOL:
TensorDataToStream<bool>(ostream, n_elements, data);
break;
case DT_BFLOAT16:
return errors::Internal(
"TensorToStream got data type bfloat16. No compatible standard C++ "
"data type.");
break;
default:
return errors::Internal("TensorToStream got unsupported data type ",
DataType_Name(tensor.dtype()));
Expand Down Expand Up @@ -272,6 +277,8 @@ Status TFDataTypeToNGraphElementType(DataType tf_dt,
break;
case DataType::DT_QINT32:
*ng_et = ng::element::i32;
case DataType::DT_BFLOAT16:
*ng_et = ng::element::bf16;
break;
default:
return errors::Unimplemented("Unsupported TensorFlow data type: ",
Expand Down Expand Up @@ -322,15 +329,16 @@ void print_node_histogram(const std::unordered_map<string, int>& histogram,

const gtl::ArraySlice<DataType>& NGraphDTypes() {
static gtl::ArraySlice<DataType> result{
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT32, DT_UINT64, DT_BOOL, DT_QINT8, DT_QUINT8};
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16};
return result;
}

const gtl::ArraySlice<DataType>& NGraphNumericDTypes() {
static gtl::ArraySlice<DataType> result{
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64,
DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16};
return result;
}

Expand Down Expand Up @@ -358,7 +366,7 @@ const gtl::ArraySlice<DataType>& NGraphSupportedQuantizedDTypes() {
}

const gtl::ArraySlice<DataType>& NGraphRealDTypes() {
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE};
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16};
return result;
}

Expand Down
Loading

0 comments on commit 310ca25

Please sign in to comment.