From 059356caa80d1598886bcf26af7151b77f13522b Mon Sep 17 00:00:00 2001 From: nullptr Date: Mon, 29 Apr 2024 02:04:01 +0000 Subject: [PATCH] chore: optimize yolo world postprocess --- core/algorithm/el_algorithm_yolo_world.cpp | 108 +++++++++++---------- 1 file changed, 59 insertions(+), 49 deletions(-) diff --git a/core/algorithm/el_algorithm_yolo_world.cpp b/core/algorithm/el_algorithm_yolo_world.cpp index 73735cd7..d8708ceb 100644 --- a/core/algorithm/el_algorithm_yolo_world.cpp +++ b/core/algorithm/el_algorithm_yolo_world.cpp @@ -26,6 +26,8 @@ #include "el_algorithm_yolo_world.h" #include +#include +#include #include #include "core/el_common.h" @@ -149,71 +151,79 @@ el_err_code_t AlgorithmYOLOWorld::postprocess() { _results.clear(); // get outputs - auto* data_bboxes{static_cast(this->__p_engine->get_output(0))}; - auto* data_scores{static_cast(this->__p_engine->get_output(1))}; + const auto* data_bboxes{static_cast(this->__p_engine->get_output(0))}; + const auto* data_scores{static_cast(this->__p_engine->get_output(1))}; - auto width{this->__input_shape.dims[1]}; - auto height{this->__input_shape.dims[2]}; + const auto width{this->__input_shape.dims[1]}; + const auto height{this->__input_shape.dims[2]}; float scale_scores{_output_quant_params[1].scale}; scale_scores = scale_scores < 0.1f ? scale_scores * 100.f : scale_scores; // rescale - int32_t zero_point_scores{_output_quant_params[1].zero_point}; + const int32_t zero_point_scores{_output_quant_params[1].zero_point}; - float scale_bboxes{_output_quant_params[0].scale}; - int32_t zero_point_bboxes{_output_quant_params[0].zero_point}; + const float scale_bboxes{_output_quant_params[0].scale}; + const int32_t zero_point_bboxes{_output_quant_params[0].zero_point}; - auto num_bboxes{this->__output_shape.dims[1]}; - auto num_element{this->__output_shape.dims[2]}; - auto num_classes{static_cast(_output_shapes[1].dims[2])}; + const auto num_bboxes{this->__output_shape.dims[1]}; + const auto num_elements{this->__output_shape.dims[2]}; + const auto num_classes{_output_shapes[1].dims[2]}; - ScoreType score_threshold{get_score_threshold()}; - IoUType iou_threshold{get_iou_threshold()}; + const ScoreType score_threshold{get_score_threshold()}; + const IoUType iou_threshold{get_iou_threshold()}; // parse output - for (size_t bbox_i = 0; bbox_i < num_bboxes; ++bbox_i) { - size_t idx_s = bbox_i * num_classes; - for (size_t target_i = 0; target_i < num_classes; ++target_i) { - uint8_t bbox_i_score = - static_cast(data_scores[idx_s + target_i] - zero_point_scores) * scale_scores; - if (bbox_i_score < score_threshold) { - continue; - } + for (int bbox_i = 0; bbox_i < num_bboxes; ++bbox_i) { + const auto score_pre = bbox_i * num_classes; + + ScoreType max_score = score_threshold; + int target = -1; - { - BoxType box{ - .x = 0, - .y = 0, - .w = 0, - .h = 0, - .score = bbox_i_score, - .target = static_cast(target_i), - }; - - size_t idx_b = bbox_i * num_element; - auto tl_x{((data_bboxes[idx_b + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)}; - auto tl_y{((data_bboxes[idx_b + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)}; - auto br_x{((data_bboxes[idx_b + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)}; - auto br_y{((data_bboxes[idx_b + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)}; - - box.w = br_x - tl_x; - box.h = br_y - tl_y; - box.x = tl_x + box.w / 2; - box.y = tl_y + box.h / 2; - - box.x = EL_CLIP(box.x, 0, width) * _w_scale; - box.y = EL_CLIP(box.y, 0, height) * _h_scale; - box.w = EL_CLIP(box.w, 0, width) * _w_scale; - box.h = EL_CLIP(box.h, 0, height) * _h_scale; - - _results.emplace_front(std::move(box)); + for (int class_i = 0; class_i < num_classes; ++class_i) { + const auto score = static_cast( + static_cast(data_scores[score_pre + class_i] - zero_point_scores) * scale_scores); + + if (score > max_score) { + max_score = score; + target = class_i; } } + + if (target < 0) { + continue; + } + + BoxType box; + + box.score = max_score; + box.target = static_cast(target); + + auto bbox_idx = bbox_i * num_elements; + + auto tl_x{((data_bboxes[bbox_idx + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)}; + auto tl_y{((data_bboxes[bbox_idx + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)}; + auto br_x{((data_bboxes[bbox_idx + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)}; + auto br_y{((data_bboxes[bbox_idx + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)}; + + box.w = static_cast(br_x - tl_x); + box.h = static_cast(br_y - tl_y); + + // if constexpr would be better (C++17) + static_assert(std::is_integral::value); + static_assert(std::is_integral::value); + + box.x = tl_x + (box.w >> 1); + box.y = tl_y + (box.h >> 1); + + box.x = EL_CLIP(box.x, 0, width) * _w_scale; + box.y = EL_CLIP(box.y, 0, height) * _h_scale; + box.w = EL_CLIP(box.w, 0, width) * _w_scale; + box.h = EL_CLIP(box.h, 0, height) * _h_scale; + + _results.emplace_front(std::move(box)); } el_nms(_results, iou_threshold, score_threshold, false, true); - _results.sort([](const BoxType& a, const BoxType& b) { return a.x < b.x; }); - return EL_OK; }