Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CMSIS-NN int8 and int16 batch matmul #2669

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ tflm_kernel_cc_library(
hdrs = [
"activations.h",
"add.h",
"batch_matmul.h",
"circular_buffer.h",
"conv.h",
"depthwise_conv.h",
Expand Down
159 changes: 22 additions & 137 deletions tensorflow/lite/micro/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -24,60 +24,31 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/batch_matmul.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
namespace {

constexpr int kInputLhsTensor = 0;
constexpr int kInputRhsTensor = 1;
constexpr int kOutputTensor = 0;

struct QuantizationOpData {
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift; // exponent

// The range of the fused activation layer. For example for kNone and
// int8_t these would be -128 and 127.
int32_t output_activation_min;
int32_t output_activation_max;

int32_t lhs_zero_point;
int32_t rhs_zero_point;
int32_t output_zero_point;
};

struct OpData {
QuantizationOpData* quantization;

// Transpose tensors and state
TfLiteEvalTensor* lhs_transposed_tensor;
TfLiteEvalTensor* rhs_transposed_tensor;
bool rhs_is_transposed;
bool lhs_is_constant_tensor;
bool rhs_is_constant_tensor;
};

struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node)
: params(static_cast<TfLiteBatchMatMulParams*>(node->builtin_data)),
op_data(static_cast<OpData*>(node->user_data)) {}
op_data(static_cast<OpDataBatchMatmul*>(node->user_data)) {}

TfLiteBatchMatMulParams* params;
OpData* op_data;
OpDataBatchMatmul* op_data;
};

struct PrepareOpContext : OpContext {
PrepareOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node),
micro_context_(GetMicroContext(context)),
lhs(micro_context_->AllocateTempInputTensor(node, kInputLhsTensor)),
rhs(micro_context_->AllocateTempInputTensor(node, kInputRhsTensor)),
output(micro_context_->AllocateTempOutputTensor(node, kOutputTensor)) {}
lhs(micro_context_->AllocateTempInputTensor(
node, kBatchMatmulInputLhsTensor)),
rhs(micro_context_->AllocateTempInputTensor(
node, kBatchMatmulInputRhsTensor)),
output(micro_context_->AllocateTempOutputTensor(
node, kBatchMatmulOutputTensor)) {}

~PrepareOpContext() {
if (lhs != nullptr) {
Expand All @@ -103,56 +74,18 @@ struct PrepareOpContext : OpContext {
struct EvalOpContext : OpContext {
EvalOpContext(TfLiteContext* context, TfLiteNode* node)
: OpContext(context, node),
lhs(tflite::micro::GetEvalInput(context, node, kInputLhsTensor)),
rhs(tflite::micro::GetEvalInput(context, node, kInputRhsTensor)),
output(tflite::micro::GetEvalOutput(context, node, kOutputTensor)) {}
lhs(tflite::micro::GetEvalInput(context, node,
kBatchMatmulInputLhsTensor)),
rhs(tflite::micro::GetEvalInput(context, node,
kBatchMatmulInputRhsTensor)),
output(tflite::micro::GetEvalOutput(context, node,
kBatchMatmulOutputTensor)) {}

const TfLiteEvalTensor* lhs;
const TfLiteEvalTensor* rhs;
TfLiteEvalTensor* output;
};

TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
const RuntimeShape& extended_lhs_shape,
const RuntimeShape& extended_rhs_shape,
bool adj_x, bool adj_y, int output_rank,
TfLiteTensor* output) {
int64_t orig_size = NumElements(output);

// make sure the new output dims rank does not exceed the original rank
TF_LITE_ENSURE(context, output_rank <= NumDimensions(output));

// make sure output tensor dims are not in the FlatBuffer
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));

// Fill in any broadcast dimensions.
for (int i = 0; i < output_rank - 2; ++i) {
const int lhs_dim = extended_lhs_shape.Dims(i);
const int rhs_dim = extended_rhs_shape.Dims(i);
int broadcast_dim = lhs_dim;
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
broadcast_dim = rhs_dim;
}
output->dims->data[i] = broadcast_dim;
}
// Fill in the matmul dimensions.
int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;

output->dims->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
output->dims->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
output->dims->size = output_rank;

// Check that output tensor has not been resized
// since TFLM doesn't support tensor resizing.
TF_LITE_ENSURE_EQ(context, orig_size, NumElements(output));

return kTfLiteOk;
}

TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor(
TfLiteContext* context, const TfLiteTensor& tensor) {
MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -195,7 +128,7 @@ TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor(
// Allocate normal quantization data if needed.
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
const PrepareOpContext& op_context) {
OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
const TfLiteTensor* lhs = op_context.lhs;
const TfLiteTensor* rhs = op_context.rhs;
MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -231,62 +164,14 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

template <typename Scalar>
void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
TfLiteEvalTensor* tensor_out) {
const Scalar* input = tflite::micro::GetTensorData<Scalar>(&tensor_in);
Scalar* output = tflite::micro::GetTensorData<Scalar>(tensor_out);
RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in));
RuntimeShape shape(transposed_shape);
TransposeParams params;
const int rank = shape.DimensionsCount();
params.perm_count = rank;
for (int i = 0; i < rank - 2; ++i) {
params.perm[i] = i;
}
// Transpose the last two dimensions.
params.perm[rank - 2] = rank - 1;
params.perm[rank - 1] = rank - 2;
transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
reference_ops::Transpose(params, shape, input, transposed_shape, output);
}

TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in,
TfLiteEvalTensor* tensor_out) {
if (tensor_in.type == kTfLiteFloat32) {
TransposeRowsColumnsImpl<float>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt8) {
TransposeRowsColumnsImpl<int8_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else if (tensor_in.type == kTfLiteInt16) {
TransposeRowsColumnsImpl<int16_t>(tensor_in, tensor_out);
return kTfLiteOk;
} else {
MicroPrintf(
"BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
"type.");
}
return kTfLiteError;
}

RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
RuntimeShape swapped_shape(shape);
const int32_t dims = shape.DimensionsCount();
swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
return swapped_shape;
}

void* BatchMatMulInit(TfLiteContext* context, const char* buffer,
size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
MicroContext* micro_context = GetMicroContext(context);
return micro_context->AllocatePersistentBuffer(sizeof(OpData));
return micro_context->AllocatePersistentBuffer(sizeof(OpDataBatchMatmul));
}

TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {
Expand Down Expand Up @@ -323,7 +208,7 @@ TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, op_context));

OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
// If the RHS is constant, we only transpose once.
op_data->rhs_is_transposed = false;
op_data->lhs_is_constant_tensor = IsConstantTensor(lhs_data);
Expand Down Expand Up @@ -393,7 +278,7 @@ TfLiteStatus BatchMatMulPrepare(TfLiteContext* context, TfLiteNode* node) {
return status;
}

TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data,
TfLiteStatus EvalInt8(TfLiteContext* context, const OpDataBatchMatmul& data,
const RuntimeShape& lhs_shape,
const TfLiteEvalTensor& lhs,
const RuntimeShape& rhs_shape,
Expand Down Expand Up @@ -423,7 +308,7 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data,
return kTfLiteOk;
}

TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data,
const RuntimeShape& lhs_shape,
const TfLiteEvalTensor& lhs,
const RuntimeShape& rhs_shape,
Expand Down Expand Up @@ -466,7 +351,7 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data,
// A X C row-oriented.
TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
EvalOpContext op_context(context, node);
OpData* op_data = op_context.op_data;
OpDataBatchMatmul* op_data = op_context.op_data;
const TfLiteEvalTensor* lhs = op_context.lhs;
const TfLiteEvalTensor* rhs = op_context.rhs;
TfLiteEvalTensor* output = op_context.output;
Expand Down
Loading
Loading