diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index e675b55c8af8f..22dcf4eb92411 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { ArgDef grad = GO(0); if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); if (attributes.find("axes") != attributes.end()) { std::vector axes_values = RetrieveValues(attributes.at("axes")); grad = IA("Unqueezed_Grad"); @@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { result.push_back(axes_values_node); result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad})); } + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + grad = IA("Unqueezed_Grad"); + result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad})); } } @@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) { } ArgDef grad = GO(0); - if (!keepdims && attributes.find("axes") != attributes.end()) { - std::vector axes_values = RetrieveValues(attributes.at("axes")); - grad = IA("Unsqueezed_Grad"); - result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); + if (attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + grad = IA("Unsqueezed_Grad"); - result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)})); + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + + result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)})); + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + grad = IA("Unsqueezed_Grad"); + result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad})); + + result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")})); + } result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")})); } else { result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")})); @@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) { ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY"); result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def})); - if (!keepdims && attributes.find("axes") != attributes.end()) { - std::vector axes_values = RetrieveValues(attributes.at("axes")); + if (!keepdims) { + size_t numInputs = GetSrcNodeInputSize(); scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY"); - result.emplace_back( - NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)})); + if (attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + result.emplace_back( + NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)})); + } else if (numInputs == 2) { // optional input 'axes' is available as input I(1) + result.emplace_back( + NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def})); + } } result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)})); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index feca94ae27c13..94ca96c68f2ce 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) { OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13}; RunReductionTests(op_def_opset13); + + // axes is input from opset 18. + OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18}; + RunReductionTests(op_def_opset18, true, true); } TEST(GradientCheckerTest, ReduceSumGrad) { @@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) { OpDef op_def_13{"ReduceSum", kOnnxDomain, 13}; RunReductionTests(op_def_13, true, true); + + OpDef op_def_18{"ReduceSum", kOnnxDomain, 18}; + + RunReductionTests(op_def_18, true, true); } TEST(GradientCheckerTest, ReduceL2Grad) { @@ -641,6 +649,11 @@ TEST(GradientCheckerTest, ReduceL2Grad) { {MakeAttribute("axes", axes)})); EXPECT_IS_TINY(max_error); } + + // axes is input from opset 18 + OpDef op_def_18{"ReduceL2", kOnnxDomain, 18}; + + RunReductionTests(op_def_18, true, true); } TEST(GradientCheckerTest, ReduceLogSumExpGrad) { @@ -648,6 +661,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) { OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11}; RunReductionTests(op_def); + + OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18}; + + RunReductionTests(op_def_opset18, true, true); } TEST(GradientCheckerTest, ReluGrad) { @@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) { ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error, {MakeAttribute("axis", int64_t(0))})); EXPECT_IS_TINY(max_error); + + // opset18 test + OpDef op_def_18{"Split", kOnnxDomain, 18}; + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error, + {MakeAttribute("axis", int64_t(0)), + MakeAttribute("num_outputs", int64_t(3))})); + EXPECT_IS_TINY(max_error); } template @@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) { TEST(GradientCheckerTest, PadGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"Pad", kOnnxDomain, 11}; + OpDef op_def{"Pad", kOnnxDomain, 18}; { TensorInfo x_info({2, 4}, true); @@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) { TEST(GradientCheckerTest, ScatterNDGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"ScatterND", kOnnxDomain, 11}; + OpDef op_def{"ScatterND", kOnnxDomain, 18}; { TensorInfo data_info({8}, true); @@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) { TEST(GradientCheckerTest, ScatterElementsGrad) { float max_error; GradientChecker gradient_checker; - OpDef op_def{"ScatterElements", kOnnxDomain, 13}; + OpDef op_def{"ScatterElements", kOnnxDomain, 18}; { // without axis TensorInfo data_info({3, 3}, true);