Skip to content

Commit

Permalink
chore: optimize yolo world postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
nullptr authored Apr 29, 2024
1 parent f4f3ddb commit 059356c
Showing 1 changed file with 59 additions and 49 deletions.
108 changes: 59 additions & 49 deletions core/algorithm/el_algorithm_yolo_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "el_algorithm_yolo_world.h"

#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "core/el_common.h"
Expand Down Expand Up @@ -149,71 +151,79 @@ el_err_code_t AlgorithmYOLOWorld::postprocess() {
_results.clear();

// get outputs
auto* data_bboxes{static_cast<int8_t*>(this->__p_engine->get_output(0))};
auto* data_scores{static_cast<int8_t*>(this->__p_engine->get_output(1))};
const auto* data_bboxes{static_cast<int8_t*>(this->__p_engine->get_output(0))};
const auto* data_scores{static_cast<int8_t*>(this->__p_engine->get_output(1))};

auto width{this->__input_shape.dims[1]};
auto height{this->__input_shape.dims[2]};
const auto width{this->__input_shape.dims[1]};
const auto height{this->__input_shape.dims[2]};

float scale_scores{_output_quant_params[1].scale};
scale_scores = scale_scores < 0.1f ? scale_scores * 100.f : scale_scores; // rescale
int32_t zero_point_scores{_output_quant_params[1].zero_point};
const int32_t zero_point_scores{_output_quant_params[1].zero_point};

float scale_bboxes{_output_quant_params[0].scale};
int32_t zero_point_bboxes{_output_quant_params[0].zero_point};
const float scale_bboxes{_output_quant_params[0].scale};
const int32_t zero_point_bboxes{_output_quant_params[0].zero_point};

auto num_bboxes{this->__output_shape.dims[1]};
auto num_element{this->__output_shape.dims[2]};
auto num_classes{static_cast<uint8_t>(_output_shapes[1].dims[2])};
const auto num_bboxes{this->__output_shape.dims[1]};
const auto num_elements{this->__output_shape.dims[2]};
const auto num_classes{_output_shapes[1].dims[2]};

ScoreType score_threshold{get_score_threshold()};
IoUType iou_threshold{get_iou_threshold()};
const ScoreType score_threshold{get_score_threshold()};
const IoUType iou_threshold{get_iou_threshold()};

// parse output
for (size_t bbox_i = 0; bbox_i < num_bboxes; ++bbox_i) {
size_t idx_s = bbox_i * num_classes;
for (size_t target_i = 0; target_i < num_classes; ++target_i) {
uint8_t bbox_i_score =
static_cast<decltype(scale_scores)>(data_scores[idx_s + target_i] - zero_point_scores) * scale_scores;
if (bbox_i_score < score_threshold) {
continue;
}
for (int bbox_i = 0; bbox_i < num_bboxes; ++bbox_i) {
const auto score_pre = bbox_i * num_classes;

ScoreType max_score = score_threshold;
int target = -1;

{
BoxType box{
.x = 0,
.y = 0,
.w = 0,
.h = 0,
.score = bbox_i_score,
.target = static_cast<decltype(BoxType::target)>(target_i),
};

size_t idx_b = bbox_i * num_element;
auto tl_x{((data_bboxes[idx_b + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)};
auto tl_y{((data_bboxes[idx_b + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)};
auto br_x{((data_bboxes[idx_b + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)};
auto br_y{((data_bboxes[idx_b + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)};

box.w = br_x - tl_x;
box.h = br_y - tl_y;
box.x = tl_x + box.w / 2;
box.y = tl_y + box.h / 2;

box.x = EL_CLIP(box.x, 0, width) * _w_scale;
box.y = EL_CLIP(box.y, 0, height) * _h_scale;
box.w = EL_CLIP(box.w, 0, width) * _w_scale;
box.h = EL_CLIP(box.h, 0, height) * _h_scale;

_results.emplace_front(std::move(box));
for (int class_i = 0; class_i < num_classes; ++class_i) {
const auto score = static_cast<ScoreType>(
static_cast<decltype(scale_scores)>(data_scores[score_pre + class_i] - zero_point_scores) * scale_scores);

if (score > max_score) {
max_score = score;
target = class_i;
}
}

if (target < 0) {
continue;
}

BoxType box;

box.score = max_score;
box.target = static_cast<decltype(BoxType::target)>(target);

auto bbox_idx = bbox_i * num_elements;

auto tl_x{((data_bboxes[bbox_idx + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)};
auto tl_y{((data_bboxes[bbox_idx + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)};
auto br_x{((data_bboxes[bbox_idx + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)};
auto br_y{((data_bboxes[bbox_idx + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)};

box.w = static_cast<decltype(BoxType::w)>(br_x - tl_x);
box.h = static_cast<decltype(BoxType::h)>(br_y - tl_y);

// if constexpr would be better (C++17)
static_assert(std::is_integral<decltype(box.w)>::value);
static_assert(std::is_integral<decltype(box.h)>::value);

box.x = tl_x + (box.w >> 1);
box.y = tl_y + (box.h >> 1);

box.x = EL_CLIP(box.x, 0, width) * _w_scale;
box.y = EL_CLIP(box.y, 0, height) * _h_scale;
box.w = EL_CLIP(box.w, 0, width) * _w_scale;
box.h = EL_CLIP(box.h, 0, height) * _h_scale;

_results.emplace_front(std::move(box));
}

el_nms(_results, iou_threshold, score_threshold, false, true);

_results.sort([](const BoxType& a, const BoxType& b) { return a.x < b.x; });

return EL_OK;
}

Expand Down

0 comments on commit 059356c

Please sign in to comment.