Skip to content

Commit

Permalink
nGraph v0.26.0.rc.0 (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhu-nervana authored and avijit-nervana committed Sep 19, 2019
1 parent 8531aa1 commit 878d285
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 48 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ if (NOT USE_PRE_BUILT_NGRAPH)
ExternalProject_Add(
ext_ngraph
GIT_REPOSITORY https://github.com/NervanaSystems/ngraph
GIT_TAG v0.25.1-rc.2
GIT_TAG v0.26.0-rc.0
CMAKE_ARGS
-DNGRAPH_DISTRIBUTED_ENABLE=${NGRAPH_DISTRIBUTED_ENABLE}
-DNGRAPH_INSTALL_PREFIX=${NGRAPH_ARTIFACTS_DIR}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Once TensorFlow's dependencies are installed, clone the `ngraph-bridge` repo:

git clone https://github.com/tensorflow/ngraph-bridge.git
cd ngraph-bridge
git checkout v0.19.0-rc1
git checkout v0.26.0-rc.0

Run the following Python script to build TensorFlow, nGraph, and the bridge. Use Python 3.5:

Expand Down
8 changes: 4 additions & 4 deletions bazel/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ tf_workspace(path_prefix = "", tf_repo_name = "org_tensorflow")
http_archive(
name = "ngraph",
build_file = "//:bazel/ngraph.BUILD",
sha256 = "d3e0ebf807dd4179c91a41c80639b75f5ea8ca39cdfe3a5697da56a7df894f11",
strip_prefix = "ngraph-0.25.1-rc.2",
sha256 = "d6832d9f923027aa5cc9f49c319fe537d73102600576e4e69be19a0e37d13a43",
strip_prefix = "ngraph-0.26.0-rc.0",
urls = [
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.2.tar.gz",
"https://github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.2.tar.gz"
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.26.0-rc.0.tar.gz",
"https://github.com/NervanaSystems/ngraph/archive/v0.26.0-rc.0.tar.gz"
],
)

Expand Down
14 changes: 6 additions & 8 deletions bazel/ngraph.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ cc_library(
"src/ngraph/op/experimental/dyn_reshape.cpp",
"src/ngraph/op/experimental/dyn_slice.cpp",
"src/ngraph/op/experimental/generate_mask.cpp",
"src/ngraph/op/experimental/quantized_avg_pool.cpp",
"src/ngraph/op/experimental/quantized_concat.cpp",
"src/ngraph/op/experimental/quantized_conv.cpp",
"src/ngraph/op/experimental/quantized_conv_bias.cpp",
"src/ngraph/op/experimental/quantized_conv_relu.cpp",
"src/ngraph/op/experimental/quantized_max_pool.cpp",
"src/ngraph/op/experimental/shape_of.cpp",
"src/ngraph/op/experimental/range.cpp",
"src/ngraph/op/experimental/quantized_dot.cpp",
Expand Down Expand Up @@ -84,7 +82,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.2\\"',
'-D NGRAPH_VERSION=\\"v0.26.0-rc.0\\"',
"-D NGRAPH_DEX_ONLY",
'-D PROJECT_ROOT_DIR=\\"\\"',
'-D NGRAPH_STATIC_LIB_ENABLE'
Expand Down Expand Up @@ -117,7 +115,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.2\\"',
'-D NGRAPH_VERSION=\\"v0.26.0-rc.0\\"',
"-D NGRAPH_DEX_ONLY",
'-D PROJECT_ROOT_DIR=\\"\\"',
] + CXX_ABI,
Expand All @@ -129,6 +127,9 @@ cc_library(
visibility = ["//visibility:public"],
alwayslink = 1,
)
# TODO: If we update to mkl_dnn v1.0 in future, we should include
# the source file "src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.cpp"
# Currently we use legacy mkl_dnn, NGRAPH_USE_LEGACY_MKLDNN is set to TRUE by default

cc_library(
name = 'cpu_backend',
Expand Down Expand Up @@ -188,12 +189,10 @@ cc_library(
"src/ngraph/runtime/cpu/builder/reduce_function.cpp",
"src/ngraph/runtime/cpu/builder/replace_slice.cpp",
"src/ngraph/runtime/cpu/builder/quantization.cpp",
"src/ngraph/runtime/cpu/builder/quantized_avg_pool.cpp",
"src/ngraph/runtime/cpu/builder/quantized_conv.cpp",
"src/ngraph/runtime/cpu/builder/quantized_concat.cpp",
"src/ngraph/runtime/cpu/builder/quantized_dot.cpp",
"src/ngraph/runtime/cpu/builder/quantized_matmul.cpp",
"src/ngraph/runtime/cpu/builder/quantized_max_pool.cpp",
"src/ngraph/runtime/cpu/builder/reshape.cpp",
"src/ngraph/runtime/cpu/builder/reverse.cpp",
"src/ngraph/runtime/cpu/builder/reverse_sequence.cpp",
Expand Down Expand Up @@ -245,7 +244,6 @@ cc_library(
"src/ngraph/runtime/cpu/pass/cpu_mat_fusion.cpp",
"src/ngraph/runtime/cpu/pass/cpu_memory_assignment.cpp",
"src/ngraph/runtime/cpu/pass/cpu_memory_optimization.cpp",
"src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.cpp",
"src/ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.cpp",
"src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp",
"src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.cpp",
Expand All @@ -269,7 +267,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"0.25.1-rc.2\\"',
'-D NGRAPH_VERSION=\\"v0.26.0-rc.0\\"',
"-D NGRAPH_DEX_ONLY",
"-D NGRAPH_TBB_ENABLE",
'-D PROJECT_ROOT_DIR=\\"\\"',
Expand Down
2 changes: 1 addition & 1 deletion build_ngtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
'''

# Component versions
ngraph_version = "v0.25.1-rc.2"
ngraph_version = "v0.26.0-rc.0"
tf_version = "v1.14.0"

# Command line parser options
Expand Down
33 changes: 17 additions & 16 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
#include "tensorflow/core/lib/core/errors.h"

#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/dequantize_builder.hpp"
#include "ngraph/builder/numpy_transpose.hpp"
#include "ngraph/builder/quantization.hpp"
#include "ngraph/builder/quantize_builder.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/util/logical_reduction.hpp"
Expand Down Expand Up @@ -513,14 +514,14 @@ static Status TranslateQuantizedPoolOp(const Node* op,
if (is_quantizedAvgPool) {
// QuantizeAvgPool
// TF doesn't include padding in avg calculation
ng_quant_pool = ng::builder::ScaledQuantizedAvgPool(
ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above, false, dummy_min, dummy_max);
ng_quant_pool = ConstructNgNode<ng::op::AvgPool>(
op->name(), ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above, false);
} else {
// QuantizeMaxPool
ng_quant_pool = ng::builder::ScaledQuantizedMaxPool(
ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above, dummy_min, dummy_max);
ng_quant_pool = ConstructNgNode<ng::op::MaxPool>(
op->name(), ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above);
}

BatchToTensorflow(is_nhwc, ng_quant_pool);
Expand Down Expand Up @@ -3305,7 +3306,7 @@ static Status TranslateQuantizedConcatOpHelper(
shared_ptr<ng::Node> ng_max_of_maxs =
make_shared<ng::op::Constant>(ng::element::f32, ng::Shape{}, max_of_maxs);

auto ng_qconcat = ng::builder::ScaledQuantizedConcat(
auto ng_qconcat = ng::builder::QuantizedConcatBuilder(
ng_args, size_t(concat_axis), ng_all_mins, ng_all_maxs);

SaveNgOp(ng_op_map, op->name(), ng_qconcat);
Expand Down Expand Up @@ -3373,7 +3374,7 @@ static Status TranslateQuantizedConv(
ng_strides, ng_dilations, ng_padding_below,
ng_padding_above);

// It is expected by ScaledQuantizedConvolutionBias (and other builder
// It is expected by QuantizedConvolutionBiasBuilder (and other builder
// functions) that the min max inputs be constant nodes
// Hence declaring them static, reading their values and converting to
// constant nodes
Expand Down Expand Up @@ -3401,7 +3402,7 @@ static Status TranslateQuantizedConv2DWithBiasMaybeReluAndRequantizeOp(
std::vector<std::shared_ptr<ng::Node>> node_inps, ng::Strides ng_strides,
ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below,
ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) {
return ng::builder::ScaledQuantizedConvolutionBias(
return ng::builder::QuantizedConvolutionBiasBuilder(
node_inps[0], node_inps[1], node_inps[2], ng_strides, ng_dilations,
ng_padding_below, ng_padding_above, ng_data_dilations, node_inps[3],
node_inps[4], node_inps[5], node_inps[6], node_inps[7], node_inps[8],
Expand All @@ -3417,7 +3418,7 @@ static Status TranslateQuantizedConv2DWithBiasSumAndReluAndRequantizeOp(
std::vector<std::shared_ptr<ng::Node>> node_inps, ng::Strides ng_strides,
ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below,
ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) {
return ng::builder::ScaledQuantizedConvolutionBiasAdd(
return ng::builder::QuantizedConvolutionBiasAddBuilder(
node_inps[0], node_inps[1], node_inps[2], node_inps[9], ng_strides,
ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations,
node_inps[3], node_inps[4], node_inps[5], node_inps[6], node_inps[7],
Expand All @@ -3433,7 +3434,7 @@ static Status TranslateQuantizedConv2DWithBiasSignedSumAndReluAndRequantizeOp(
std::vector<std::shared_ptr<ng::Node>> node_inps, ng::Strides ng_strides,
ng::Strides ng_dilations, ng::CoordinateDiff ng_padding_below,
ng::CoordinateDiff ng_padding_above, ng::Strides ng_data_dilations) {
return ng::builder::ScaledQuantizedConvolutionBiasSignedAdd(
return ng::builder::QuantizedConvolutionBiasSignedAddBuilder(
node_inps[0], node_inps[1], node_inps[2], node_inps[9], ng_strides,
ng_dilations, ng_padding_below, ng_padding_above, ng_data_dilations,
node_inps[3], node_inps[4], node_inps[5], node_inps[6], node_inps[7],
Expand Down Expand Up @@ -3467,8 +3468,8 @@ static Status TranslateQuantizeV2Op(const Node* op,
ng::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN;

SaveNgOp(ng_op_map, op->name(),
ng::builder::ScaledQuantize(ng_input, ng_min, ng_max, ng_et,
ng::AxisSet(), ng_round_mode));
ng::builder::QuantizeBuilder(ng_input, ng_min, ng_max, ng_et,
ng::AxisSet(), ng_round_mode));
SaveNgOp(ng_op_map, op->name(), ng_min);
SaveNgOp(ng_op_map, op->name(), ng_max);

Expand All @@ -3483,8 +3484,8 @@ static Status TranslateDequantizeOp(const Node* op,

// TF only dequantizes to fp32
SaveNgOp(ng_op_map, op->name(),
ng::builder::ScaledDequantize(ng_input, ng_min, ng_max,
ng::element::f32, ng::AxisSet()));
ng::builder::DequantizeBuilder(ng_input, ng_min, ng_max,
ng::element::f32, ng::AxisSet()));
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion ngraph_bridge/ngraph_encapsulate_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ Status NGraphEncapsulateImpl::AllocateNGInputTensors(
std::unique_ptr<ngraph::Event> event_copy_input_next(
new ngraph::Event(event_name, m_name, ""));
current_ng_tensor->write(
current_src_ptr, 0,
current_src_ptr,
current_ng_tensor->get_element_count() * ng_element_type.size());

event_copy_input_next->Stop();
Expand Down
8 changes: 4 additions & 4 deletions ngraph_bridge/ngraph_encapsulate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
"Output_" + to_string(i) + "_" + to_string(copy_size);
std::unique_ptr<ngraph::Event> event_copy_output_next(
new ngraph::Event(event_name, name(), ""));
dst_ng_tensor->read(dst_ptr, 0, dst_ng_tensor->get_element_count() *
ng_element_type.size());
dst_ng_tensor->read(dst_ptr, dst_ng_tensor->get_element_count() *
ng_element_type.size());
event_copy_output_next->Stop();
output_copy_events.push_back(std::move(event_copy_output_next));
}
Expand All @@ -648,8 +648,8 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
std::to_string(dst_ng_tensor->get_element_count() *
ng_element_type.size())),
name(), ""));
dst_ng_tensor->read(dst_ptr, 0, dst_ng_tensor->get_element_count() *
ng_element_type.size());
dst_ng_tensor->read(dst_ptr, dst_ng_tensor->get_element_count() *
ng_element_type.size());
event_copy_output_next->Stop();
output_copy_events.push_back(std::move(event_copy_output_next));
}
Expand Down
8 changes: 4 additions & 4 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ void ReadNGTensor(shared_ptr<ng::runtime::Tensor> ng_tensor,
Tensor* tf_tensor) {
ngraph::Event event_sync_ng_tf_tensors("Tensor Read D2H", "", "");
void* tf_src_ptr = (void*)DMAHelper::base(tf_tensor);
ng_tensor->read(tf_src_ptr, 0, ng_tensor->get_element_count() *
ng_tensor->get_element_type().size());
ng_tensor->read(tf_src_ptr, ng_tensor->get_element_count() *
ng_tensor->get_element_type().size());
event_sync_ng_tf_tensors.Stop();
ngraph::Event::write_trace(event_sync_ng_tf_tensors);
}
Expand All @@ -100,8 +100,8 @@ void WriteNGTensor(shared_ptr<ng::runtime::Tensor> ng_tensor,
Tensor* tf_tensor) {
ngraph::Event event_sync_ng_tf_tensors("Tensor Write H2D", "", "");
void* tf_src_ptr = (void*)DMAHelper::base(tf_tensor);
ng_tensor->write(tf_src_ptr, 0, ng_tensor->get_element_count() *
ng_tensor->get_element_type().size());
ng_tensor->write(tf_src_ptr, ng_tensor->get_element_count() *
ng_tensor->get_element_type().size());
event_sync_ng_tf_tensors.Stop();
ngraph::Event::write_trace(event_sync_ng_tf_tensors);
}
Expand Down
2 changes: 1 addition & 1 deletion ngraph_bridge/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ template <typename T>
T GetScalarFromTensor(const std::shared_ptr<ngraph::runtime::Tensor>& t,
size_t element_offset = 0) {
T result;
t->read(&result, element_offset * sizeof(T), sizeof(T));
t->read(&result, sizeof(T));
return result;
}

Expand Down
4 changes: 2 additions & 2 deletions ngraph_bridge/version.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
// nGraph-TensorFlow bridge uses semantic versioning: see http://semver.org/

#define NG_TF_MAJOR_VERSION 0
#define NG_TF_MINOR_VERSION 19
#define NG_TF_MINOR_VERSION 20
#define NG_TF_PATCH_VERSION 0

// The version suffix is used for pre-release version numbers
// For example before v0.7.0 we may do a pre-release i.e., a release
// candidate such as v0.7.0-rc0
// The code in master will always have the last released version number
// with a suffix of '-master'
#define NG_TF_VERSION_SUFFIX "-rc1"
#define NG_TF_VERSION_SUFFIX "-rc0"

#define VERSION_STR_HELPER(x) #x
#define VERSION_STR(x) VERSION_STR_HELPER(x)
Expand Down
2 changes: 1 addition & 1 deletion python/setup.in.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_tag(self):

setup(
name='ngraph_tensorflow_bridge',
version='0.19.0rc1',
version='0.20.0rc0',
description='Intel nGraph compiler and runtime for TensorFlow',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
4 changes: 2 additions & 2 deletions test/graph_exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ TEST(graph_exec, axpy) {

auto t_x = backend->create_tensor(ng::element::f32, ng_shape_x);
float v_x[2][3] = {{1, 1, 1}, {1, 1, 1}};
t_x->write(&v_x, 0, sizeof(v_x));
t_x->write(&v_x, sizeof(v_x));

auto t_y = backend->create_tensor(ng::element::f32, ng_shape_y);
t_y->write(&v_x, 0, sizeof(v_x));
t_y->write(&v_x, sizeof(v_x));

// Allocate tensor for the result(s)
vector<shared_ptr<ng::runtime::Tensor>> outputs;
Expand Down
4 changes: 2 additions & 2 deletions test/opexecuter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,
std::shared_ptr<ngraph::runtime::Tensor> result;
if (ng_backend_type != "CPU") {
result = backend->create_tensor(ng_et, ng_shape);
result->write(src_ptr, 0, result->get_element_count() * ng_et.size());
result->write(src_ptr, result->get_element_count() * ng_et.size());
} else {
result = backend->create_tensor(ng_et, ng_shape, src_ptr);
}
Expand Down Expand Up @@ -433,7 +433,7 @@ void OpExecuter::ExecuteOnNGraph(vector<Tensor>& ngraph_outputs,
// Convert to tf tensor
Tensor output_tensor(expected_output_datatypes_[i], tf_op_shapes[i]);
void* dst_ptr = DMAHelper::base(&output_tensor);
ng_op_tensors[i]->read(dst_ptr, 0, output_tensor.TotalBytes());
ng_op_tensors[i]->read(dst_ptr, output_tensor.TotalBytes());
ngraph_outputs.push_back(output_tensor);
NGRAPH_VLOG(5) << " NGRAPH op " << i << ngraph_outputs[i].DebugString();
}
Expand Down

0 comments on commit 878d285

Please sign in to comment.