diff --git a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java index 76fd608e4362b..72518488e6682 100644 --- a/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java +++ b/js/react_native/android/src/androidTest/java/ai/onnxruntime/reactnative/TensorHelperTest.java @@ -238,6 +238,34 @@ public void createInputTensor_double() throws Exception { outputTensor.close(); } + @Test + public void createInputTensor_bool() throws Exception { + OnnxTensor outputTensor = OnnxTensor.createTensor(ortEnvironment, new boolean[] {false, true}); + + JavaOnlyMap inputTensorMap = new JavaOnlyMap(); + + JavaOnlyArray dims = new JavaOnlyArray(); + dims.pushInt(2); + inputTensorMap.putArray("dims", dims); + + inputTensorMap.putString("type", TensorHelper.JsTensorTypeBool); + + ByteBuffer dataByteBuffer = ByteBuffer.allocate(2); + dataByteBuffer.put((byte)0); + dataByteBuffer.put((byte)1); + inputTensorMap.putMap("data", blobModule.testCreateData(dataByteBuffer.array())); + + OnnxTensor inputTensor = TensorHelper.createInputTensor(blobModule, inputTensorMap, ortEnvironment); + + Assert.assertEquals(inputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); + Assert.assertEquals(outputTensor.getInfo().onnxType, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL); + Assert.assertEquals(inputTensor.toString(), outputTensor.toString()); + Assert.assertArrayEquals(inputTensor.getByteBuffer().array(), outputTensor.getByteBuffer().array()); + + inputTensor.close(); + outputTensor.close(); + } + @Test public void createOutputTensor_bool() throws Exception { MockitoSession mockSession = mockitoSession().mockStatic(Arguments.class).startMocking(); diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java index d9c2e3bac5d9b..63cddace36640 100644 --- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java @@ -174,7 +174,11 @@ private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType tensorType tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); break; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + ByteBuffer buffer = values; + tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.BOOL); + break; + } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: