Skip to content

Commit

Permalink
Java: Bugfix: Avoid segfault when a Tensor is used after being close()d.
Browse files Browse the repository at this point in the history
There were two bugs:
- The string constant kNullPointerException had a typo in its value
- The scalar value accessors weren't checking for a valid handle

Fixed those, added tests and moved the utility function throwException
to its own file since I plan to use that in JNI files for the other
Java classes in the future.

Another step in the journey that is zplizzi#5
Change: 141091613
  • Loading branch information
asimshankar authored and tensorflower-gardener committed Dec 5, 2016
1 parent bdfb8ff commit 502675b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 29 deletions.
4 changes: 3 additions & 1 deletion tensorflow/java/src/main/java/org/tensorflow/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
/**
* A typed multi-dimensional array.
*
* Instances of a Tensor are <b>not</b> thread-safe.
*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
* try-with-resources block like:
Expand Down Expand Up @@ -78,7 +80,7 @@ public static Tensor create(Object obj) {
*
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
*
* <p>The Tensor object is no longer usable after {@code close} is invoked.
* <p>The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/java/src/main/native/exception_jni.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <stdarg.h>

#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"

const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
const char kIllegalStateException[] = "java/lang/IllegalStateException";
const char kNullPointerException[] = "java/lang/NullPointerException";

void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
char* message = nullptr;
if (vasprintf(&message, fmt, args) >= 0) {
env->ThrowNew(env->FindClass(clazz), message);
} else {
env->ThrowNew(env->FindClass(clazz), "");
}
va_end(args);
}
34 changes: 34 additions & 0 deletions tensorflow/java/src/main/native/exception_jni.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_JAVA_EXCEPTION_JNI_H_
#define TENSORFLOW_JAVA_EXCEPTION_JNI_H_

#include <jni.h>

#ifdef __cplusplus
extern "C" {
#endif

extern const char kIllegalArgumentException[];
extern const char kIllegalStateException[];
extern const char kNullPointerException[];

void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_EXCEPTION_JNI_H_
42 changes: 14 additions & 28 deletions tensorflow/java/src/main/native/tensor_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,22 @@ limitations under the License.
#include "tensorflow/java/src/main/native/tensor_jni.h"

#include <assert.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>

#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"

namespace {

const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
const char kIllegalStateException[] = "java/lang/IllegalStateException";
const char kNullPointerException[] = "java/lang/kNullPointerException";

void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
char* message = nullptr;
if (vasprintf(&message, fmt, args) >= 0) {
env->ThrowNew(env->FindClass(clazz), message);
} else {
env->ThrowNew(env->FindClass(clazz), "");
TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kNullPointerException,
"close() was called on the Tensor");
return nullptr;
}
va_end(args);
return reinterpret_cast<TF_Tensor*>(handle);
}

size_t elemByteSize(TF_DataType dtype) {
Expand Down Expand Up @@ -272,18 +265,16 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
if (handle == 0) {
return;
}
if (handle == 0) return;
TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
}

JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
assert(handle != 0);
TF_Tensor* t = reinterpret_cast<TF_Tensor*>(handle);
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
void* data = TF_TensorData(t);
Expand All @@ -299,8 +290,9 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
#define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \
JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \
JNIEnv* env, jclass clazz, jlong handle) { \
TF_Tensor* t = reinterpret_cast<TF_Tensor*>(handle); \
jtype ret = 0; \
TF_Tensor* t = requireHandle(env, handle); \
if (t == nullptr) return ret; \
if (TF_NumDims(t) != 0) { \
throwException(env, kIllegalStateException, "Tensor is not a scalar"); \
} else if (TF_TensorType(t) != dtype) { \
Expand All @@ -322,14 +314,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
// The exceptions thrown use "copyTo()" since readNDArray is a private
// function meant to serve the public Java copyTo() method.
if (handle == 0) {
throwException(env, kNullPointerException,
"copyTo() cannot be called after close()");
return;
}
TF_Tensor* t = reinterpret_cast<TF_Tensor*>(handle);
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
const void* data = TF_TensorData(t);
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,16 @@ public void failOnZeroDimension() {
// The expected exception.
}
}

@Test
public void useAfterClose() {
int n = 4;
Tensor t = Tensor.create(n);
t.close();
try {
t.intValue();
} catch (NullPointerException e) {
// The expected exception.
}
}
}

0 comments on commit 502675b

Please sign in to comment.