Skip to content

Commit

Permalink
Add FlattenAndUnpad Op
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Oct 11, 2023
1 parent 2ef6ee6 commit 69db77a
Show file tree
Hide file tree
Showing 16 changed files with 433 additions and 107 deletions.
17 changes: 10 additions & 7 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,13 +755,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {

IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("No_use")})};
}

IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
{GO(0), I(1), O(1)},
{GI(0)})};
}

IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
Expand Down
24 changes: 21 additions & 3 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4737,14 +4737,12 @@ Return true if all elements are true and false otherwise.
" unflatten_dims: [2, 3], shape is [2]"

" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
" shape is [2, 3, 4]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")

Check notice on line 4743 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] orttraining/orttraining/core/graph/training_op_defs.cc#L4743

"behaviour" is a misspelling of "behavior"
Raw output
./orttraining/orttraining/core/graph/training_op_defs.cc:4743:49: "behaviour" is a misspelling of "behavior"
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
Expand All @@ -4758,6 +4756,26 @@ Return true if all elements are true and false otherwise.
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");

ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformers.")
.Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INT")
.Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices and shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");

Check warning on line 4777 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/graph/training_op_defs.cc#L4777

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/graph/training_op_defs.cc:4777:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
return new_expand_node->MutableOutputDefs()[0];
}

// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// Insert FlattenAndUnpad to flatten the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertFlattenPatternForInput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gather_index_arg,
const logging::Logger& logger) {
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.reserve(2);
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
std::vector<int64_t> new_shape;
new_shape.push_back(-1); // only support flatten 0 and 1 dims
auto input_shape = node.InputDefs()[in_index]->Shape();
ORT_ENFORCE(input_shape->dim_size() >= 2);
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
} else {
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
flattened_shape.add_dim()->set_dim_param(token_dim_name);
}
for (int k = 2; k < input_shape->dim_size(); k++) {
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
new_shape.push_back(input_shape->dim(k).dim_value());
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
}
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
new_shape_const_tensor.add_dims(new_shape.size());
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
reshape_input_args.push_back(new_shape_arg);

InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));

Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("Reshape"),
"Reshape",
"Reshape node to filter invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
"",
logger);
InlinedVector<NodeArg*> unpad_input_args;
unpad_input_args.reserve(2);
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
unpad_input_args.push_back(gather_index_arg);

new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];

reshape_out_arg->SetShape(flattened_shape);

InlinedVector<NodeArg*> gather_input_args;
gather_input_args.reserve(2);
gather_input_args.push_back(reshape_output_args[0]);
gather_input_args.push_back(gather_index_arg);

InlinedVector<NodeArg*> gather_output_args;
gather_output_args.push_back(
InlinedVector<NodeArg*> unpad_output_args;
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
reshape_out_arg->TypeAsProto()));
nullptr));
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
nullptr));

Node* new_gather_node = InsertIntermediateNodeOnDestInput(
Node* unpad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("PaddingFilter"),
"ShrunkenGather",
"ShrunkenGather node to filter invalid tokens.",
gather_input_args,
gather_output_args,
"FlattenAndUnpad",
"FlattenAndUnpad node to filter invalid tokens.",
unpad_input_args,
unpad_output_args,
{},
kMSDomain,
logger);

new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
return gather_out_arg;
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
return unpad_out_arg;
}

// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
Expand All @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));

Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
Expand Down
3 changes: 1 addition & 2 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};

TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
Expand All @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
#endif

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
{padded_out_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5901,9 +5901,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case >= 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")

def find_input_node_type(model, arg):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning test

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

#if defined(USE_CUDA) || defined(USE_ROCM)


TEST(FlattenAndUnpadTest, Int32Type1D) {
std::vector<int32_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<int32_t> output = {1, 2, 3, 4, 5, 6};
std::vector<int64_t> unflatten_dims = {5, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<int32_t>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, Int32Type2D) {
std::vector<int32_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<int32_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int64_t> unflatten_dims = {2, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<int32_t>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, Int64Type1D) {
std::vector<int64_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<int64_t> output = {1, 2, 3, 4, 5, 6};
std::vector<int64_t> unflatten_dims = {5, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<int64_t>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, Int64Type2D) {
std::vector<int64_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<int64_t> output = {1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int64_t> unflatten_dims = {2, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<int64_t>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, FloatType1D) {
std::vector<float> input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<float>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, FloatType2D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<float>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type1D) {
std::vector<float> input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};

std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));

Check warning on line 119 in orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc#L119

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc:119:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));

Check warning on line 122 in orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc#L122

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc:122:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {5, 3}, input_half);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<MLFloat16>("output", {6}, output_half);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};

std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));

Check warning on line 142 in orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc#L142

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc:142:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]
std::vector<MLFloat16> output_half;
output_half.resize(output.size());
ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size()));

Check warning on line 145 in orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc#L145

Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc:145:  Using deprecated casting style.  Use static_cast<int>(...) instead  [readability/casting] [4]

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", {2, 3, 3}, input_half);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<MLFloat16>("output", {3, 3}, output_half);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

#endif

} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit 69db77a

Please sign in to comment.