-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sindhu/bfloat16 support #399
Changes from 24 commits
694e63c
2ab87c7
b267ba1
3ffb02e
367d3db
453a304
22e7755
c6220b7
bea7c4d
4cfb27f
266b24a
062a3c3
f00e298
5644eb6
0a4ffdd
80c46f8
eb145c7
4d91711
5f08083
e50323a
e35892d
a95c92f
5d313e3
f636278
d2a161f
1e4923c
0bb58e0
957bf01
9fce56c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/******************************************************************************* | ||
* Copyright 2019-2020 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
||
#include "ngraph_bridge/ngraph_register_stub_kernels.h" | ||
|
||
using namespace std; | ||
|
||
namespace tensorflow { | ||
|
||
namespace ngraph_bridge { | ||
|
||
/* ------------------------------------------------- | ||
// | ||
// NGraphStubOp | ||
// | ||
---------------------------------------------------*/ | ||
// Constructor | ||
NGStubOp::NGStubOp(OpKernelConstruction* context) : OpKernel(context) { | ||
OP_REQUIRES( | ||
context, false, | ||
errors::Internal("The constructor for OpType ", type_string(), | ||
"should not get called. This Op is expected to have " | ||
"been encapsulated or replaced by other ops. Op Name: ", | ||
name(), "\n")); | ||
} | ||
// Compute | ||
void NGStubOp::Compute(OpKernelContext* context) { | ||
OP_REQUIRES( | ||
context, false, | ||
errors::Internal("This kernel for OpType ", type_string(), | ||
"should not get called. This Op is expected to have " | ||
"been encapsulated or replaced by other ops. Op Name: ", | ||
name(), "\n")); | ||
} | ||
// Destructor | ||
NGStubOp::~NGStubOp() {} | ||
|
||
/* ------------------------------------------------- */ | ||
|
||
// Register Bfloat Stub Kernels | ||
|
||
// TF Ops that work on bfloat DataType get assigned Device XLA_CPU | ||
// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub | ||
// bfloat16 | ||
// kernels here. The expectation is when we register the stub kernels for | ||
// bfloat16 | ||
// TF is going to assign DEVICE_CPU to the respective Ops and we will | ||
// encapsulate them | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub |
||
// These Stub Kernels/Op will never get called | ||
|
||
// Keep them in alphabetical order | ||
REGISTER_NGRAPH_STUB_KERNEL("Conv2D") | ||
|
||
} // namespace ngraph_bridge | ||
|
||
} // namespace tensorflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/******************************************************************************* | ||
* Copyright 2019-2020 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
#ifndef NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ | ||
#define NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ | ||
|
||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
||
using namespace std; | ||
|
||
namespace tensorflow { | ||
|
||
namespace ngraph_bridge { | ||
|
||
/* ------------------------------------------------- | ||
// | ||
// NGStubOp | ||
// | ||
---------------------------------------------------*/ | ||
|
||
class NGStubOp : public OpKernel { | ||
public: | ||
explicit NGStubOp(OpKernelConstruction* context); | ||
|
||
void Compute(OpKernelContext* context) override; | ||
|
||
private: | ||
~NGStubOp() override; | ||
}; | ||
|
||
#define REGISTER_NGRAPH_STUB_KERNEL(optype) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name(optype).Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), \ | ||
NGStubOp); | ||
|
||
#define REGISTER_NGRAPH_STUB_BFLOAT_KERNEL(optype) \ | ||
REGISTER_KERNEL_BUILDER(Name(optype).Device(DEVICE_CPU), NGStubOp); | ||
|
||
} // namespace ngraph_bridge | ||
|
||
} // namespace tensorflow | ||
|
||
#endif // NGRAPH_TF_BRIDGE_REGISTER_STUB_KERNELS_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -223,6 +223,9 @@ Status TensorToStream(std::ostream& ostream, const Tensor& tensor) { | |
case DT_BOOL: | ||
TensorDataToStream<bool>(ostream, n_elements, data); | ||
break; | ||
case DT_BFLOAT16: | ||
TensorDataToStream<bool>(ostream, n_elements, data); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It says There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Not sure what the corresponding data type for bfloat is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can throw an error or return a bad status for now I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
break; | ||
default: | ||
return errors::Internal("TensorToStream got unsupported data type ", | ||
DataType_Name(tensor.dtype())); | ||
|
@@ -272,6 +275,8 @@ Status TFDataTypeToNGraphElementType(DataType tf_dt, | |
break; | ||
case DataType::DT_QINT32: | ||
*ng_et = ng::element::i32; | ||
case DataType::DT_BFLOAT16: | ||
*ng_et = ng::element::bf16; | ||
break; | ||
default: | ||
return errors::Unimplemented("Unsupported TensorFlow data type: ", | ||
|
@@ -322,15 +327,16 @@ void print_node_histogram(const std::unordered_map<string, int>& histogram, | |
|
||
const gtl::ArraySlice<DataType>& NGraphDTypes() { | ||
static gtl::ArraySlice<DataType> result{ | ||
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, | ||
DT_UINT16, DT_UINT32, DT_UINT64, DT_BOOL, DT_QINT8, DT_QUINT8}; | ||
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, | ||
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, | ||
DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16}; | ||
return result; | ||
} | ||
|
||
const gtl::ArraySlice<DataType>& NGraphNumericDTypes() { | ||
static gtl::ArraySlice<DataType> result{ | ||
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, | ||
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64}; | ||
DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, | ||
DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16}; | ||
return result; | ||
} | ||
|
||
|
@@ -358,7 +364,7 @@ const gtl::ArraySlice<DataType>& NGraphSupportedQuantizedDTypes() { | |
} | ||
|
||
const gtl::ArraySlice<DataType>& NGraphRealDTypes() { | ||
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE}; | ||
static gtl::ArraySlice<DataType> result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}; | ||
return result; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REGISTER_NGRAPH_STUB_KERNEL
registers only for bfloat type? since its defn is:in that case is this replacement ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed the macros. the names were not matching the definition.