-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sindhu/bfloat16 support #399
Conversation
- Enabled --var build to use parallel executor integrating weights-on-device and data pipelining - moved ngraph_var files outside the var build
367d3db
to
d74ac48
Compare
Will your work help with this issue #447 ? |
REGISTER_NGRAPH_STUB_KERNEL("NGraphApplyMomentum"); | ||
REGISTER_NGRAPH_STUB_KERNEL( | ||
"NGraphAssignAdd"); //*input[0] = *input[0] + input[1] | ||
REGISTER_NGRAPH_STUB_KERNEL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REGISTER_NGRAPH_STUB_KERNEL
registers only for bfloat type? since its defn is:
#define REGISTER_NGRAPH_STUB_KERNEL(optype) \
REGISTER_KERNEL_BUILDER( \
Name(optype).Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), \
NGStubOp);
in that case is this replacement ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed the macros. the names were not matching the definition.
ngraph_bridge/ngraph_utils.cc
Outdated
@@ -223,6 +223,9 @@ Status TensorToStream(std::ostream& ostream, const Tensor& tensor) { | |||
case DT_BOOL: | |||
TensorDataToStream<bool>(ostream, n_elements, data); | |||
break; | |||
case DT_BFLOAT16: | |||
TensorDataToStream<bool>(ostream, n_elements, data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It says <bool>
in the template. copy-paste error perhaps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Not sure what the corresponding data type for bfloat is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can throw an error or return a bad status for now I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments
// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub | ||
// bfloat16 | ||
// kernels here. The expectation is when we register the stub kernels for | ||
// bfloat16 | ||
// TF is going to assign DEVICE_CPU to the respective Ops and we will | ||
// encapsulate them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Since nGraph-bridge OPs work on TF DEVICE_CPU we are registering stub
// bfloat16 kernels here. The expectation is when we register the stub kernels
// for bfloat16 TF is going to assign DEVICE_CPU to the respective Ops and
// we will encapsulate them
test/python/test_bfloat16.py
Outdated
@@ -0,0 +1,131 @@ | |||
# ============================================================================== | |||
# Copyright 2019 Intel Corporation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright 2019-2020 Intel Corporation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add support in bridge for TF Graph with Ops that take in bfloat data type inputs.