Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kernel tests for ops that changed in opset18 #19767

Merged
merged 11 commits into from
Mar 19, 2024
37 changes: 28 additions & 9 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {

ArgDef grad = GO(0);
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
Expand All @@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
grad = IA("Unqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
}
}

Expand Down Expand Up @@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
}

ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));

result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")}));
}
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
} else {
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
Expand Down Expand Up @@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));

if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
} else if (numInputs == 2) { // optional input 'axes' is available as input I(1)
result.emplace_back(
NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def}));
}
}

result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));
Expand Down
30 changes: 27 additions & 3 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) {

OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
RunReductionTests(op_def_opset13);

// axes is input from opset 18.
OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18};
RunReductionTests(op_def_opset18, true, true);
}

TEST(GradientCheckerTest, ReduceSumGrad) {
Expand All @@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
OpDef op_def_13{"ReduceSum", kOnnxDomain, 13};

RunReductionTests(op_def_13, true, true);

OpDef op_def_18{"ReduceSum", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);
}

TEST(GradientCheckerTest, ReduceL2Grad) {
Expand All @@ -641,13 +649,22 @@ TEST(GradientCheckerTest, ReduceL2Grad) {
{MakeAttribute("axes", axes)}));
EXPECT_IS_TINY(max_error);
}

// axes is input from opset 18
OpDef op_def_18{"ReduceL2", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);
}

TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};

RunReductionTests(op_def);

OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18};

RunReductionTests(op_def_opset18, true, true);
}

TEST(GradientCheckerTest, ReluGrad) {
Expand Down Expand Up @@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) {
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
{MakeAttribute("axis", int64_t(0))}));
EXPECT_IS_TINY(max_error);

// opset18 test
OpDef op_def_18{"Split", kOnnxDomain, 18};
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
{MakeAttribute("axis", int64_t(0)),
MakeAttribute("num_outputs", int64_t(3))}));
EXPECT_IS_TINY(max_error);
}

template <typename T>
Expand Down Expand Up @@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) {
TEST(GradientCheckerTest, PadGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"Pad", kOnnxDomain, 11};
OpDef op_def{"Pad", kOnnxDomain, 18};

{
TensorInfo x_info({2, 4}, true);
Expand Down Expand Up @@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) {
TEST(GradientCheckerTest, ScatterNDGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"ScatterND", kOnnxDomain, 11};
OpDef op_def{"ScatterND", kOnnxDomain, 18};

{
TensorInfo data_info({8}, true);
Expand Down Expand Up @@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) {
TEST(GradientCheckerTest, ScatterElementsGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"ScatterElements", kOnnxDomain, 13};
OpDef op_def{"ScatterElements", kOnnxDomain, 18};

{ // without axis
TensorInfo data_info({3, 3}, true);
Expand Down
Loading