Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sindhu/bfloat16 support #399

Merged
merged 29 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
694e63c
initial commit
sindhu-nervana Dec 11, 2019
2ab87c7
add bfloat16 test
sindhu-nervana Dec 16, 2019
b267ba1
Shrestha/var in compute (#388)
Dec 20, 2019
3ffb02e
disable the test
sindhu-nervana Dec 20, 2019
367d3db
Kanvi/Add asserts in some python tests (#398)
kanvi-nervana Dec 20, 2019
453a304
Merge branch 'master' into sindhu/bfloat16_support
sindhu-nervana Dec 20, 2019
22e7755
Merge branch 'master' into sindhu/bfloat16_support
Dec 24, 2019
c6220b7
Merge branch 'master' into sindhu/bfloat16_support
kanvi-nervana Dec 31, 2019
bea7c4d
Merge branch 'master' into sindhu/bfloat16_support
Jan 17, 2020
4cfb27f
added test
Jan 21, 2020
266b24a
changes
Jan 22, 2020
062a3c3
added another test
Jan 24, 2020
f00e298
added another bfloat test. encapsulate always assigned device CPU
Jan 24, 2020
5644eb6
Merge remote-tracking branch 'origin/master' into sindhu/bfloat16_sup…
Jan 24, 2020
0a4ffdd
removed couts, rearranged the tests
Jan 24, 2020
80c46f8
device checks
Jan 25, 2020
eb145c7
fix by registering dummy bfloat kernel
Jan 28, 2020
4d91711
Merge remote-tracking branch 'origin/master' into sindhu/bfloat16_sup…
Jan 28, 2020
5f08083
hanging include
Jan 28, 2020
e50323a
changes
Jan 28, 2020
e35892d
minor
Jan 28, 2020
a95c92f
Register Stub Kernels
Jan 29, 2020
5d313e3
fix bazel build
Jan 29, 2020
f636278
update comment
Jan 29, 2020
d2a161f
added comments to the test
Jan 29, 2020
1e4923c
corrected the macros
Jan 29, 2020
0bb58e0
fix template
Jan 29, 2020
957bf01
Merge remote-tracking branch 'origin/master' into sindhu/bfloat16_sup…
Jan 29, 2020
9fce56c
incorporate review comments
Jan 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bazel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cc_library(
"ngraph_bridge/ngraph_partial_shapes.h",
"ngraph_bridge/ngraph_prefetch_shared_data.h",
"ngraph_bridge/ngraph_pipelined_tensors.h",
"ngraph_bridge/ngraph_register_stub_kernels.h",
"ngraph_bridge/ngraph_rewrite_for_tracking.h",
"ngraph_bridge/ngraph_tensor_manager.h",
"ngraph_bridge/ngraph_timer.h",
Expand Down Expand Up @@ -89,6 +90,7 @@ cc_library(
"ngraph_bridge/ngraph_mark_for_clustering.cc",
"ngraph_bridge/ngraph_partial_shapes.cc",
"ngraph_bridge/ngraph_pipelined_tensors.cc",
"ngraph_bridge/ngraph_register_stub_kernels.cc",
"ngraph_bridge/ngraph_rewrite_for_tracking.cc",
"ngraph_bridge/ngraph_tensor_manager.cc",
"ngraph_bridge/ngraph_tracked_variable.cc",
Expand Down
1 change: 1 addition & 0 deletions ngraph_bridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ set(SRC
ngraph_freshness_tracker.cc
ngraph_mark_for_clustering.cc
ngraph_partial_shapes.cc
ngraph_register_stub_kernels.cc
ngraph_rewrite_for_tracking.cc
ngraph_rewrite_pass.cc
ngraph_tensor_manager.cc
Expand Down
134 changes: 14 additions & 120 deletions ngraph_bridge/enable_variable_ops/ngraph_variable_modifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,133 +33,27 @@
#include "ngraph_bridge/ngraph_utils.h"
#include "ngraph_bridge/ngraph_var.h"

#include "ngraph_bridge/ngraph_register_stub_kernels.h"

using namespace std;
namespace ng = ngraph;

namespace tensorflow {

namespace ngraph_bridge {
/* -------------------------------------------------
//
// NGraphApplyMomentumOp
//
---------------------------------------------------*/

class NGraphApplyMomentumOp : public OpKernel {
private:
public:
explicit NGraphApplyMomentumOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

//---------------------------------------------------------------------------
// ~NGraphApplyMomentumOp()
//---------------------------------------------------------------------------
~NGraphApplyMomentumOp() override {}

// This will never be called
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
} // end of compute function
}; // end of NGraphApplyGradientDescent class definition

REGISTER_KERNEL_BUILDER(Name("NGraphApplyMomentum").Device(DEVICE_CPU),
NGraphApplyMomentumOp);
/* -------------------------------------------------
//
// NGraphApplyGradientDescentOp
//
---------------------------------------------------*/

class NGraphApplyGradientDescentOp : public OpKernel {
private:
public:
explicit NGraphApplyGradientDescentOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

//---------------------------------------------------------------------------
// ~NGraphApplyGradientDescentOp()
//---------------------------------------------------------------------------
~NGraphApplyGradientDescentOp() override {}

// This will never be called
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
} // end of compute function
}; // end of NGraphApplyGradientDescent class definition

REGISTER_KERNEL_BUILDER(Name("NGraphApplyGradientDescent").Device(DEVICE_CPU),
NGraphApplyGradientDescentOp);

/* -------------------------------------------------
//
// NGraphAssignSubOp
//
---------------------------------------------------*/

// Computes *input[0] = *input[0] - input[1]
class NGraphAssignSubOp : public OpKernel {
private:
// bool use_exclusive_lock_; //TF op has this
~NGraphAssignSubOp() override {}

public:
explicit NGraphAssignSubOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
}
};

REGISTER_KERNEL_BUILDER(Name("NGraphAssignSub").Device(DEVICE_CPU),
NGraphAssignSubOp);

/* -------------------------------------------------
//
// NGraphAssignAddOp
//
---------------------------------------------------*/

// Computes *input[0] = *input[0] + input[1]
class NGraphAssignAddOp : public OpKernel {
public:
explicit NGraphAssignAddOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES(context, false,
errors::Internal("This constructor should not get called",
name(), "\n"));
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, false,
errors::Internal("This kernel should not get called", name(), "\n"));
}

private:
~NGraphAssignAddOp() override {}
};

REGISTER_KERNEL_BUILDER(Name("NGraphAssignAdd").Device(DEVICE_CPU),
NGraphAssignAddOp);
// Register NGraphOptimizers here
// These Optimizer Ops are replaced by a TF computational subgraph
// in ReplaceModifiers Rewrite Pass. Hence, these Stub Kernels/Op will never get
// called

// Keep them in alphabetical order
REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyGradientDescent");
REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyMomentum");
REGISTER_NGRAPH_STUB_KERNEL(
"NGraphAssignAdd"); //*input[0] = *input[0] + input[1]
REGISTER_NGRAPH_STUB_KERNEL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REGISTER_NGRAPH_STUB_KERNEL registers only for bfloat type? since its defn is:

#define REGISTER_NGRAPH_STUB_KERNEL(optype)                          \
  REGISTER_KERNEL_BUILDER(                                           \
      Name(optype).Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), \
      NGStubOp);

in that case is this replacement ok?

Copy link
Contributor

@shresthamalik shresthamalik Jan 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed the macros. the names were not matching the definition.

"NGraphAssignSub"); //*input[0] = *input[0] - input[1]

} // namespace ngraph_bridge

Expand Down
70 changes: 70 additions & 0 deletions ngraph_bridge/ngraph_register_stub_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

#include "ngraph_bridge/ngraph_register_stub_kernels.h"

using namespace std;

namespace tensorflow {

namespace ngraph_bridge {

/* -------------------------------------------------
//
// NGraphStubOp
//
---------------------------------------------------*/
// Constructor
NGStubOp::NGStubOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES(
context, false,
errors::Internal("The constructor for OpType ", type_string(),
"should not get called. This Op is expected to have "
"been encapsulated or replaced by other ops. Op Name: ",
name(), "\n"));
}
// Compute
void NGStubOp::Compute(OpKernelContext* context) {
OP_REQUIRES(
context, false,
errors::Internal("This kernel for OpType ", type_string(),
"should not get called. This Op is expected to have "
"been encapsulated or replaced by other ops. Op Name: ",
name(), "\n"));
}
// Destructor
NGStubOp::~NGStubOp() {}

/* ------------------------------------------------- */

// Register Bfloat Stub Kernels

// TF Ops that work on bfloat DataType get assigned Device XLA_CPU
// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub
// bfloat16 kernels here. The expectation is when we register the stub kernels
// for bfloat16 TF is going to assign DEVICE_CPU to the respective Ops and
// we will encapsulate them
// These Stub Kernels/Op will never get called

// Keep them in alphabetical order
REGISTER_NGRAPH_STUB_BFLOAT_KERNEL("Conv2D")

} // namespace ngraph_bridge

} // namespace tensorflow
56 changes: 56 additions & 0 deletions ngraph_bridge/ngraph_register_stub_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_
#define NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace std;

namespace tensorflow {

namespace ngraph_bridge {

/* -------------------------------------------------
//
// NGStubOp
//
---------------------------------------------------*/

class NGStubOp : public OpKernel {
public:
explicit NGStubOp(OpKernelConstruction* context);

void Compute(OpKernelContext* context) override;

private:
~NGStubOp() override;
};

#define REGISTER_NGRAPH_STUB_KERNEL(optype) \
REGISTER_KERNEL_BUILDER(Name(optype).Device(DEVICE_CPU), NGStubOp);

#define REGISTER_NGRAPH_STUB_BFLOAT_KERNEL(optype) \
REGISTER_KERNEL_BUILDER( \
Name(optype).Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), \
NGStubOp);

} // namespace ngraph_bridge

} // namespace tensorflow

#endif // NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_
18 changes: 13 additions & 5 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ Status TensorToStream(std::ostream& ostream, const Tensor& tensor) {
case DT_BOOL:
TensorDataToStream<bool>(ostream, n_elements, data);
break;
case DT_BFLOAT16:
return errors::Internal(
"TensorToStream got data type bfloat16. No compatible standard C++ "
"data type.");
break;
default:
return errors::Internal("TensorToStream got unsupported data type ",
DataType_Name(tensor.dtype()));
Expand Down Expand Up @@ -272,6 +277,8 @@ Status TFDataTypeToNGraphElementType(DataType tf_dt,
break;
case DataType::DT_QINT32:
*ng_et = ng::element::i32;
case DataType::DT_BFLOAT16:
*ng_et = ng::element::bf16;
break;
default:
return errors::Unimplemented("Unsupported TensorFlow data type: ",
Expand Down Expand Up @@ -322,15 +329,16 @@ void print_node_histogram(const std::unordered_map<string, int>& histogram,

const gtl::ArraySlice<DataType>& NGraphDTypes() {
static gtl::ArraySlice<DataType> result{
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT32, DT_UINT64, DT_BOOL, DT_QINT8, DT_QUINT8};
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16};
return result;
}

const gtl::ArraySlice<DataType>& NGraphNumericDTypes() {
static gtl::ArraySlice<DataType> result{
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64,
DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16};
return result;
}

Expand Down Expand Up @@ -358,7 +366,7 @@ const gtl::ArraySlice<DataType>& NGraphSupportedQuantizedDTypes() {
}

const gtl::ArraySlice<DataType>& NGraphRealDTypes() {
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE};
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16};
return result;
}

Expand Down
Loading