Skip to content

Commit

Permalink
Updating react native tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Oct 12, 2023
1 parent 34787aa commit 15d5283
Showing 1 changed file with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ public void throwWrongSizeInput() {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
String sessionKey = null;

try (InputStream modelStream =
reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float)) {
JavaOnlyMap options = new JavaOnlyMap();
ReadableMap loadMap = ortModule.loadModel("test", modelStream, options);
byte[] modelBuffer = getInputModelBuffer(modelStream);
ReadableMap loadMap = ortModule.loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");

int[] dims = new int[] {1, 7};
float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE, 5f, -6f};
Expand Down Expand Up @@ -253,6 +256,7 @@ public void throwWrongSizeInput() {
Assert.assertTrue(e.getMessage().contains("Got invalid dimensions for input"));
}
} finally {
ortModule.dispose(sessionKey);
mockSession.finishMocking();
}
}
Expand All @@ -265,11 +269,14 @@ public void throwWrongRankInput() {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
String sessionKey = null;

JavaOnlyMap options = new JavaOnlyMap();
try (InputStream modelStream =
reactContext.getResources().openRawResource(ai.onnxruntime.reactnative.test.R.raw.test_types_float)){
ReadableMap loadMap = ortModule.loadModel("test", modelStream, options);
byte[] modelBuffer = getInputModelBuffer(modelStream);
ReadableMap loadMap = ortModule.loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");

int[] dims = new int[] {1, 1, 7};
float[] inputData = new float[] {1.0f, 2.0f, -3.0f, Float.MIN_VALUE, Float.MAX_VALUE, 5f, -6f};
Expand Down Expand Up @@ -308,6 +315,7 @@ public void throwWrongRankInput() {
Assert.assertTrue(e.getMessage().contains("Invalid rank for input"));
}
} finally {
ortModule.dispose(sessionKey);
mockSession.finishMocking();
}
}
Expand Down

0 comments on commit 15d5283

Please sign in to comment.