Skip to content

Commit

Permalink
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/r…
Browse files Browse the repository at this point in the history
…ot_dim
  • Loading branch information
wangyems committed Dec 18, 2023
2 parents c604e67 + ea6186e commit f50b583
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/stale.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
# Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale
stale-issue-label: "stale"
# Comment that you want to add to issues that are labeled by the actions/stale action
stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
# Comment that you want to add to issues that are closed by the actions/stale action
close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed."
# If you never want this action to label PRs, set this value to -1
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,8 +1035,11 @@ class PlannerImpl {
std::function<void(NodeIndex)> dfs = [&](NodeIndex curr) {
if (dependents.find(curr) == dependents.end()) {
dependents.insert(curr);
for (NodeIndex dep : dependence_graph_[curr]) {
dfs(dep);
auto dep_graph_iter = dependence_graph_.find(curr);
if (dep_graph_iter != dependence_graph_.end()) {
for (NodeIndex dep : dep_graph_iter->second) {
dfs(dep);
}
}
}
};
Expand Down
35 changes: 23 additions & 12 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,17 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
onnx_function_proto = *func_template_->onnx_func_proto_;
return true;
} else if (op_) {
auto get_opset_version = [op = op_](Graph* graph) -> std::optional<int> {
if (op->domain() == kOnnxDomain) {
const auto& domain_to_version = graph->DomainToVersionMap();
const auto iter = domain_to_version.find(kOnnxDomain);
if (iter != domain_to_version.cend()) {
return iter->second;
}
}
return {};
};

// Check if this node has a schema defined function proto.
if (op_->HasContextDependentFunction()) {
NodeProto node_proto;
Expand All @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
} else
input_types.emplace_back();
}

auto requested_opset_version = get_opset_version(graph_);
if (!requested_opset_version.has_value()) {
requested_opset_version = SinceVersion();
}
ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version);
} else if (op_->HasFunction()) {
const FunctionProto* function_ptr = nullptr;
// We need to get a function-body suitable for the ONNX opset used by the model.
Expand All @@ -605,17 +621,12 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
// as the default-version, which is incorrect in the case of functions belonging to
// non-onnx domains, like MSDOMAIN.

// We use the following as a temporary hack.
function_ptr = op_->GetFunction(SinceVersion(), false);

// TODO: Switch to following, once ONNX issue is fixed.
// auto& map = graph_->DomainToVersionMap();
// const auto iter = map.find(kOnnxDomain);
// if (iter != map.end()) {
// function_ptr = op_->GetFunction(iter->second, true);
// } else {
// function_ptr = op_->GetFunction();
// }
auto requested_opset_version = get_opset_version(graph_);
if (requested_opset_version.has_value()) {
function_ptr = op_->GetFunction(*requested_opset_version, true);
} else {
function_ptr = op_->GetFunction(SinceVersion(), false);
}

if (function_ptr != nullptr) {
onnx_function_proto = *function_ptr;
Expand Down
25 changes: 18 additions & 7 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2506,7 +2506,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
trt_config->setMaxWorkspaceSize(max_workspace_size_);
trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_);

// Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow
if (fp16_enable_ && layer_norm_fp32_fallback_) {
Expand Down Expand Up @@ -2723,13 +2723,24 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
}

// enable builder heuristics
#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5
if (build_heuristics_enable_) {
trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled."
<< " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics.";
}
#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8
#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
// for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2
if (build_heuristics_enable_) {
if (builder_optimization_level_ == 2) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards.";
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics.";
}
}
#endif

#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
// switch optimizaion level
if (builder_optimization_level_ != 3) {
trt_config->setBuilderOptimizationLevel(builder_optimization_level_);
Expand Down Expand Up @@ -3125,7 +3136,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_state->context->reset();
trt_state->engine->reset();
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr));
trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr));
for (auto trt_profile : trt_profiles) {
trt_config->addOptimizationProfile(trt_profile);
}
Expand Down Expand Up @@ -3166,7 +3177,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
}
#if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8
#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
// switch optimizaion level
if (trt_state->builder_optimization_level != 3) {
trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level);
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/test/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) {
AsSpan(output_names), &fetches, 0));
}

/// <summary>
/// This test covers the issues:
/// https://github.com/microsoft/onnxruntime/issues/16438
/// https://github.com/microsoft/onnxruntime/issues/18781
/// </summary>
TEST(FunctionTest, Test_GH_issue_16438) {
const char* code = R"(
<
ir_version: 8,
opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
producer_name: "pytorch",
producer_version: "2.1.0"
>
torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) {
_val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax <dim: int = 2, dtype: int = 11> (input_0)
}
<
domain: "pkg.onnxscript.torch_lib",
opset_import: ["" : 18]
>
aten_special_log_softmax <dim, dtype>(self) => (result_8)
{
tmp = Shape(self)
tmp_0 = Size(tmp)
int64_0 = Constant<value : tensor = int64 int64_0{0}> ()
int64_0_cast = CastLike(int64_0, tmp_0)
self_is_scalar = Equal(tmp_0, int64_0_cast)
self_4 = If(self_is_scalar) <then_branch : graph = thenGraph_8() => (self_2) {
tmp_1 = Constant<value_ints : ints = [0]> ()
self_2 = Unsqueeze(self, tmp_1)
}, else_branch : graph = elseGraph_8() => (self_3) {
self_3 = Identity(self)
}>
result = LogSoftmax<axis : int = @dim>(self_4)
result_5 = Cast<to : int = @dtype>(result)
result_8 = If(self_is_scalar) <then_branch : graph = thenGraph_12() => (result_6) {
result_6 = Squeeze(result_5)
}, else_branch : graph = elseGraph_12() => (result_7) {
result_7 = Identity(result_5)
}>
}
)";

std::string serialized_model;
ParseOnnxSource(code, serialized_model);
SessionOptions session_options;
InferenceSession session_object{session_options, GetEnvironment()};

std::stringstream sstr(serialized_model);
auto status = session_object.Load(sstr);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
}
} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ stages:
runTests: ${{ parameters.RunOnnxRuntimeTests }}
buildJava: true
java_artifact_id: onnxruntime_gpu
CudaVersion: 11.8

# CUDA with Tensorrt
- template: templates/win-ci.yml
Expand All @@ -253,10 +254,11 @@ stages:
buildArch: x64
msbuildPlatform: x64
packageName: x64-tensorrt
buildparameter: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80"
buildparameter: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80"
runTests: ${{ parameters.RunOnnxRuntimeTests }}
buildJava: true
java_artifact_id: onnxruntime_gpu
CudaVersion: 11.8
UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }}

# ROCm
Expand Down

0 comments on commit f50b583

Please sign in to comment.