Skip to content

Commit

Permalink
Merge branch 'main' into misc
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Sep 27, 2024
2 parents 62255cc + 4ed97bf commit 1b0a8e2
Show file tree
Hide file tree
Showing 20 changed files with 146 additions and 38 deletions.
1 change: 1 addition & 0 deletions python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ py_test(
],
deps = [
":runtime",
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
"//tensorflow/lite/micro/examples/recipes:add_four_numbers",
Expand Down
5 changes: 3 additions & 2 deletions python/tflite_micro/runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tflite_micro.python.tflite_micro import runtime
Expand Down Expand Up @@ -199,10 +200,10 @@ def testCompareWithTFLite(self):
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)

# TFLite interpreter
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_content=model_data,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
tflite_interpreter.allocate_tensors()
tflite_output_details = tflite_interpreter.get_output_details()[0]
tflite_input_details = tflite_interpreter.get_input_details()[0]
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/lite/core/c/c_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ typedef enum TfLiteStatus {
// TODO(b/250636993): Cancellation triggered by `SetCancellationFunction`
// should also return this status code.
kTfLiteCancelled = 8,

// This status is returned by Prepare when the output shape cannot be
// determined but the size of the output tensor is known. For example, the
// output of reshape is always the same size as the input. This means that
// such ops may be
// done in place.
kTfLiteOutputShapeNotKnown = 9,
} TfLiteStatus;

/// Types supported by tensor
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/kernels/internal/reference/comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

#include "tensorflow/lite/kernels/internal/reference/comparisons.h"

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"

namespace tflite {
namespace reference_ops {

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/kernels/internal/reference/comparisons.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_

#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/macros.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/types.h"

namespace tflite {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/examples/hello_world/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ py_binary(
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@absl_py//absl/logging",
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
"//python/tflite_micro:runtime",
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/lite/micro/examples/hello_world/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tensorflow as tf
from absl import app
from absl import flags
from ai_edge_litert import interpreter as litert_interpreter
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.platform import resource_loader
Expand Down Expand Up @@ -92,9 +93,9 @@ def get_tflm_prediction(model_path, x_values):
# returns the prediction of the interpreter.
def get_tflite_prediction(model_path, x_values):
# TFLite interpreter
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=model_path,
experimental_op_resolver_type=tf.lite.experimental.OpResolverType.
experimental_op_resolver_type=litert_interpreter.OpResolverType.
BUILTIN_REF,
)
tflite_interpreter.allocate_tensors()
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/examples/mnist_lstm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py_binary(
srcs = ["train.py"],
srcs_version = "PY3",
deps = [
requirement("ai-edge-litert"),
requirement("numpy"),
requirement("tensorflow"),
],
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/lite/micro/examples/mnist_lstm/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
Expand All @@ -43,10 +44,10 @@ def testInputErrHandling(self):
evaluate.predict_image(self.tflm_interpreter, wrong_size_image_path)

def testCompareWithTFLite(self):
tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=self.model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
tflite_interpreter.allocate_tensors()
tflite_output_details = tflite_interpreter.get_output_details()[0]
tflite_input_details = tflite_interpreter.get_input_details()[0]
Expand Down
19 changes: 14 additions & 5 deletions tensorflow/lite/micro/kernels/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void* Relu6Init(TfLiteContext* context, const char* buffer, size_t length) {

TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);

const Relu6OpData& data = *(static_cast<const Relu6OpData*>(node->user_data));

const TfLiteEvalTensor* input =
Expand All @@ -92,11 +93,19 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
case kTfLiteInt8: {
Relu6Quantized(data.zero_int8, data.six_int8,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
Relu6Quantized<int8_t>(data.zero, data.six,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
case kTfLiteInt16: {
Relu6Quantized<int16_t>(data.zero, data.six,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
}
default: {
Expand Down
18 changes: 13 additions & 5 deletions tensorflow/lite/micro/kernels/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct ReluOpData {
};

struct Relu6OpData {
int8_t six_int8;
int8_t zero_int8;
int32_t six;
int32_t zero;
};

void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape,
Expand All @@ -50,9 +50,17 @@ void ReluFloat(const RuntimeShape& input_shape, const float* input_data,
void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data);

void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& output_shape,
int8_t* output_data);
template <typename T>
void Relu6Quantized(T lower, T upper, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const T val = input_data[i];
const T clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}

TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node);

Expand Down
24 changes: 10 additions & 14 deletions tensorflow/lite/micro/kernels/activations_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,6 @@ void Relu6Float(const RuntimeShape& input_shape, const float* input_data,
}
}

void Relu6Quantized(int8_t lower, int8_t upper, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& output_shape,
int8_t* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const int8_t val = input_data[i];
const int8_t clamped = val > upper ? upper : val < lower ? lower : val;
output_data[i] = clamped;
}
}

TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
Expand All @@ -137,6 +126,7 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {

TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);

Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);

MicroContext* micro_context = GetMicroContext(context);
Expand All @@ -145,9 +135,15 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input != nullptr);

if (input->type == kTfLiteInt8) {
data->six_int8 = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
input->params.zero_point);
data->zero_int8 = input->params.zero_point;
data->zero = input->params.zero_point;
data->six = FloatToQuantizedType<int8_t>(6.0f, input->params.scale,
input->params.zero_point);
TF_LITE_ENSURE(context, data->six >= INT8_MIN && data->six <= INT8_MAX);
} else if (input->type == kTfLiteInt16) {
data->zero = input->params.zero_point;
data->six = FloatToQuantizedType<int16_t>(6.0f, input->params.scale,
input->params.zero_point);
TF_LITE_ENSURE(context, data->six >= INT16_MIN && data->six <= INT16_MAX);
}

micro_context->DeallocateTempTfLiteTensor(input);
Expand Down
62 changes: 62 additions & 0 deletions tensorflow/lite/micro/kernels/activations_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,46 @@ void TestRelu6Int8(int* input_dims_data, const float* input_data,
}
}

void TestRelu6Int16(int* input_dims_data, const float* input_data,
int16_t* input_data_quantized, const float input_scale,
const int input_zero_point, const float* golden,
int16_t* golden_quantized, int* output_dims_data,
const float output_scale, const int output_zero_point,
int16_t* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_elements_count = ElementCount(*output_dims);
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input_data, input_data_quantized, input_dims,
input_scale, input_zero_point),
CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point),
};

int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

const TFLMRegistration registration = Register_RELU6();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);

TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());

Quantize(golden, golden_quantized, output_elements_count, output_scale,
output_zero_point);

for (int i = 0; i < output_elements_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]);
}
}

} // namespace
} // namespace testing
} // namespace tflite
Expand Down Expand Up @@ -247,4 +287,26 @@ TF_LITE_MICRO_TEST(SimpleRelu6TestInt8) {
output_zero_point, output_data);
}

TF_LITE_MICRO_TEST(SimpleRelu6TestInt16) {
const int elements_count = 10;

int input_shape[] = {2, 1, 5};
const float input_data[] = {4, 5, 6, 7, 8, -1, -2, -3, -4, -5};
int16_t input_quantized[elements_count];
int output_shape[] = {2, 1, 5};
const float golden[] = {4, 5, 6, 6, 6, 0, 0, 0, 0, 0};
int16_t golden_quantized[elements_count];
int16_t output_data[elements_count];

const float input_scale = 0.5f;
const int input_zero_point = 0;
const float output_scale = 0.5f;
const int output_zero_point = 0;

tflite::testing::TestRelu6Int16(input_shape, input_data, input_quantized,
input_scale, input_zero_point, golden,
golden_quantized, output_shape, output_scale,
output_zero_point, output_data);
}

TF_LITE_MICRO_TESTS_END
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ tflite::BufferPlan* CreateBufferPlan() {
// Some targets do not support dynamic memory (i.e., no malloc or new), thus,
// the test need to place non-transitent memories in static variables. This is
// safe because tests are guarateed to run serially.
static int8_t buffer_plan_buffer[tflite::SizeOfBufferPlan(kBufferCnt)];
alignas(tflite::BufferPlan) static int8_t
buffer_plan_buffer[tflite::SizeOfBufferPlan(kBufferCnt)];
tflite::BufferPlan* buffer_plan_ptr =
reinterpret_cast<tflite::BufferPlan*>(buffer_plan_buffer);
new (buffer_plan_buffer) tflite::BufferPlan();
buffer_plan_ptr->buffer_count = kBufferCnt;
buffer_plan_ptr->buffer_plan_entries[0].offset = kBuffer0Offset;
buffer_plan_ptr->buffer_plan_entries[1].offset = kBuffer1Offset;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ py_library(
srcs_version = "PY3",
visibility = ["//:__subpackages__"],
deps = [
requirement("ai-edge-litert"),
"//tensorflow/lite/python:schema_py",
],
)
Expand Down Expand Up @@ -208,6 +209,7 @@ py_binary(
":model_transforms_utils",
"@absl_py//absl:app",
"@absl_py//absl/flags",
requirement("ai-edge-litert"),
requirement("tensorflow"),
"//python/tflite_micro:runtime",
"//tensorflow/lite/tools:flatbuffer_utils",
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/lite/micro/tools/generate_test_for_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import tensorflow as tf

from ai_edge_litert import interpreter as litert_interpreter
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb


Expand Down Expand Up @@ -103,9 +104,9 @@ def generate_golden_single_in_single_out(self):
if (len(self.model_paths) != 1):
raise RuntimeError(f'Single model expected')
model_path = self.model_paths[0]
interpreter = tf.lite.Interpreter(model_path=model_path,
interpreter = litert_interpreter.Interpreter(model_path=model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)

interpreter.allocate_tensors()

Expand Down Expand Up @@ -140,10 +141,10 @@ def generate_goldens(self, builtin_operator):

for model_path in self.model_paths:
# Load model and run a single inference with random inputs.
interpreter = tf.lite.Interpreter(
interpreter = litert_interpreter.Interpreter(
model_path=model_path,
experimental_op_resolver_type=\
tf.lite.experimental.OpResolverType.BUILTIN_REF)
litert_interpreter.OpResolverType.BUILTIN_REF)
interpreter.allocate_tensors()
input_tensor = interpreter.tensor(
interpreter.get_input_details()[0]['index'])
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/tools/layer_by_layer_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import app
from absl import flags
from absl import logging
from ai_edge_litert import interpreter as litert_interpreter
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -194,7 +195,7 @@ def main(_) -> None:
intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors,
)

tflite_interpreter = tf.lite.Interpreter(
tflite_interpreter = litert_interpreter.Interpreter(
model_path=_INPUT_TFLITE_FILE.value,
experimental_preserve_all_tensors=True,
)
Expand Down
Loading

0 comments on commit 1b0a8e2

Please sign in to comment.