Skip to content

Commit

Permalink
Java: Add support for scalar Tensors with dtype = TF_STRING
Browse files Browse the repository at this point in the history
This change adds support for string scalars used by
popular operations like DecodeJpeg. Vectors, matrices
and higher-dimensional string tensors are left as an
excercise for the future.

Note that the TF_STRING data type in TensorFlow corresponds to a byte[] in
Java (not a String) since it is used for arbitrary bytes, not just
strings of unicode characters.

One more step in the journey that is zplizzi#5
Change: 141221096
  • Loading branch information
asimshankar authored and tensorflower-gardener committed Dec 7, 2016
1 parent b0b9de3 commit 4d51b55
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 4 deletions.
7 changes: 7 additions & 0 deletions tensorflow/java/src/main/java/org/tensorflow/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ public enum DataType {
/** 32-bit signed integer. */
INT32(3),

/**
* A sequence of bytes.
*
* <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes.
*/
STRING(7),

/** 64-bit signed integer. */
INT64(9),

Expand Down
40 changes: 36 additions & 4 deletions tensorflow/java/src/main/java/org/tensorflow/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,17 @@ public static Tensor create(Object obj) {
t.dtype = dataTypeOf(obj);
t.shapeCopy = new long[numDimensions(obj)];
fillShape(obj, 0, t.shapeCopy);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy);
setValue(t.nativeHandle, obj);
if (t.dtype != DataType.STRING) {
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy);
setValue(t.nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
throw new UnsupportedOperationException(
String.format(
"non-scalar DataType.STRING tensors are not supported yet (version %s). Please file a feature request at https://github.com/tensorflow/tensorflow/issues/new",
TensorFlow.version()));
} else {
t.nativeHandle = allocateScalarBytes((byte[]) obj);
}
return t;
}

Expand Down Expand Up @@ -160,6 +169,15 @@ public boolean booleanValue() {
return scalarBoolean(nativeHandle);
}

/**
* Returns the value in a scalar {@link DataType#STRING} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public byte[] bytesValue() {
return scalarBytes(nativeHandle);
}

/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
Expand Down Expand Up @@ -224,7 +242,12 @@ private static DataType dataTypeOf(Object o) {
if (Array.getLength(o) == 0) {
throw new IllegalArgumentException("cannot create Tensors with a 0 dimension");
}
return dataTypeOf(Array.get(o, 0));
// byte[] is a DataType.STRING scalar.
Object e = Array.get(o, 0);
if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
return DataType.STRING;
}
return dataTypeOf(e);
}
if (Float.class.isInstance(o) || float.class.isInstance(o)) {
return DataType.FLOAT;
Expand All @@ -243,7 +266,12 @@ private static DataType dataTypeOf(Object o) {

private static int numDimensions(Object o) {
if (o.getClass().isArray()) {
return 1 + numDimensions(Array.get(o, 0));
// byte[] is a DataType.STRING scalar.
Object e = Array.get(o, 0);
if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
return 0;
}
return 1 + numDimensions(e);
}
return 0;
}
Expand Down Expand Up @@ -291,6 +319,8 @@ private void throwExceptionIfTypeIsIncompatible(Object o) {

private static native long allocate(int dtype, long[] shape);

private static native long allocateScalarBytes(byte[] value);

private static native void delete(long handle);

private static native int dtype(long handle);
Expand All @@ -309,6 +339,8 @@ private void throwExceptionIfTypeIsIncompatible(Object o) {

private static native boolean scalarBoolean(long handle);

private static native byte[] scalarBytes(long handle);

private static native void readNDArray(long handle, Object value);

static {
Expand Down
77 changes: 77 additions & 0 deletions tensorflow/java/src/main/native/tensor_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <memory>

#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
Expand Down Expand Up @@ -262,6 +263,39 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
return reinterpret_cast<jlong>(t);
}

JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
JNIEnv* env, jclass clazz, jbyteArray value) {
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
// TF_StringEncode-encoded bytes.
size_t src_len = static_cast<int>(env->GetArrayLength(value));
size_t dst_len = TF_StringEncodedSize(src_len);
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
char* dst = static_cast<char*>(TF_TensorData(t));
memset(dst, 0, 8); // The offset table

// jbyte is a signed char, while the C standard doesn't require char and
// signed char to be the same. As a result, static_cast<char*>(src) will
// complain. Copy the string instead. sigh!
jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
std::unique_ptr<char[]> src(new char[src_len]);
static_assert(sizeof(jbyte) == sizeof(char),
"Cannot convert Java byte to a C char");
memcpy(src.get(), jsrc, src_len);
env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);

TF_Status* status = TF_NewStatus();
TF_StringEncode(src.get(), src_len, dst + 8, dst_len, status);
if (TF_GetCode(status) != TF_OK) {
// TODO(ashankar): Replace with throwExceptionIfNotOK() being added to
// exception_jni.h in another change.
throwException(env, kIllegalStateException, TF_Message(status));
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return reinterpret_cast<jlong>(t);
}

JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
Expand Down Expand Up @@ -337,6 +371,49 @@ DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
#undef DEFINE_GET_SCALAR_METHOD

JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
JNIEnv* env, jclass clazz, jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return nullptr;
if (TF_NumDims(t) != 0) {
throwException(env, kIllegalStateException, "Tensor is not a scalar");
return nullptr;
}
if (TF_TensorType(t) != TF_STRING) {
throwException(env, kIllegalArgumentException,
"Tensor is not a string/bytes scalar");
return nullptr;
}
const char* data = static_cast<const char*>(TF_TensorData(t));
const char* src = data + 8;
size_t src_len = TF_TensorByteSize(t) - 8;
uint64_t offset = 0;
memcpy(&offset, data, sizeof(offset));
if (offset >= src_len) {
throwException(env, kIllegalArgumentException,
"invalid tensor encoding: bad offsets");
return nullptr;
}
jbyteArray ret = nullptr;
const char* dst = nullptr;
size_t dst_len = 0;
TF_Status* status = TF_NewStatus();
TF_StringDecode(src, src_len, &dst, &dst_len, status);
if (TF_GetCode(status) != TF_OK) {
// TODO(ashankar): Replace with throwExceptionIfNotOK introduced into
// exception_jni.h by another change.
throwException(env, kIllegalArgumentException,
"invalid tensor encoding: %s", TF_Message(status));
} else {
ret = env->NewByteArray(dst_len);
jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
memcpy(cpy, dst, dst_len);
env->ReleaseByteArrayElements(ret, cpy, 0);
}
TF_DeleteStatus(status);
return ret;
}

JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
jclass clazz,
jlong handle,
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/java/src/main/native/tensor_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
jint, jlongArray);

/*
* Class: org_tensorflow_Tensor
* Method: allocateScalarBytes
* Signature: ([B)J
*/
JNIEXPORT jlong JNICALL
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv *, jclass, jbyteArray);

/*
* Class: org_tensorflow_Tensor
* Method: delete
Expand Down Expand Up @@ -108,6 +116,15 @@ JNIEXPORT jboolean JNICALL Java_org_tensorflow_Tensor_scalarBoolean(JNIEnv *,
jclass,
jlong);

/*
* Class: org_tensorflow_Tensor
* Method: scalarBytes
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(JNIEnv *,
jclass,
jlong);

/*
* Class: org_tensorflow_Tensor
* Method: readNDArray
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ public void scalars() {
assertEquals(0, t.shape().length);
assertTrue(t.booleanValue());
}

final byte[] bytes = {1,2,3,4};
try (Tensor t = Tensor.create(bytes)) {
assertEquals(DataType.STRING, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertArrayEquals(bytes, t.bytesValue());
}
}

@Test
Expand Down

0 comments on commit 4d51b55

Please sign in to comment.