Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mobile] QNN failed to finalize QNN graph for attention layer #21221

Open
Novelfor opened this issue Jul 1, 2024 · 2 comments
Open

[Mobile] QNN failed to finalize QNN graph for attention layer #21221

Novelfor opened this issue Jul 1, 2024 · 2 comments
Labels
ep:QNN issues related to QNN exeution provider platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization

Comments

@Novelfor
Copy link

Novelfor commented Jul 1, 2024

Describe the issue

I train a qat self attention model by Pytorch FX, the model can be run in libQnnCpu.so but error in libQnnHtp.so.
The model run in linux x86.
QNN: 2.20.0.240223
ERROR Message:

2024-07-01 21:56:45.190463253 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] graph_prepare.cc:205:ERROR:could not create op: q::flat_from_vtcm

2024-07-01 21:56:45.190499590 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] graph_prepare.cc:1168:ERROR:Op 0x1aa00000005b preparation failed with err:-1

2024-07-01 21:56:45.190521298 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] QnnDsp <E> "/MatMul" generated: could not create op

2024-07-01 21:56:45.190525810 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] QnnDsp <E> RouterX86 graph prepare failed 12

2024-07-01 21:56:45.190532281 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] QnnDsp <E> Failed to finalize graph (id: 1) with err 1002

2024-07-01 21:56:45.190539347 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] QnnDsp <V> Wake up free backend 1 thread(s)

2024-07-01 21:56:45.190549138 [V:onnxruntime:test, qnn_backend_manager.cc:249 QnnLogging] QnnDsp <I> QnnGraph_finalize done. status 0x3ea

�[1;31m2024-07-01 21:56:45.190555828 [E:onnxruntime:, qnn_model.cc:181 FinalizeGraphs] Failed to finalize QNN graph.�[m

To reproduce

I write the minimal reproduce code, the pytorch code to generate "test_int8.onnx", and use c++ code to run it.

I only test it in linux x86, but i guess it will be consistent on the Android side.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, f'dim {dim} should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = True

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_mask = None

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
                                  self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p if self.training else 0.,
        )

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

def to_static_shape_onnx(model_file):
    mp = onnx.load(model_file)
    mp.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1    
    mp.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1    
    mp = onnx.shape_inference.infer_shapes(mp)
    onnx.save(mp, model_file)

from torch.ao.quantization import (
    get_default_qconfig,
    get_default_qat_qconfig,
    QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx

attn_model = Attention(320, 10)
attn_model(torch.randn(1, 10, 320))

qconfig_mapping = QConfigMapping()\
    .set_global(get_default_qat_qconfig("qnnpack"))

model_prepared = quantize_fx.prepare_qat_fx(
    attn_model, qconfig_mapping, (torch.randn(1, 10, 320), ))

for i in range(10):
    model_prepared(torch.randn(1, 10, 320))
    
quant_model = quantize_fx.convert_fx(model_prepared)
torch.onnx.export(quant_model, torch.randn(1, 10, 320), "test_int8.onnx", input_names=['input'], output_names=['output'])
to_static_shape_onnx("test_int8.onnx")

Init this model in c++

#include <iostream>
#include <vector>
#include "session/onnxruntime_cxx_api.h"

template <class T>
inline std::vector<float> forward(Ort::Session& session, std::vector<T>& inputs, 
    const std::vector<int64_t>& input_shape_, const std::vector<const char*> input_names, const std::vector<const char*> output_names) {
    std::vector<float> temp_data;
    temp_data.resize(inputs.size());
    for (int i = 0; i < inputs.size(); i++) {
        temp_data[i] = inputs[i];
    }
    
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    auto input_tensor_ = Ort::Value::CreateTensor<float>(memory_info, temp_data.data(), temp_data.size(), input_shape_.data(), input_shape_.size());
    float* input_buffer = input_tensor_.GetTensorMutableData<float>();
    Ort::Value output_tensor_(nullptr);
    session.Run(Ort::RunOptions{nullptr}, input_names.data(), &input_tensor_, input_names.size(), output_names.data(), &output_tensor_, output_names.size());
    float* output_buffer = output_tensor_.GetTensorMutableData<float>();
    std::vector<float> output_data(output_buffer, output_buffer + output_tensor_.GetTensorTypeAndShapeInfo().GetElementCount());
    return output_data;
}

bool CheckStatus(const OrtApi* g_ort, OrtStatus* status) {
  if (status != nullptr) {
    const char* msg = g_ort->GetErrorMessage(status);
    std::cerr << msg << std::endl;
    g_ort->ReleaseStatus(status);
    throw Ort::Exception(msg, OrtErrorCode::ORT_EP_FAIL);
  }
  return true;
}

int main() {
    const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
    OrtEnv* env;
    // g_ort->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, "test", &env);
    g_ort->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, "test", &env);

    OrtSessionOptions* session_options;
    CheckStatus(g_ort, g_ort->CreateSessionOptions(&session_options));
    CheckStatus(g_ort, g_ort->SetIntraOpNumThreads(session_options, 1));
    CheckStatus(g_ort, g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_BASIC));

    std::vector<const char*> options_keys = {"backend_path"};
    std::vector<const char*> options_values = {"libQnnHtp.so"};

    CheckStatus(g_ort, g_ort->SessionOptionsAppendExecutionProvider(session_options, "QNN", options_keys.data(), options_values.data(), options_keys.size()));
    

    OrtSession* session;
    CheckStatus(g_ort, g_ort->CreateSession(env, "test_int8.onnx", session_options, &session));
  
    std::cout << "FINISH" << std::endl;
    return 0;
}

Urgency

No response

Platform

Android

OS Version

linux

ONNX Runtime Installation

Built from Source

Compiler Version (if 'Built from Source')

gcc9

Package Name (if 'Released Package')

None

ONNX Runtime Version or Commit ID

8c26898

ONNX Runtime API

C++/C

Architecture

X64

Execution Provider

Other / Unknown

Execution Provider Library Version

QNN: 2.20.0.240223

@Novelfor Novelfor added the platform:mobile issues related to ONNX Runtime mobile; typically submitted using template label Jul 1, 2024
@github-actions github-actions bot added the quantization issues related to quantization label Jul 1, 2024
@Novelfor Novelfor changed the title [Mobile] QNN Model finalize QNN graph for attention layer [Mobile] QNN failed to finalize QNN graph for attention layer Jul 1, 2024
@yf711 yf711 added the ep:QNN issues related to QNN exeution provider label Jul 2, 2024
@Unicorncosmos
Copy link

Hi, @Novelfor did you found any rectification?..
for the qnn model run on the htp the same issue iam also facing runs on cpu not htp
` 0.0ms [ ERROR ] graph_prepare.cc:203:ERROR:could not create op: q::flat_from_vtcm

 0.0ms [ ERROR ] graph_prepare.cc:1187:ERROR:Op 0x1a47e400000026 preparation failed with err:-1

 0.0ms [ ERROR ] QnnDsp <E> "_encoder_backbone_backbone_0_Conv" generated: could not create op

 0.0ms [ ERROR ] QnnDsp <E> RouterX86 graph prepare failed 12

 0.0ms [ ERROR ] QnnDsp <E> Failed to finalize graph (id: 1) with err 1002

 0.0ms [VERBOSE] QnnDsp <V> Wake up free backend 1 thread(s)

 0.0ms [  INFO ] QnnDsp <I> QnnGraph_finalize done. status 0x3ea

43309.1ms [ ERROR ] Finalize Graph for Idx = 0 failed with error = 1002
Graph Finalize failure
0.0ms [VERBOSE] QnnDsp Final cleanup: free backend id 0x1

 0.0ms [WARNING] QnnDsp <W> Backend 1 free cleanup called during process exit

 0.0ms [VERBOSE] QnnDsp <V> Terminated backend 0x1 successfully in backendLifecycleManager

 0.0ms [VERBOSE] QnnDsp <V> Final context cleanup: contextId = 1!

 0.0ms [VERBOSE] QnnDsp <V> qnnOpPackageManager: unloading OpPackages...

 0.0ms [VERBOSE] QnnDsp <V> qnnOpPackageManager: OpPackge already unloaded.

 0.0ms [VERBOSE] QnnDsp <V> RouterNative tryUnLoadPrepare Disabled`

@Novelfor
Copy link
Author

Novelfor commented Aug 9, 2024

I try use aimet to quantize transformer... it works https://github.com/quic/aimet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:QNN issues related to QNN exeution provider platform:mobile issues related to ONNX Runtime mobile; typically submitted using template quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants