From 4d51b55cf8c2f58f49b3c89bf0dbd32b5875f1fb Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 6 Dec 2016 13:49:10 -0800 Subject: [PATCH] Java: Add support for scalar Tensors with dtype = TF_STRING This change adds support for string scalars used by popular operations like DecodeJpeg. Vectors, matrices and higher-dimensional string tensors are left as an excercise for the future. Note that the TF_STRING data type in TensorFlow corresponds to a byte[] in Java (not a String) since it is used for arbitrary bytes, not just strings of unicode characters. One more step in the journey that is #5 Change: 141221096 --- .../main/java/org/tensorflow/DataType.java | 7 ++ .../src/main/java/org/tensorflow/Tensor.java | 40 +++++++++- tensorflow/java/src/main/native/tensor_jni.cc | 77 +++++++++++++++++++ tensorflow/java/src/main/native/tensor_jni.h | 17 ++++ .../test/java/org/tensorflow/TensorTest.java | 8 ++ 5 files changed, 145 insertions(+), 4 deletions(-) diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 08d4e14df9be76..a7c6e12b41a069 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -26,6 +26,13 @@ public enum DataType { /** 32-bit signed integer. */ INT32(3), + /** + * A sequence of bytes. + * + *

TensorFlow uses the STRING type for an arbitrary sequence of bytes. + */ + STRING(7), + /** 64-bit signed integer. */ INT64(9), diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index d09a71a758ad29..5478bb85e9bdd0 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -70,8 +70,17 @@ public static Tensor create(Object obj) { t.dtype = dataTypeOf(obj); t.shapeCopy = new long[numDimensions(obj)]; fillShape(obj, 0, t.shapeCopy); - t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy); - setValue(t.nativeHandle, obj); + if (t.dtype != DataType.STRING) { + t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy); + setValue(t.nativeHandle, obj); + } else if (t.shapeCopy.length != 0) { + throw new UnsupportedOperationException( + String.format( + "non-scalar DataType.STRING tensors are not supported yet (version %s). Please file a feature request at https://github.com/tensorflow/tensorflow/issues/new", + TensorFlow.version())); + } else { + t.nativeHandle = allocateScalarBytes((byte[]) obj); + } return t; } @@ -160,6 +169,15 @@ public boolean booleanValue() { return scalarBoolean(nativeHandle); } + /** + * Returns the value in a scalar {@link DataType#STRING} tensor. + * + * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. + */ + public byte[] bytesValue() { + return scalarBytes(nativeHandle); + } + /** * Copies the contents of the tensor to {@code dst} and returns {@code dst}. * @@ -224,7 +242,12 @@ private static DataType dataTypeOf(Object o) { if (Array.getLength(o) == 0) { throw new IllegalArgumentException("cannot create Tensors with a 0 dimension"); } - return dataTypeOf(Array.get(o, 0)); + // byte[] is a DataType.STRING scalar. + Object e = Array.get(o, 0); + if (Byte.class.isInstance(e) || byte.class.isInstance(e)) { + return DataType.STRING; + } + return dataTypeOf(e); } if (Float.class.isInstance(o) || float.class.isInstance(o)) { return DataType.FLOAT; @@ -243,7 +266,12 @@ private static DataType dataTypeOf(Object o) { private static int numDimensions(Object o) { if (o.getClass().isArray()) { - return 1 + numDimensions(Array.get(o, 0)); + // byte[] is a DataType.STRING scalar. + Object e = Array.get(o, 0); + if (Byte.class.isInstance(e) || byte.class.isInstance(e)) { + return 0; + } + return 1 + numDimensions(e); } return 0; } @@ -291,6 +319,8 @@ private void throwExceptionIfTypeIsIncompatible(Object o) { private static native long allocate(int dtype, long[] shape); + private static native long allocateScalarBytes(byte[] value); + private static native void delete(long handle); private static native int dtype(long handle); @@ -309,6 +339,8 @@ private void throwExceptionIfTypeIsIncompatible(Object o) { private static native boolean scalarBoolean(long handle); + private static native byte[] scalarBytes(long handle); + private static native void readNDArray(long handle, Object value); static { diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc index d5850b0d4236b1..c98d6807acc345 100644 --- a/tensorflow/java/src/main/native/tensor_jni.cc +++ b/tensorflow/java/src/main/native/tensor_jni.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/c/c_api.h" #include "tensorflow/java/src/main/native/exception_jni.h" @@ -262,6 +263,39 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env, return reinterpret_cast(t); } +JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes( + JNIEnv* env, jclass clazz, jbyteArray value) { + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by + // TF_StringEncode-encoded bytes. + size_t src_len = static_cast(env->GetArrayLength(value)); + size_t dst_len = TF_StringEncodedSize(src_len); + TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len); + char* dst = static_cast(TF_TensorData(t)); + memset(dst, 0, 8); // The offset table + + // jbyte is a signed char, while the C standard doesn't require char and + // signed char to be the same. As a result, static_cast(src) will + // complain. Copy the string instead. sigh! + jbyte* jsrc = env->GetByteArrayElements(value, nullptr); + std::unique_ptr src(new char[src_len]); + static_assert(sizeof(jbyte) == sizeof(char), + "Cannot convert Java byte to a C char"); + memcpy(src.get(), jsrc, src_len); + env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT); + + TF_Status* status = TF_NewStatus(); + TF_StringEncode(src.get(), src_len, dst + 8, dst_len, status); + if (TF_GetCode(status) != TF_OK) { + // TODO(ashankar): Replace with throwExceptionIfNotOK() being added to + // exception_jni.h in another change. + throwException(env, kIllegalStateException, TF_Message(status)); + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return reinterpret_cast(t); +} + JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env, jclass clazz, jlong handle) { @@ -337,6 +371,49 @@ DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long); DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean); #undef DEFINE_GET_SCALAR_METHOD +JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes( + JNIEnv* env, jclass clazz, jlong handle) { + TF_Tensor* t = requireHandle(env, handle); + if (t == nullptr) return nullptr; + if (TF_NumDims(t) != 0) { + throwException(env, kIllegalStateException, "Tensor is not a scalar"); + return nullptr; + } + if (TF_TensorType(t) != TF_STRING) { + throwException(env, kIllegalArgumentException, + "Tensor is not a string/bytes scalar"); + return nullptr; + } + const char* data = static_cast(TF_TensorData(t)); + const char* src = data + 8; + size_t src_len = TF_TensorByteSize(t) - 8; + uint64_t offset = 0; + memcpy(&offset, data, sizeof(offset)); + if (offset >= src_len) { + throwException(env, kIllegalArgumentException, + "invalid tensor encoding: bad offsets"); + return nullptr; + } + jbyteArray ret = nullptr; + const char* dst = nullptr; + size_t dst_len = 0; + TF_Status* status = TF_NewStatus(); + TF_StringDecode(src, src_len, &dst, &dst_len, status); + if (TF_GetCode(status) != TF_OK) { + // TODO(ashankar): Replace with throwExceptionIfNotOK introduced into + // exception_jni.h by another change. + throwException(env, kIllegalArgumentException, + "invalid tensor encoding: %s", TF_Message(status)); + } else { + ret = env->NewByteArray(dst_len); + jbyte* cpy = env->GetByteArrayElements(ret, nullptr); + memcpy(cpy, dst, dst_len); + env->ReleaseByteArrayElements(ret, cpy, 0); + } + TF_DeleteStatus(status); + return ret; +} + JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env, jclass clazz, jlong handle, diff --git a/tensorflow/java/src/main/native/tensor_jni.h b/tensorflow/java/src/main/native/tensor_jni.h index 72779088290cb7..ea0dfc819efdb4 100644 --- a/tensorflow/java/src/main/native/tensor_jni.h +++ b/tensorflow/java/src/main/native/tensor_jni.h @@ -30,6 +30,14 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass, jint, jlongArray); +/* + * Class: org_tensorflow_Tensor + * Method: allocateScalarBytes + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv *, jclass, jbyteArray); + /* * Class: org_tensorflow_Tensor * Method: delete @@ -108,6 +116,15 @@ JNIEXPORT jboolean JNICALL Java_org_tensorflow_Tensor_scalarBoolean(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Tensor + * Method: scalarBytes + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(JNIEnv *, + jclass, + jlong); + /* * Class: org_tensorflow_Tensor * Method: readNDArray diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 65540f9f20a775..ec1c8551a71586 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -63,6 +63,14 @@ public void scalars() { assertEquals(0, t.shape().length); assertTrue(t.booleanValue()); } + + final byte[] bytes = {1,2,3,4}; + try (Tensor t = Tensor.create(bytes)) { + assertEquals(DataType.STRING, t.dataType()); + assertEquals(0, t.numDimensions()); + assertEquals(0, t.shape().length); + assertArrayEquals(bytes, t.bytesValue()); + } } @Test