Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference_session.cc Exception during initialization: invalid unordered_map<K, T> key #18885

Closed
quic-zhanweiw opened this issue Dec 20, 2023 · 4 comments
Labels
ep:DML issues related to the DirectML execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform

Comments

@quic-zhanweiw
Copy link

quic-zhanweiw commented Dec 20, 2023

Describe the issue

I compile ONNXRumtime('main' branch) on Lenovo x13s(ARM64 windows 11) device. The build command as below:

build.bat --use_dml --config Release --build_wheel --parallel --skip_tests --compile_no_warning_as_error --skip_submodule_sync

I compiled this onnxruntime with the parameter '--use_dml' and the extension name is 'onnxruntime-directml':

While I load 'text_encoder' model with below python code, I got error. I'm using the model got from 'https://huggingface.co/tlwu/stable-diffusion-v1-5/tree/fp16'. It was mentioned in the document: 'https://medium.com/microsoftazure/accelerating-stable-diffusion-inference-with-onnx-runtime-203bd7728540'

providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
ort_session = onnxruntime.InferenceSession(os.path.join(model_dir, "text_encoder/model.onnx"), providers=providers)

Error log:

    ort_session = onnxruntime.InferenceSession(os.path.join(model_dir, "text_encoder/model.onnx"), providers=providers)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Programs\Python\Python311-arm64\Lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "C:\Programs\Python\Python311-arm64\Lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 483, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: invalid unordered_map<K, T> key

It can load 'unet' and 'vae_decoder' model successfully.

By adding log for debugging this issue, I found issue is in 'GraphDesc [email protected]', it crash at the line 'const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());':

            printf("BuildGraphDesc 1.1\n");
            printf("BuildGraphDesc nameToNodeAndIndexMap 2 %s\n", graphOutput->Name().c_str());
            const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
            printf("BuildGraphDesc 1.2\n");

Log:

BuildGraphDesc 1
BuildGraphDesc 1.1
BuildGraphDesc nameToNodeAndIndexMap 2 /text_model/Gather_3_output_0_CUDAExecutionProvider
2023-12-20 15:22:24.0923759 [E:onnxruntime:, inference_session.cc:1872 onnxruntime::InferenceSession::Initialize::<lambda_64f75e57d93441df5f487a54bcee363c>::operator ()] Exception during initialization: invalid unordered_map<K, T> key

Here the 'graphOutput->Name()' is '/text_model/Gather_3_output_0_CUDAExecutionProvider'. My device doesn't support CUDA and in my code I hasn't using 'CUDAExecutionProvider'. But there is output name about 'CUDAExecutionProvider'.
May we think this model has problem for running with 'DmlExecutionProvider' or 'CPUExecutionProvider'?

And by adding log in the below code, we can see this name hasn't been added to the 'nameToNodeAndIndexMap' variable:

            for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges)
            {
                const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex];
                if (arg->Exists())
                {
                    nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex {
                        operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex],
                        operatorGraphOutputEdge.FromNodeOutputIndex
                    };
                    printf("BuildGraphDesc nameToNodeAndIndexMap 1::%s\n", arg->Name().c_str());

                    nodeOutputShapes[arg->Name()] = outputShapes;
                }
            }

To reproduce

Compile onnxruntime and run it with ARM64 Python 3.11.5.

Urgency

No response

Platform

Windows

OS Version

22H2

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.17.0(main branch)

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

DirectML

Execution Provider Library Version

1.17.0

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:DML issues related to the DirectML execution provider platform:windows issues related to the Windows platform model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Dec 20, 2023
@quic-zhanweiw
Copy link
Author

It seems the 'std::unordered_map<std::string, NodeAndIndex> nameToNodeAndIndexMap;' usage in 'BuildGraphDesc()' function has problem.

Before this code line, there're no any code to add any items to 'nameToNodeAndIndexMap':
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp#L350

How can we access this nameToNodeAndIndexMap with 'at' method here? If this method was called here, it will cause exception, right?

const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name());

Thanks for your supporting!

@PatriceVignola
Copy link
Contributor

Hi @quic-zhanweiw ,

The model that you were trying to use was specific to CUDA and was converted using a CUDA EP, therefore DML is not able to run it. You can convert the model yourself by following the steps over here.

@quic-zhanweiw
Copy link
Author

Thanks @PatriceVignola for your support. I'll try the Olive script.

@nums11
Copy link
Contributor

nums11 commented Mar 26, 2024

Closing as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:DML issues related to the DirectML execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

4 participants