Skip to content

Commit

Permalink
Implement All op for GNMT training (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
kanvi-nervana authored and avijit-nervana committed Oct 1, 2018
1 parent 920e0ad commit de71d65
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 6 deletions.
69 changes: 63 additions & 6 deletions src/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,62 @@ static Status TranslateAddNOp(
return Status::OK();
}

static Status TranslateAllOp(const Node* op,
const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
shared_ptr<ng::Node> ng_input, ng_axes_op;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input, &ng_axes_op));

bool tf_keep_dims;
if (GetNodeAttr(op->attrs(), "keep_dims", &tf_keep_dims) != Status::OK()) {
tf_keep_dims = false;
}

std::vector<int64> all_axes;
TF_RETURN_IF_ERROR(GetStaticInputVector(op, 1, static_input_map, &all_axes));

ng::Shape input_shape = ng_input->get_shape();
size_t input_rank = ng_input->get_shape().size();

TF_RETURN_IF_ERROR(CheckAxisDimInRange(all_axes, input_rank));

std::vector<size_t> ng_reduction_axes_vect(all_axes.size());
std::transform(
all_axes.begin(), all_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::vector<bool> init_val = {true};
auto arg_init = make_shared<ng::op::Constant>(ng_input->get_element_type(),
ng::Shape(0), init_val);
auto f_A = make_shared<ng::op::Parameter>(ng::element::boolean, ng::Shape{});
auto f_B = make_shared<ng::op::Parameter>(ng::element::boolean, ng::Shape{});
auto ng_and = make_shared<ng::Function>(make_shared<ng::op::And>(f_A, f_B),
ng::op::ParameterVector{f_A, f_B});

shared_ptr<ng::Node> ng_all = make_shared<ng::op::Reduce>(
ng_input, arg_init, ng_and, ng_reduction_axes);

// If keep_dims is specified we need to reshape to put back the reduced
// axes, with length 1.
if (tf_keep_dims) {
ng::Shape ng_result_shape_with_keep(input_rank);

for (size_t i = 0; i < input_rank; i++) {
ng_result_shape_with_keep[i] =
ng_reduction_axes.count(i) == 0 ? input_shape[i] : 1;
}

ng::AxisVector ng_axis_order(ng_all->get_shape().size());
std::iota(ng_axis_order.begin(), ng_axis_order.end(), 0);
ng_all = make_shared<ng::op::Reshape>(ng_all, ng_axis_order,
ng_result_shape_with_keep);
}

SaveNgOp(ng_op_map, op->name(), ng_all);
return Status::OK();
}

static Status TranslateArgMaxOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
Expand Down Expand Up @@ -1626,7 +1682,7 @@ static Status TranslateMaxOp(const Node* op,
std::vector<size_t> ng_reduction_axes_vect(max_axes.size());
std::transform(
max_axes.begin(), max_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? input_rank : 0); });
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_max =
Expand Down Expand Up @@ -1794,7 +1850,7 @@ static Status TranslateMeanOp(
std::vector<size_t> ng_reduction_axes_vect(mean_axes.size());
std::transform(
mean_axes.begin(), mean_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? input_rank : 0); });
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_mean =
Expand Down Expand Up @@ -1843,7 +1899,7 @@ static Status TranslateMinOp(const Node* op,
std::vector<size_t> ng_reduction_axes_vect(min_axes.size());
std::transform(
min_axes.begin(), min_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? input_rank : 0); });
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_min =
Expand Down Expand Up @@ -1985,7 +2041,7 @@ static Status TranslateProdOp(
std::vector<size_t> ng_reduction_axes_vect(prod_axes.size());
std::transform(
prod_axes.begin(), prod_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? input_rank : 0); });
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_prod =
Expand Down Expand Up @@ -2324,7 +2380,7 @@ static Status TranslateSparseSoftmaxCrossEntropyWithLogitsOp(
" while building op ", op->type_string());
}

// Logits/Featues and Labels must have the same first dimension
// Logits/Features and Labels must have the same first dimension
if (ng_labels_shape[0] != ng_features_shape[0]) {
return errors::InvalidArgument(
" Logits/Features and Labels must have the same first dimension, got "
Expand Down Expand Up @@ -2677,7 +2733,7 @@ static Status TranslateSumOp(const Node* op,
std::vector<size_t> ng_reduction_axes_vect(sum_axes.size());
std::transform(
sum_axes.begin(), sum_axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? input_rank : 0); });
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
ng::AxisSet ng_reduction_axes(ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_sum =
Expand Down Expand Up @@ -2859,6 +2915,7 @@ const static std::map<
{"Abs", TranslateUnaryOp<ngraph::op::Abs>},
{"Add", TranslateBinaryOp<ngraph::op::Add>},
{"AddN", TranslateAddNOp},
{"All", TranslateAllOp},
{"ArgMax", TranslateArgMaxOp},
{"AvgPool", TranslateAvgPoolOp},
{"AvgPoolGrad", TranslateAvgPoolGradOp},
Expand Down
2 changes: 2 additions & 0 deletions src/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ Status MarkForClustering(Graph* graph) {
type_constraint_map["Abs"]["T"] = NGraphNumericDTypes();
type_constraint_map["Add"]["T"] = NGraphNumericDTypes();
type_constraint_map["AddN"]["T"] = NGraphNumericDTypes();
type_constraint_map["All"]["Tidx"] = NGraphIndexDTypes();
type_constraint_map["ArgMax"]["T"] = NGraphNumericDTypes();
type_constraint_map["ArgMax"]["Tidx"] = NGraphIndexDTypes();
type_constraint_map["AvgPool"]["T"] = NGraphNumericDTypes();
Expand Down Expand Up @@ -309,6 +310,7 @@ Status MarkForClustering(Graph* graph) {
confirmation_functions["Abs"] = SimpleConfirmationFunction();
confirmation_functions["Add"] = SimpleConfirmationFunction();
confirmation_functions["AddN"] = SimpleConfirmationFunction();
confirmation_functions["All"] = SimpleConfirmationFunction({1});
confirmation_functions["ArgMax"] = SimpleConfirmationFunction({1});
confirmation_functions["AvgPool"] = SimpleConfirmationFunction();
confirmation_functions["AvgPoolGrad"] = SimpleConfirmationFunction({0});
Expand Down
3 changes: 3 additions & 0 deletions test/opexecuter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ void OpExecuter::CompareNGraphAndTF() {
case DT_INT64:
AssertTensorEquals<int64>(tf_outputs_[i], ngraph_outputs_[i]);
break;
case DT_BOOL:
AssertTensorEquals<bool>(tf_outputs_[i], ngraph_outputs_[i]);
break;
default:
EXPECT_TRUE(false)
<< "Could not find the corresponding function for the "
Expand Down
79 changes: 79 additions & 0 deletions test/test_math_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,85 @@ TEST(MathOps, AddN) {
opexecuter.RunTest();
} // end of test op AddN

// Test op: All
// All with attribute KeepDims set to true
TEST(MathOps, AllKeepDims) {
Scope root = Scope::NewRootScope();
int dim1 = 2;
int dim2 = 2;

std::vector<bool> v = {true, true, true, false};
Tensor A(DT_BOOL, TensorShape({dim1, dim2}));
auto keep_dims = ops::All::Attrs().KeepDims(true);

AssignInputValuesFromVector<bool>(A, v);

// axis at which the dimension will be inserted
// should be -rank <= axis < rank
int axis = 0;

vector<int> static_input_indexes = {1};
vector<DataType> output_datatypes = {DT_BOOL};

auto R = ops::All(root, A, axis, keep_dims);
std::vector<Output> sess_run_fetchoutputs = {R};
OpExecuter opexecuter(root, "All", static_input_indexes, output_datatypes,
sess_run_fetchoutputs);

opexecuter.RunTest();
}

TEST(MathOps, AllNegativeAxis) {
Scope root = Scope::NewRootScope();
int dim1 = 2;
int dim2 = 3;

std::vector<bool> v = {true, true, true, true, false, false};
Tensor A(DT_BOOL, TensorShape({dim1, dim2}));

AssignInputValuesFromVector<bool>(A, v);

// axis at which the dimension will be inserted
// should be -rank <= axis < rank
int axis = -1;

vector<int> static_input_indexes = {1};
vector<DataType> output_datatypes = {DT_BOOL};

auto R = ops::All(root, A, axis);
std::vector<Output> sess_run_fetchoutputs = {R};
OpExecuter opexecuter(root, "All", static_input_indexes, output_datatypes,
sess_run_fetchoutputs);

opexecuter.RunTest();
}

TEST(MathOps, AllPositiveAxis) {
Scope root = Scope::NewRootScope();
int dim1 = 3;
int dim2 = 3;

std::vector<bool> v = {true, true, true, true, false,
false, true, false, false};
Tensor A(DT_BOOL, TensorShape({dim1, dim2}));

AssignInputValuesFromVector<bool>(A, v);

// axis at which the dimension will be inserted
// should be -rank <= axis < rank
int axis = 1;

vector<int> static_input_indexes = {1};
vector<DataType> output_datatypes = {DT_BOOL};

auto R = ops::All(root, A, axis);
std::vector<Output> sess_run_fetchoutputs = {R};
OpExecuter opexecuter(root, "All", static_input_indexes, output_datatypes,
sess_run_fetchoutputs);

opexecuter.RunTest();
} // end of test op All

// Test op: BatchMatMul
TEST(MathOps, BatchMatMul2D) {
Scope root = Scope::NewRootScope();
Expand Down

0 comments on commit de71d65

Please sign in to comment.