diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index de3013484b1ab..b4be501d3f00a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -2937,7 +2937,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="trt_max_workspace_size" and value="2147483648" * @@ -3433,7 +3433,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="device_id" and value="0" * diff --git a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java index dd7dd07fc1f5d..c5ee8aa5b4648 100644 --- a/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java @@ -17,7 +17,9 @@ /** * Conversions between fp16, bfloat16 and fp32. */ public final class Fp16Conversions { private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); - + + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 70af10ff8cd79..ca7bf2f317ce4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -53,6 +53,13 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Applies the Java side configuration to the native side object. + * + * @throws OrtException If the native call failed. + */ + protected abstract void applyToNative() throws OrtException; + /** * Is the native object closed? * diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index fbea13d155507..8ab4a1cb26bb1 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException { public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractCUDA()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) cudaOpts).applyToNative(); addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle); } else { throw new OrtException( @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException { public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException { checkClosed(); if (OnnxRuntime.extractTensorRT()) { + // Cast is to make the compiler pick the right overload. + ((OrtProviderOptions) tensorRTOpts).applyToNative(); addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle); } else { throw new OrtException( diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java index b7a83708a2314..6c1e8f02e90af 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtCUDAProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java index 958d3a9e18f9b..0a69f0b72415b 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtTensorRTProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException { String id = "" + deviceId; this.options.put("device_id", id); - add(getApiHandle(), this.nativeHandle, "device_id", id); } @Override @@ -59,17 +58,17 @@ public OrtProvider getProvider() { private static native long create(long apiHandle) throws OrtException; /** - * Adds an option to this options instance. + * Adds the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param keys The option keys. + * @param values The option values. * @throws OrtException If the addition failed. */ @Override - protected native void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected native void applyToNative( + long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException; /** * Closes this options instance. diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 961163035c9a6..8abc227d23aef 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException { Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); - add(getApiHandle(), nativeHandle, key, value); } /** @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException { public void parseOptionsString(String serializedForm) throws OrtException { String[] options = serializedForm.split(";"); for (String o : options) { - if (!o.isEmpty() && o.contains("=")) { + if (o.contains("=")) { String[] curOption = o.split("="); if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) { add(curOption[0], curOption[1]); @@ -76,15 +75,31 @@ public String getOptionsString() { .collect(Collectors.joining(";", "", ";")); } + @Override + protected void applyToNative() throws OrtException { + if (!options.isEmpty()) { + String[] keys = new String[options.size()]; + String[] values = new String[options.size()]; + int i = 0; + for (Map.Entry e : options.entrySet()) { + keys[i] = e.getKey(); + values[i] = e.getValue(); + i++; + } + + applyToNative(getApiHandle(), this.nativeHandle, keys, values); + } + } + /** - * Adds an option to this options instance. + * Add all the options to this options instance. * * @param apiHandle The api pointer. * @param nativeHandle The native options pointer. - * @param key The option key. - * @param value The option value. + * @param key The option keys. + * @param value The option values. * @throws OrtException If the addition failed. */ - protected abstract void add(long apiHandle, long nativeHandle, String key, String value) - throws OrtException; + protected abstract void applyToNative( + long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException; } diff --git a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java index fce872688aa1f..451c0d9848586 100644 --- a/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java +++ b/java/src/main/jvm/ai/onnxruntime/platform/Fp16Conversions.java @@ -54,6 +54,8 @@ public final class Fp16Conversions { fp32ToFp16 = tmp32; } + private Fp16Conversions() {} + /** * Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). * diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c index 22907fc65c16c..46df515c2e235 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtCUDAProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -24,19 +24,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_cre /* * Class: ai_onnxruntime_providers_OrtCUDAProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtCUDAProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtCUDAProviderOptionsV2* opts = (OrtCUDAProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free((void*)keys); + } + if (values != NULL) { + free((void*)values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateCUDAProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + free((void*)keys); + free((void*)values); + } } /* diff --git a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c index 9146e7dd589aa..404a80f118306 100644 --- a/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c +++ b/java/src/main/native/ai_onnxruntime_providers_OrtTensorRTProviderOptions.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -23,19 +23,46 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions /* * Class: ai_onnxruntime_providers_OrtTensorRTProviderOptions - * Method: add - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + * Method: applyToNative + * Signature: (JJ[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_add - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring key, jstring value) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_providers_OrtTensorRTProviderOptions_applyToNative + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jobjectArray jKeyArr, jobjectArray jValueArr) { (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; OrtTensorRTProviderOptionsV2* opts = (OrtTensorRTProviderOptionsV2*) optionsHandle; - const char* keyStr = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); - const char* valueStr = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); - checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, &keyStr, &valueStr, 1)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keyStr); - (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,valueStr); + + jsize keyLength = (*jniEnv)->GetArrayLength(jniEnv, jKeyArr); + const char** keys = (const char**) allocarray(keyLength, sizeof(const char*)); + const char** values = (const char**) allocarray(keyLength, sizeof(const char*)); + if ((keys == NULL) || (values == NULL)) { + if (keys != NULL) { + free((void*)keys); + } + if (values != NULL) { + free((void*)values); + } + throwOrtException(jniEnv, 1, "Not enough memory"); + } else { + // Copy out strings into UTF-8. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + keys[i] = (*jniEnv)->GetStringUTFChars(jniEnv, key, NULL); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jValueArr, i); + values[i] = (*jniEnv)->GetStringUTFChars(jniEnv, value, NULL); + } + // Write to the provider options. + checkOrtStatus(jniEnv,api,api->UpdateTensorRTProviderOptions(opts, keys, values, keyLength)); + // Release allocated strings. + for (jsize i = 0; i < keyLength; i++) { + jobject key = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,key,keys[i]); + jobject value = (*jniEnv)->GetObjectArrayElement(jniEnv, jKeyArr, i); + (*jniEnv)->ReleaseStringUTFChars(jniEnv,value,values[i]); + } + free((void*)keys); + free((void*)values); + } } /* diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index ac65cbab146bf..3340a2e5e9f3a 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -678,6 +678,9 @@ private void runProvider(OrtProvider provider) throws OrtException { if (provider == OrtProvider.CORE_ML) { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); + } else if (provider == OrtProvider.CUDA) { + // CUDA gives slightly different answers on a H100 with CUDA 12.2 + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } else { assertArrayEquals(expectedOutput, resultArray, 1e-5f); } diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 0e3bc15ba9c70..8dfea92c9ff10 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -41,40 +41,49 @@ public void testCUDAOptions() throws OrtException { OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); sessionOpts.addCUDA(cudaOpts); runProvider(OrtProvider.CUDA, sessionOpts); + sessionOpts.close(); + cudaOpts.close(); // Test invalid device num throws assertThrows(IllegalArgumentException.class, () -> new OrtCUDAProviderOptions(-1)); // Test invalid key name throws - OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtCUDAProviderOptions invalidKeyOpts = new OrtCUDAProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0); - assertThrows(OrtException.class, () -> invalidValueOpts.add("gpu_mem_limit", "not a number")); + try (OrtCUDAProviderOptions invalidValueOpts = new OrtCUDAProviderOptions(0)) { + invalidValueOpts.add("gpu_mem_limit", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") public void testTensorRT() throws OrtException { // Test standard options - OrtTensorRTProviderOptions cudaOpts = new OrtTensorRTProviderOptions(0); - cudaOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); + OrtTensorRTProviderOptions rtOpts = new OrtTensorRTProviderOptions(0); + rtOpts.add("trt_max_workspace_size", "" + (512 * 1024 * 1024)); OrtSession.SessionOptions sessionOpts = new OrtSession.SessionOptions(); - sessionOpts.addTensorrt(cudaOpts); + sessionOpts.addTensorrt(rtOpts); runProvider(OrtProvider.TENSOR_RT, sessionOpts); + sessionOpts.close(); + rtOpts.close(); // Test invalid device num throws assertThrows(IllegalArgumentException.class, () -> new OrtTensorRTProviderOptions(-1)); // Test invalid key name throws - OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidKeyOpts.add("not_a_real_provider_option", "not a number")); + try (OrtTensorRTProviderOptions invalidKeyOpts = new OrtTensorRTProviderOptions(0)) { + invalidKeyOpts.add("not_a_real_provider_option", "not a number"); + assertThrows(OrtException.class, invalidKeyOpts::applyToNative); + } // Test invalid value throws - OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0); - assertThrows( - OrtException.class, () -> invalidValueOpts.add("trt_max_workspace_size", "not a number")); + try (OrtTensorRTProviderOptions invalidValueOpts = new OrtTensorRTProviderOptions(0)) { + invalidValueOpts.add("trt_max_workspace_size", "not a number"); + assertThrows(OrtException.class, invalidValueOpts::applyToNative); + } } private static void runProvider(OrtProvider provider, OrtSession.SessionOptions options) @@ -96,7 +105,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-5f); + assertArrayEquals(expectedOutput, resultArray, 1e-3f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); }