Skip to content

Commit

Permalink
[luci-interpreter] Avoid Conv2D integer overflow (#13690)
Browse files Browse the repository at this point in the history
This adds a guard code to avoid integer overflow in optimized Conv2D kernel.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Aug 19, 2024
1 parent d4b32bd commit d82e145
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions compiler/luci-interpreter/pal/linux/PALConv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,30 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
float *scratchpad_data)
{
(void)scratchpad_shape;
if (scratchpad_data)

const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
const int32_t output_height = output_shape.Dims(1);
const int32_t output_width = output_shape.Dims(2);
const int32_t filter_height = filter_shape.Dims(1);
const int32_t filter_width = filter_shape.Dims(2);

int64_t im2col_flat_size = 1;
im2col_flat_size *= batches;
im2col_flat_size *= output_height;
im2col_flat_size *= output_width;
im2col_flat_size *= input_depth;
im2col_flat_size *= filter_height;
im2col_flat_size *= filter_width;

// This condition checks if integer overflow will occur inside the optimized kernel.
// https://github.com/tensorflow/tensorflow/blob/v2.12.1/tensorflow/lite/kernels/internal/optimized/im2col_utils.h#L81
// If overflow is expected, we fall back to the reference kernel.
// NOTE This is just a rough check.
bool opt_kernel_overflow = im2col_flat_size > std::numeric_limits<int32_t>::max();

if (scratchpad_data and not opt_kernel_overflow)
{
const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
const int32_t output_height = output_shape.Dims(1);
const int32_t output_width = output_shape.Dims(2);
const int32_t filter_height = filter_shape.Dims(1);
const int32_t filter_width = filter_shape.Dims(2);
tflite::RuntimeShape im2col_shape{batches, output_height, output_width,
input_depth * filter_height * filter_width};

Expand Down

0 comments on commit d82e145

Please sign in to comment.