Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/rn] Supoort New Architecture #16669

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2354db1
TS: define NativeOnnxruntimeSpec
jhen0409 Jul 6, 2023
209b91f
setup codegen
jhen0409 Jul 6, 2023
efce8ed
Android: Split module implementation from oldarch/newarch
jhen0409 Jul 8, 2023
f176211
Bump @types/react-native to 0.71.3
jhen0409 Jul 8, 2023
70daa8d
Ignore unicorn/filename-case for NativeOnnxruntimeSpec
jhen0409 Jul 8, 2023
39dab48
Android: Fix incorrect dir path
jhen0409 Jul 8, 2023
8485286
E2E: Upgrade React Native to v0.71
jhen0409 Jul 10, 2023
1df0795
Android: Fix build
jhen0409 Jul 10, 2023
16cff1e
TS: Fix native spec
jhen0409 Jul 11, 2023
cff8ff2
Android: Still use legacy module for JSIHelper
jhen0409 Jul 11, 2023
e6f5bff
iOS: Support New Architecture
jhen0409 Jul 11, 2023
d99af7f
Ignore more paths
jhen0409 Jul 11, 2023
e0813d3
Format
jhen0409 Jul 11, 2023
0ce934a
Android: Fix path of E2E package name
jhen0409 Jul 11, 2023
b12e78d
Android: Fix OnnxruntimeSpec
jhen0409 Jul 11, 2023
a2f98f6
Android: Use Light style for E2E project
jhen0409 Jul 11, 2023
7dde3a6
Android: Fix path of E2E package name
jhen0409 Jul 11, 2023
e7a1c9d
Android: Always set blob module on check
jhen0409 Jul 11, 2023
c9e1714
Use v0.69 for native unit tests
jhen0409 Jul 11, 2023
61ff7ab
Android: Disable new-arch in E2E project for passed detox test
jhen0409 Jul 11, 2023
e4e389b
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 11, 2023
1d55dc4
TS: Move Binding types to native module spec (options remain {})
jhen0409 Jul 12, 2023
77f54a3
Revert unnecessary changes
jhen0409 Jul 12, 2023
f964488
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 12, 2023
ca38c81
iOS: Revert removed comments
jhen0409 Jul 12, 2023
5ba46e3
Format
jhen0409 Jul 12, 2023
200f14b
E2E: Remove local package links
jhen0409 Jul 12, 2023
7ea8b98
TS: Un-ban {} type only for NativeOnnxruntime spec
jhen0409 Jul 12, 2023
27d36df
Android: Revert rn_edit_text_material
jhen0409 Jul 12, 2023
b02c8e2
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 13, 2023
3babf7e
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 29, 2023
075770f
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Aug 7, 2023
9479f46
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Aug 28, 2023
99817d6
Doc: Remove unnecessary keygen step
jhen0409 Aug 28, 2023
4560820
Android: Move more duplicated code to Onnxruntime class
jhen0409 Aug 28, 2023
b50162a
Android: Use class name as 2nd arg for init ReactModuleInfo
jhen0409 Aug 28, 2023
ba5dc5d
Android: Fix tests
jhen0409 Aug 28, 2023
0b6479e
Android: Remove unnecessary code
jhen0409 Aug 28, 2023
8599948
iOS: Remove unnecessary patch
jhen0409 Aug 28, 2023
4fd5831
Android: Use 0.71 for unit tests & fix gradle build
jhen0409 Aug 29, 2023
93c195b
iOS: Update Podfile
jhen0409 Sep 7, 2023
110b381
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Sep 7, 2023
48e0e73
Android: Fix react-android dep for RN < 0.71
jhen0409 Sep 10, 2023
0a1988a
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Sep 28, 2023
5e08a9e
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Oct 19, 2023
e1a5b8d
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Nov 12, 2023
446b12b
Revert unnecessary deps change
jhen0409 Nov 12, 2023
c409719
Remove dep that added hash in lockfile
jhen0409 Nov 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion js/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,22 @@ module.exports = {
}
}, {
files: ['react_native/lib/**/*.ts'], rules: {
'@typescript-eslint/naming-convention': 'off'
'@typescript-eslint/naming-convention': 'off',
}
}, {
files: ['react_native/lib/NativeOnnxruntime.ts'], rules: {
'unicorn/filename-case': 'off',
'@typescript-eslint/ban-types': [
'error',
{
types: {
// NOTE: We got issue like https://github.com/facebook/react-native/issues/36431
// So we have to use `{}` type here.
'{}': false,
},
extendDefaults: true,
}
]
}
}, {
files: ['react_native/scripts/**/*.ts'], rules: {
Expand Down
6 changes: 0 additions & 6 deletions js/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,6 @@ From ORT v1.13 onwards the 'full' ONNX Runtime package is used. It supports both

4. Test Android and iOS apps. In Windows, open Android Emulator first.

`debug.keystore` must be generated ahead for Android example.

```sh
keytool -genkey -v -keystore <ORT_ROOT>/js/react_native/e2e/android/debug.keystore -alias androiddebugkey -storepass android -keypass android -keyalg RSA -keysize 2048 -validity 999999 -dname "CN=Android Debug,O=Android,C=US"
```

From `<ORT_ROOT>/js/react_native,

```sh
Expand Down
2 changes: 2 additions & 0 deletions js/react_native/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ DerivedData
*.ipa
*.xcuserstate
project.xcworkspace
xcshareddata

# Android/IJ
#
.idea
.gradle
local.properties
android.iml
.cxx

# Cocoapods
#
Expand Down
34 changes: 32 additions & 2 deletions js/react_native/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@ buildscript {
}

dependencies {
classpath 'com.android.tools.build:gradle:4.1.2'
classpath 'com.android.tools.build:gradle:7.3.1'
// noinspection DifferentKotlinGradleVersion
}
}

apply plugin: 'com.android.library'

def isNewArchitectureEnabled() {
return rootProject.hasProperty("newArchEnabled") && rootProject.getProperty("newArchEnabled") == "true"
}

if (isNewArchitectureEnabled()) {
apply plugin: "com.facebook.react"
}

def getExtOrDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : project.properties['OnnxruntimeModule_' + name]
}
Expand Down Expand Up @@ -90,6 +98,7 @@ android {
abiFilters (*reactNativeArchitectures())
}
}
buildConfigField "boolean", "IS_NEW_ARCHITECTURE_ENABLED", isNewArchitectureEnabled().toString()
}

if (rootProject.hasProperty("ndkPath")) {
Expand All @@ -115,6 +124,7 @@ android {
"META-INF",
"META-INF/**",
"**/libjsi.so",
"**/libc++_shared.so",
]
}

Expand All @@ -139,6 +149,12 @@ android {
} else {
java.exclude '**/OnnxruntimeExtensionsEnabled.java'
}

if (isNewArchitectureEnabled()) {
java.srcDirs += ['src/newarch']
} else {
java.srcDirs += ['src/oldarch']
}
}
}
}
Expand Down Expand Up @@ -219,7 +235,13 @@ repositories {
}

dependencies {
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
if (REACT_NATIVE_MINOR_VERSION >= 71) {
// REACT_NATIVE_VERSION >= 0.71.x use react-android (https://mvnrepository.com/artifact/com.facebook.react/react-android)
// See also https://github.com/facebook/react-native/blob/0.71-stable/android/README.md
api "com.facebook.react:react-android:" + REACT_NATIVE_VERSION
} else {
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
}
api "org.mockito:mockito-core:2.28.2"

androidTestImplementation "androidx.test:runner:1.1.0"
Expand All @@ -238,3 +260,11 @@ dependencies {
implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar"
}
}

if (isNewArchitectureEnabled()) {
react {
jsRootDir = file("../lib/")
libraryName = "OnnxruntimeSpec"
codegenJavaPackageName = "ai.onnxruntime.reactnative"
}
}
5 changes: 5 additions & 0 deletions js/react_native/android/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ OnnxruntimeModule_buildToolsVersion=29.0.2
OnnxruntimeModule_compileSdkVersion=31
OnnxruntimeModule_minSdkVersion=21
OnnxruntimeModule_targetSdkVersion=31

# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
# Default value: -Xmx512m -XX:MaxMetaspaceSize=256m
org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=7faa7198769f872826c8ef4f1450f839ec27f0b4d5d1e51bade63667cbccd205
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
distributionSha256Sum=f6b8596b10cce501591e92f229816aa4046424f3b24d771751b06779d58c8ec4
distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void setUp() {
@Test
public void getName() throws Exception {
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
ortModule.getOnnxruntime().setBlobModule(blobModule);
String name = "Onnxruntime";
Assert.assertEquals(ortModule.getName(), name);
}
Expand All @@ -71,7 +71,8 @@ public void onnxruntime_module() throws Exception {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
ortModule.getOnnxruntime().setBlobModule(blobModule);
ortModule.getOnnxruntime().checkBlobModule();
String sessionKey = "";

// test loadModel()
Expand All @@ -82,7 +83,7 @@ public void onnxruntime_module() throws Exception {

JavaOnlyMap options = new JavaOnlyMap();
try {
ReadableMap resultMap = ortModule.loadModel(modelBuffer, options);
ReadableMap resultMap = ortModule.getOnnxruntime().loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");
ReadableArray inputNames = resultMap.getArray("inputNames");
ReadableArray outputNames = resultMap.getArray("outputNames");
Expand Down Expand Up @@ -132,7 +133,7 @@ public void onnxruntime_module() throws Exception {
options.putBoolean("encodeTensorData", true);

try {
ReadableMap resultMap = ortModule.run(sessionKey, inputDataMap, outputNames, options);
ReadableMap resultMap = ortModule.getOnnxruntime().run(sessionKey, inputDataMap, outputNames, options);

ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Expand All @@ -151,7 +152,7 @@ public void onnxruntime_module() throws Exception {
}

// test dispose
ortModule.dispose(sessionKey);
ortModule.getOnnxruntime().dispose(sessionKey);
} finally {
mockSession.finishMocking();
}
Expand All @@ -165,7 +166,8 @@ public void onnxruntime_module_append_nnapi() throws Exception {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
ortModule.getOnnxruntime().setBlobModule(blobModule);
ortModule.getOnnxruntime().checkBlobModule();
String sessionKey = "";

// test loadModel() with nnapi ep options
Expand All @@ -182,7 +184,7 @@ public void onnxruntime_module_append_nnapi() throws Exception {
options.putArray("executionProviders", epArray);

try {
ReadableMap resultMap = ortModule.loadModel(modelBuffer, options);
ReadableMap resultMap = ortModule.getOnnxruntime().loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");
ReadableArray inputNames = resultMap.getArray("inputNames");
ReadableArray outputNames = resultMap.getArray("outputNames");
Expand All @@ -195,7 +197,7 @@ public void onnxruntime_module_append_nnapi() throws Exception {
Assert.fail(e.getMessage());
}
}
ortModule.dispose(sessionKey);
ortModule.getOnnxruntime().dispose(sessionKey);
} finally {
mockSession.finishMocking();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@
import android.net.Uri;
import android.os.Build;
import android.util.Log;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.LifecycleEventListener;
import com.facebook.react.bridge.Promise;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReadableType;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
Expand All @@ -46,7 +42,7 @@
import java.util.stream.Stream;

@RequiresApi(api = Build.VERSION_CODES.N)
public class OnnxruntimeModule extends ReactContextBaseJavaModule implements LifecycleEventListener {
public class Onnxruntime implements LifecycleEventListener {
private static ReactApplicationContext reactContext;

private static OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment();
Expand All @@ -58,27 +54,20 @@ private static String getNextSessionKey() {
nextSessionId = nextSessionId.add(BigInteger.valueOf(1));
return key;
}
private BlobModule blobModule;

protected BlobModule blobModule;
public Onnxruntime(ReactApplicationContext context) { reactContext = context; }

public OnnxruntimeModule(ReactApplicationContext context) {
super(context);
reactContext = context;
}

@NonNull
@Override
public String getName() {
return "Onnxruntime";
}
protected void setBlobModule(BlobModule blobModule) { this.blobModule = blobModule; }

public void checkBlobModule() {
if (blobModule == null) {
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
blobModule = reactContext.getNativeModule(BlobModule.class);
if (blobModule == null) {
throw new RuntimeException("BlobModule is not initialized");
}
}
setBlobModule(blobModule);
}

/**
Expand All @@ -90,7 +79,6 @@ public void checkBlobModule() {
* @note the value provided to `promise` includes a key representing the session.
* when run() is called, the key must be passed into the first parameter.
*/
@ReactMethod
public void loadModel(String uri, ReadableMap options, Promise promise) {
try {
WritableMap resultMap = loadModel(uri, options);
Expand All @@ -109,7 +97,6 @@ public void loadModel(String uri, ReadableMap options, Promise promise) {
* @note the value provided to `promise` includes a key representing the session.
* when run() is called, the key must be passed into the first parameter.
*/
@ReactMethod
public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) {
try {
checkBlobModule();
Expand All @@ -129,7 +116,6 @@ public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise pro
* @param key session key representing a session given at loadModel()
* @param promise output returning back to react native js
*/
@ReactMethod
public void dispose(String key, Promise promise) {
try {
dispose(key);
Expand All @@ -148,9 +134,9 @@ public void dispose(String key, Promise promise) {
* @param options onnxruntime run options
* @param promise output returning back to react native js
*/
@ReactMethod
public void run(String key, ReadableMap input, ReadableArray output, ReadableMap options, Promise promise) {
try {
checkBlobModule();
WritableMap resultMap = run(key, input, output, options);
promise.resolve(resultMap);
} catch (Exception e) {
Expand Down Expand Up @@ -259,8 +245,6 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read

RunOptions runOptions = parseRunOptions(options);

checkBlobModule();

long startTime = System.currentTimeMillis();
Map<String, OnnxTensor> feed = new HashMap<>();
Iterator<String> iterator = ortSession.getInputNames().iterator();
Expand Down
Loading