Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling compatibility with ONNX Runtime #74

Open
MatejUrbanQC opened this issue Sep 12, 2024 · 1 comment
Open

Handling compatibility with ONNX Runtime #74

MatejUrbanQC opened this issue Sep 12, 2024 · 1 comment

Comments

@MatejUrbanQC
Copy link
Contributor

The ONNX Runtime is missing support for certain data types in some operations. This is currently handled in NDOnnx by casting the operand to a supported type before and casting back after the operation.

The problem is that this generates an inefficient ONNX graph without the user being aware.

Example

Here's an example code with uint64, for which ONNX Runtime doesn't implement Add.
In the "Implicit cast" case, I let NDOnnx do the conversion for each operation.
In the "Explicit cast" case, I just cast to int64 at the beginning of the whole computation and cast back at the end.

import time
import ndonnx as ndx
import onnx
import onnxruntime as ort
import numpy as np

# BUILD

def compute(x):
    return x + x + x + x + x + x

x = ndx.array(shape=('N',), dtype=ndx.uint64)

# Implicit cast:
y1 = compute(x)

model1 = ndx.build({'x': x}, {'y': y1})
onnx.save(model1, "model1.onnx")

# Explicit cast:
y2 = compute(x.astype(ndx.int64)).astype(ndx.uint64)

model2 = ndx.build({'x': x}, {'y': y2})
onnx.save(model2, "model2.onnx")

# RUN

input1 = np.arange(100000000, dtype=np.uint64)

options1 = ort.SessionOptions()
options1.optimized_model_filepath = "optimized1.onnx"
session1 = ort.InferenceSession(model1.SerializeToString(), options1)
start = time.time()
session1.run(None, {'x': input1})
end = time.time()
print(f"Implicit cast: {end-start:.3f}s")

options2 = ort.SessionOptions()
options2.optimized_model_filepath = "optimized2.onnx"
session2 = ort.InferenceSession(model2.SerializeToString(), options2)
start = time.time()
session2.run(None, {'x': input1})
end = time.time()
print(f"Explicit cast: {end-start:.3f}s")

Result:

Implicit cast: 0.634s
Explicit cast: 0.305s

Models:

Implicit:
optimized1 onnx
Explicit:
optimized2 onnx

Suggested changes

There are at least 4 ways to handle these casts:

  • Cast Silent - the current behaviour
  • Cast with Warn - raise a Warning each time we do this casting
  • Error - don't do any casting and just fail if there is not a ONNX Runtime implementation for the given dtype
  • No Cast - proceed without cast, producing a graph that is not runnable by ONNX Runtime

I suggest "Cast with Warn" as the default mode. This lets the user know that we are producing an inefficient graph, and they can choose to use a different dtype instead, or just ignore the warning. Getting these warnings can also serve as a signal to upstream an implementation to ONNX Runtime.

I also suggest having the option to switch to the other modes via a flag.

@adityagoel4512
Copy link
Member

I agree with the general direction of this which is to give the user greater control over the generated graph. Warnings would certainly allow users to know when casting is occurring but they can do nothing about it other than step out of ndonnx and into Spox. This issue seems tightly connected to #42 and I think a better solution might be to allow users to specify what minimum version of ORT the generated graph will be executable on and then only apply appropriate casts for this version. This way a) we give the user some control over the compatibility characteristics they want out of the generated graph and b) we don't end up needing to make executive decisions on when is a reasonable amount of time to drop certain dtype workarounds (things are versioned all the way through).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants