Skip to content

Commit

Permalink
Java: Construct a Tensor from a handle to the C TF_Tensor object.
Browse files Browse the repository at this point in the history
This change introduces a package private static method to create
a Java Tensor object from the handle to the native TF_Tensor object.
This will be needed to create the output Tensor objects produced
by graph execution (when the equivalent of a Session and it's run()
method are introduced in Java).

Yet another step in the journey that is zplizzi#5
Change: 141096422
  • Loading branch information
asimshankar authored and tensorflower-gardener committed Dec 5, 2016
1 parent 1f7526b commit 005b59a
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tensorflow/java/src/main/java/org/tensorflow/DataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,14 @@ public enum DataType {
int c() {
return value;
}

static DataType fromC(int c) {
for (DataType t : DataType.values()) {
if (t.c() == c) {
return t;
}
}
throw new IllegalArgumentException(
"DataType " + c + " is not recognized in Java (version " + TensorFlow.version() + ")");
}
}
23 changes: 22 additions & 1 deletion tensorflow/java/src/main/java/org/tensorflow/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/**
* A typed multi-dimensional array.
*
* Instances of a Tensor are <b>not</b> thread-safe.
* <p>Instances of a Tensor are <b>not</b> thread-safe.
*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
Expand Down Expand Up @@ -196,6 +196,23 @@ public String toString() {
return String.format("%s tensor with shape %s", dtype.toString(), Arrays.toString(shape()));
}

/**
* Create a Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor fromHandle(long handle) {
Tensor t = new Tensor();
t.dtype = DataType.fromC(dtype(handle));
t.shapeCopy = shape(handle);
t.nativeHandle = handle;
return t;
}

long getNativeHandle() {
return nativeHandle;
}

private long nativeHandle;
private DataType dtype;
private long[] shapeCopy = null;
Expand Down Expand Up @@ -276,6 +293,10 @@ private void throwExceptionIfTypeIsIncompatible(Object o) {

private static native void delete(long handle);

private static native int dtype(long handle);

private static native long[] shape(long handle);

private static native void setValue(long handle, Object value);

private static native float scalarFloat(long handle);
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/java/src/main/native/tensor_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,33 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
}

JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
static_assert(sizeof(jint) >= sizeof(TF_DataType),
"TF_DataType in C cannot be represented as an int in Java");
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return 0;
return static_cast<jint>(TF_TensorType(t));
}

JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return nullptr;
static_assert(sizeof(jlong) == sizeof(int64_t),
"Java long is not compatible with the TensorFlow C API");
const jsize num_dims = TF_NumDims(t);
jlongArray ret = env->NewLongArray(num_dims);
jlong* dims = env->GetLongArrayElements(ret, nullptr);
for (int i = 0; i < num_dims; ++i) {
dims[i] = static_cast<jlong>(TF_Dim(t, i));
}
env->ReleaseLongArrayElements(ret, dims, 0);
return ret;
}

JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
jclass clazz,
jlong handle,
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/java/src/main/native/tensor_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Tensor
* Method: dtype
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Tensor
* Method: shape
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv *, jclass,
jlong);

/*
* Class: org_tensorflow_Tensor
* Method: setValue
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,22 @@ public void useAfterClose() {
// The expected exception.
}
}

@Test
public void fromHandle() {
// fromHandle is a package-visible method intended for use when the C TF_Tensor object has been
// created indepdently of the Java code. In practice, two Tensor instances MUST NOT have the
// same native handle.
//
// An exception is made for this test, where the pitfalls of this is avoided by not calling
// close() on both Tensors.
final float[][] matrix = {{1, 2, 3}, {4, 5, 6}};
try (Tensor src = Tensor.create(matrix)) {
Tensor cpy = Tensor.fromHandle(src.getNativeHandle());
assertEquals(src.dataType(), cpy.dataType());
assertEquals(src.numDimensions(), cpy.numDimensions());
assertArrayEquals(src.shape(), cpy.shape());
assertArrayEquals(matrix, cpy.copyTo(new float[2][3]));
}
}
}

0 comments on commit 005b59a

Please sign in to comment.