From a36692066da2b9b4d1d6a91f45630af5c62d3288 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 5 May 2024 03:16:55 -0400 Subject: [PATCH] [java] CUDA & TensorRT options fix (#20549) ### Description I misunderstood how UpdateCUDAProviderOptions and UpdateTensorRTProviderOptions work in the C API, I had assumed that they updated the options struct, however they re-initialize the struct to the defaults then only apply the values in the update. I've rewritten the Java bindings for those classes so that they aggregate all the updates and apply them in one go. I also updated the C API documentation to note that these classes have this behaviour. I've not checked if any of the other providers with an options struct have this behaviour, we only expose CUDA and TensorRT's options in Java. There's a small unrelated update to add a private constructor to the Fp16Conversions classes to remove a documentation warning (they shouldn't be instantiated anyway as they are utility classes containing static methods). ### Motivation and Context Fixes #20544. --- .../core/session/onnxruntime_c_api.h | 4 +- .../onnxruntime/platform/Fp16Conversions.java | 4 +- .../ai/onnxruntime/OrtProviderOptions.java | 9 +++- .../main/java/ai/onnxruntime/OrtSession.java | 6 ++- .../providers/OrtCUDAProviderOptions.java | 13 +++-- .../providers/OrtTensorRTProviderOptions.java | 13 +++-- .../StringConfigProviderOptions.java | 31 ++++++++---- .../onnxruntime/platform/Fp16Conversions.java | 2 + ...runtime_providers_OrtCUDAProviderOptions.c | 47 +++++++++++++++---- ...ime_providers_OrtTensorRTProviderOptions.c | 47 +++++++++++++++---- .../java/ai/onnxruntime/InferenceTest.java | 5 +- .../providers/ProviderOptionsTest.java | 41 +++++++++------- 12 files changed, 158 insertions(+), 64 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index de3013484b1ab..b4be501d3f00a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -2937,7 +2937,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="trt_max_workspace_size" and value="2147483648" * @@ -3433,7 +3433,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="device_id" and value="0" * diff --git a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java index dd7dd07fc1f5d..c5ee8aa5b4648 100644 --- a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java @@ -17,7 +17,9 @@ /** * Conversions between fp16, bfloat16 and fp32. */ public final class Fp16Conversions { private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); - + + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 70af10ff8cd79..ca7bf2f317ce4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -53,6 +53,13 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Applies the Java side configuration to the native side object. + * + * @throws OrtException If the native call failed. + */ + protected abstract void applyToNative() throws OrtException; + /** * Is the native object closed? * diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index fbea13d155507..8ab4a1cb26bb1 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException { public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractCUDA()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) cudaOpts).applyToNative(); addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle); } else { throw new OrtException( @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException { public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractTensorRT()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) tensorRTOpts).applyToNative(); addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle); } else { throw new OrtException( diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java index b7a83708a2314..6c1e8f02e90af 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java index 958d3a9e18f9b..0a69f0b72415b 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 961163035c9a6..8abc227d23aef 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException { Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); - add(getApiHandle(), nativeHandle, key, value); } /** @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException { public void parseOptionsString(String serializedForm) throws OrtException { String[] options = serializedForm.split(";"); for (String o : options) { - if (!o.isEmpty() && o.contains("=")) { + if (o.contains("=")) { String[] curOption = o.split("="); if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) { add(curOption[0], curOption[1]); @@ -76,15 +75,31 @@ public String getOptionsString() { .collect(Collectors.joining(";", "", ";")); } + @Override + protected void applyToNative() throws OrtException { + if (!options.isEmpty()) { + String[] keys = new String[options.size()]; + String[] values = new String[options.size()]; + int i = 0; + for (Map.Entry e : options.entrySet()) { + keys[i] = e.getKey(); + values[i] = e.getValue(); + i++; + } + + applyToNative(getApiHandle(), this.nativeHandle, keys, values); + } + } + /** - * Adds an option to this options instance. + * Add all the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param key The option keys. + * @param value The option values. * @throws OrtException If the addition failed. */ - protected abstract void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected abstract void applyToNative( + long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException; } diff --git a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java index fce872688aa1f..451c0d9848586 100644 --- a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java @@ -54,6 +54,8 @@ public final class Fp16Conversions { fp32ToFp16 = tmp32; } + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c index 22907fc65c16c..46df515c2e235 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -24,19 +24,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_cre /* * Class: ai_onnxruntime_providers_OrtCUDAProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free((void*)keys); + } + if (values != NULL) { + free((void*)values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + free((void*)keys); + free((void*)values); + } } /* diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c index 9146e7dd589aa..404a80f118306 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -23,19 +23,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions /* * Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free((void*)keys); + } + if (values != NULL) { + free((void*)values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + free((void*)keys); + free((void*)values); + } } /* diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index ac65cbab146bf..3340a2e5e9f3a 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -678,6 +678,9 @@ private void runProvider(OrtProvider provider) throws OrtException { if (provider == OrtProvider.CORE_ML) { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); + } else if (provider == OrtProvider.CUDA) { + // CUDA gives slightly different answers on a H100 with CUDA 12.2 + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } else { assertArrayEquals(expectedOutput, resultArray, 1e-5f); } diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 0e3bc15ba9c70..8dfea92c9ff10 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,40 +41,49 @@ public void testCUDAOptions() throws OrtException { OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); sessionOpts.addCUDA(cudaOpts); runProvider(OrtProvider.CUDA, sessionOpts); + sessionOpts.close(); + cudaOpts.close(); // Test invalid device num throws assertThrows(IllegalArgumentException.class, () -> new OrtCUDAProviderOptions(-1)); // Test invalid key name throws - OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0); - assertThrows(OrtException.class, () -> invalidValueOpts.add("gpu_mem_limit", "not a number")); + try (OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0)) { + invalidValueOpts.add("gpu_mem_limit", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") public void testTensorRT() throws OrtException { // Test standard options - OrtTensorRTProviderOptions cudaOpts = new OrtTensorRTProviderOptions(0); - cudaOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); + OrtTensorRTProviderOptions rtOpts = new OrtTensorRTProviderOptions(0); + rtOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); - sessionOpts.addTensorrt(cudaOpts); + sessionOpts.addTensorrt(rtOpts); runProvider(OrtProvider.TENSOR_RT, sessionOpts); + sessionOpts.close(); + rtOpts.close(); // Test invalid device num throws assertThrows(IllegalArgumentException.class, () -> new OrtTensorRTProviderOptions(-1)); // Test invalid key name throws - OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidValueOpts.add("trt_max_workspace_size", "not a number")); + try (OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0)) { + invalidValueOpts.add("trt_max_workspace_size", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } private static void runProvider(OrtProvider provider, OrtSession.SessionOptions options) @@ -96,7 +105,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-5f); + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); }