Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] OpenVINO EP Produces incorrect inference results #20357

Closed
henxing opened this issue Apr 17, 2024 · 1 comment
Closed

[Performance] OpenVINO EP Produces incorrect inference results #20357

henxing opened this issue Apr 17, 2024 · 1 comment
Labels
ep:OpenVINO issues related to OpenVINO execution provider performance issues related to performance regressions quantization issues related to quantization stale issues that have not been addressed in a while; categorized by a bot

Comments

@henxing
Copy link

henxing commented Apr 17, 2024

Describe the issue

When using onnxruntime-openvino==1.16.0 on python3.8, I'm seeing inference results that do not match what the model produces in PyTorch, for certain models. I'm running with the OpenVINOExecutionProvider and using the GPU device. This seems to be a regression, as when I run with version 1.14.0, I do not see the difference in scores.

This is possibly related to #19975, where I previously left a comment.

To reproduce

See below for a script that reproduces the issue. The results of the BrokenModel will disagree, but the results of the other two models will be very close, if not the same.

import numpy as np
import onnxruntime as rt
import torch
from torch import nn


class BrokenModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv_1(x)
        output = self.conv_2(x)
        return output.mean(dim=(1, 2, 3))


class BatchMeanModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv_1(x)
        output = self.conv_2(x)
        return output.mean(dim=(1, 2, 3)), output.mean()


class FewChannelModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(3, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv_1(x)
        output = self.conv_2(x)
        return output.mean(dim=(1, 2, 3))


def run_model_pytorch_onnxruntime(arch, path):
    model = arch()
    model.eval()
    print("=" * 80)
    print(model)

    data = torch.ones(2, 3, 224, 224)
    data[0] *= 0

    print("Torch:")
    for _ in range(2):
        result = model(data)
        print(result)
    print()

    torch.onnx.export(
        model,
        data,
        path,
        input_names=["input"],
        output_names=["output"],
        export_params=True,
        dynamic_axes={name: {0: "batch_size"} for name in ("input", "output")},
        verbose=False,
    )

    sess_options = rt.SessionOptions()
    sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL

    print("Onnxruntime:")
    rt_sess = rt.InferenceSession(
        path, sess_options, providers=["OpenVINOExecutionProvider"], provider_options=[{"device_id": "GPU"}]
    )
    for _ in range(2):
        outputs = rt_sess.run(None, {"input": data.numpy()})
        print(outputs)
    print()


if __name__ == "__main__":
    run_model_pytorch_onnxruntime(BrokenModel, "broken_model.onnx")
    print()
    run_model_pytorch_onnxruntime(BatchMeanModel, "batch_mean_model.onnx")
    print()
    run_model_pytorch_onnxruntime(FewChannelModel, "few_channel_model.onnx")

You'll need to install torch, onnxruntime-openvino, and numpy to run this script.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

OpenVINO

Execution Provider Library Version

bundled with onnxruntime-openvino

Model File

No response

Is this a quantized model?

Yes

@github-actions github-actions bot added ep:OpenVINO issues related to OpenVINO execution provider quantization issues related to quantization labels Apr 17, 2024
@sophies927 sophies927 added the performance issues related to performance regressions label Apr 25, 2024
Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label May 28, 2024
@henxing henxing closed this as completed Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:OpenVINO issues related to OpenVINO execution provider performance issues related to performance regressions quantization issues related to quantization stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

2 participants