Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] JNI refactor for OrtJniUtil #12516

Merged
merged 8 commits into from
Sep 9, 2022

Conversation

Craigacp
Copy link
Contributor

@Craigacp Craigacp commented Aug 9, 2022

Description:

Following on from #12013, #12281 and #12496 this PR fixes the JNI error handling in OrtJniUtil. The refactor of all the JNI code should be complete now. I'll revise the sparse tensor PR (#10653) after this has been merged as it touches many of the same parts of the code.

This change is independent of #12496.

Motivation and Context

@fs-eire
Copy link
Contributor

fs-eire commented Aug 11, 2022

after a sequence of changes, the exception handling is getting correct. however, before closing the issue we need to add some test cases ( which intentionally trigger exceptions ) and we need to check if the exception handling is working as expected. As we are doing changes in different PRs (#12013, #12281, #12496 and this one). You can do it separately. but we need the tests to make sure the code changes work and also to ensure it is not broken by future changes.

@Craigacp
Copy link
Contributor Author

Craigacp commented Aug 11, 2022

after a sequence of changes, the exception handling is getting correct. however, before closing the issue we need to add some test cases ( which intentionally trigger exceptions ) and we need to check if the exception handling is working as expected. As we are doing changes in different PRs (#12013, #12281, #12496 and this one). You can do it separately. but we need the tests to make sure the code changes work and also to ensure it is not broken by future changes.

I would do if I could replicate the issue in Java, but I can't. When I call session.run with a tensor with the incorrect shape I get the expected exception back and the JVM continues running. I think it's something about how Android implements the JNI spec, but I don't have an Android environment to test it in.

There are already tests which supply the wrong type or number of inputs into session.run and check that the appropriate exception is thrown.

I ran a quick test in jshell on both Linux and macOS x86_64 using v1.12.1 & v1.11.0 and for the following program:

import ai.onnxruntime.*;
var env = OrtEnvironment.getEnvironment();
var session = env.createSession("<path-to-an-mnist-cnn-model>");
//jshell> session.getInputInfo()
//==> {input_image=NodeInfo(name=input_image,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[-1, 1, 28, 28]))}
float[] input = new float[56];
var tensor = OnnxTensor.createTensor(env,input);
var outputs = session.run(Map.of("input_image",tensor));

I get the expected exception:

|  Exception ai.onnxruntime.OrtException: Error code - ORT_INVALID_ARGUMENT - message: Invalid rank for input: input_image Got: 1 Expected: 4 Please fix either the inputs or the model.
|        at OrtSession.run (Native Method)
|        at OrtSession.run (OrtSession.java:295)
|        at OrtSession.run (OrtSession.java:238)
|        at OrtSession.run (OrtSession.java:207)

and if I change the tensor so it's the right rank, but still has 56 elements rather than 784 I get:

|  Exception ai.onnxruntime.OrtException: Error code - ORT_INVALID_ARGUMENT - message: Got invalid dimensions for input: input_image for the following indices
 index: 1 Got: 2 Expected: 1
 index: 2 Got: 4 Expected: 28
 index: 3 Got: 7 Expected: 28
 Please fix either the inputs or the model.
|        at OrtSession.run (Native Method)
|        at OrtSession.run (OrtSession.java:295)
|        at OrtSession.run (OrtSession.java:238)
|        at OrtSession.run (OrtSession.java:207)

again, as expected.

@fs-eire
Copy link
Contributor

fs-eire commented Aug 11, 2022

When I call session.run with a tensor with the incorrect shape I get the expected exception back and the JVM continues running. I think it's something about how Android implements the JNI spec, but I don't have an Android environment to test it in.

Could you please help to add test cases in react_native\android\src\androidTest\java\ai\onnxruntime\reactnative\OnnxruntimeModuleTest.java ? If you don't have the Android environment, we can leverage the CI to test it.

@Craigacp
Copy link
Contributor Author

When I call session.run with a tensor with the incorrect shape I get the expected exception back and the JVM continues running. I think it's something about how Android implements the JNI spec, but I don't have an Android environment to test it in.

Could you please help to add test cases in react_native\android\src\androidTest\java\ai\onnxruntime\reactnative\OnnxruntimeModuleTest.java ? If you don't have the Android environment, we can leverage the CI to test it.

Sure. The Java code in there looks a little odd, is there some guidance on which bits of the JDK I can expect to work? It looks like it's using different Map and List implementations.

@fs-eire
Copy link
Contributor

fs-eire commented Aug 12, 2022

When I call session.run with a tensor with the incorrect shape I get the expected exception back and the JVM continues running. I think it's something about how Android implements the JNI spec, but I don't have an Android environment to test it in.

Could you please help to add test cases in react_native\android\src\androidTest\java\ai\onnxruntime\reactnative\OnnxruntimeModuleTest.java ? If you don't have the Android environment, we can leverage the CI to test it.

Sure. The Java code in there looks a little odd, is there some guidance on which bits of the JDK I can expect to work? It looks like it's using different Map and List implementations.

There are some types import from React Native data bridge:

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;

Those are defined in react native for inter-op between Java and JavaScript.
https://github.com/facebook/react-native/tree/main/ReactAndroid/src/main/java/com/facebook/react/bridge

I am using OpenJDK 11 + Android Studio Bumblebee

@Craigacp
Copy link
Contributor Author

Sure. The Java code in there looks a little odd, is there some guidance on which bits of the JDK I can expect to work? It looks like it's using different Map and List implementations.

There are some types import from React Native data bridge:

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;

Those are defined in react native for inter-op between Java and JavaScript. https://github.com/facebook/react-native/tree/main/ReactAndroid/src/main/java/com/facebook/react/bridge

I am using OpenJDK 11 + Android Studio Bumblebee

Ok, once this has been merged in I'll work up a PR to mirror the existing exception tests over to react, and add a new test checking specifically for input size & shape to both Java & react. I might need some help with the react one though, but we can discuss that in that PR.

@Craigacp
Copy link
Contributor Author

I made a PR with some new tests for the react-native portion #12659.

@Craigacp
Copy link
Contributor Author

@yuslepukhin please could you review this PR? It's the last one to finish off the JNI refactor.

@yuslepukhin yuslepukhin self-assigned this Aug 30, 2022
// length + 1 as we need to write out the final offset
size_t * offsets;
checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(size_t)*(length+1),(void**)&offsets));
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray) {
Copy link
Member

@yuslepukhin yuslepukhin Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allocator

Not needed #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@yuslepukhin
Copy link
Member

/azp run MacOS CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-python-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@yuslepukhin
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@yuslepukhin
Copy link
Member

/azp run orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@yuslepukhin
Copy link
Member

/azp run onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

yuslepukhin
yuslepukhin previously approved these changes Sep 7, 2022
@yuslepukhin
Copy link
Member

/azp run Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux Nuphar CI Pipeline,Linux OpenVINO CI Pipeline

@yuslepukhin
Copy link
Member

/azp run MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-python-checks-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@Craigacp
Copy link
Contributor Author

Craigacp commented Sep 7, 2022

I can replicate that Windows failure locally, I get an error out of the JVM saying The block at <hex-address> was not allocated by _aligned routines, use free(). It's only the sequence of maps tests which fail in this way, so I'll have a look through the changes from this PR and figure it out.

@yuslepukhin
Copy link
Member

I can replicate that Windows failure locally, I get an error out of the JVM saying The block at <hex-address> was not allocated by _aligned routines, use free(). It's only the sequence of maps tests which fail in this way, so I'll have a look through the changes from this PR and figure it out.

Our default CPU allocator uses aligned alloc and free. Can it be that something that was allocated by malloc is being deallocated by aligned free?

@Craigacp
Copy link
Contributor Author

Craigacp commented Sep 7, 2022

I can replicate that Windows failure locally, I get an error out of the JVM saying The block at <hex-address> was not allocated by _aligned routines, use free(). It's only the sequence of maps tests which fail in this way, so I'll have a look through the changes from this PR and figure it out.

Our default CPU allocator uses aligned alloc and free. Can it be that something that was allocated by malloc is being deallocated by aligned free?

Found it. In two places I was using AllocatorFree instead of ReleaseValue to free OrtValues. Which apparently runs just fine on clang & gcc, but Windows/MSVC does not like.

Fixing that makes the crash disappear on my Windows box.

@yuslepukhin
Copy link
Member

yuslepukhin commented Sep 7, 2022

I can replicate that Windows failure locally, I get an error out of the JVM saying The block at <hex-address> was not allocated by _aligned routines, use free(). It's only the sequence of maps tests which fail in this way, so I'll have a look through the changes from this PR and figure it out.

Our default CPU allocator uses aligned alloc and free. Can it be that something that was allocated by malloc is being deallocated by aligned free?

Found it. In two places I was using AllocatorFree instead of ReleaseValue to free OrtValues. Which apparently runs just fine on clang & gcc, but Windows/MSVC does not like.

It would be wrong everywhere. One needs to destroy the objects first and only then deallocate memory. Simply deallocating memory for OrtValue does not deallocate tensor memory which can be huge. ReleaseValue is a destructor, not memory management. On top of that, OrtValues are allocated from system memory, not from the ORT Allocators pool which can lead to memory corruption. We are lucky that Windows caught on a technicality.

@Craigacp
Copy link
Contributor Author

Craigacp commented Sep 7, 2022

I can replicate that Windows failure locally, I get an error out of the JVM saying The block at <hex-address> was not allocated by _aligned routines, use free(). It's only the sequence of maps tests which fail in this way, so I'll have a look through the changes from this PR and figure it out.

Our default CPU allocator uses aligned alloc and free. Can it be that something that was allocated by malloc is being deallocated by aligned free?

Found it. In two places I was using AllocatorFree instead of ReleaseValue to free OrtValues. Which apparently runs just fine on clang & gcc, but Windows/MSVC does not like.

It would be wrong everywhere. One needs to destroy the objects first and only then deallocate memory. Simply deallocating memory for OrtValue does not deallocate tensor memory which can be huge. ReleaseValue is a destructor, not memory management. On top of that, OrtValues are allocated from system memory, not from the ORT Allocators pool which can lead to memory corruption. We are lucky that Windows caught on a technicality.

I agree it was wrong everywhere, I'm surprised it didn't crash on the other platforms too. I've checked through the code and I don't think I'm using it incorrectly elsewhere.

@yuslepukhin
Copy link
Member

/azp run Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux Nuphar CI Pipeline,Linux OpenVINO CI Pipeline

@yuslepukhin
Copy link
Member

/azp run MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-python-checks-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@Craigacp
Copy link
Contributor Author

Craigacp commented Sep 8, 2022

That TensorRT test failure looks independent of all the changes in this PR, it's in the core library tests rather than any of the Java ones.

@yuslepukhin
Copy link
Member

That TensorRT test failure looks independent of all the changes in this PR, it's in the core library tests rather than any of the Java ones.

I will see to it.

Copy link
Member

@yuslepukhin yuslepukhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@yuslepukhin yuslepukhin merged commit 5d55b07 into microsoft:main Sep 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JAVA API does not handle exceptions correctly - causing crash or potential memory leak
3 participants