Skip to content

Commit

Permalink
fix: apply NMS again after nmspostprocess to ensure external paramete…
Browse files Browse the repository at this point in the history
…rs are effective
  • Loading branch information
LynnL4 committed Nov 28, 2024
1 parent 9bf406c commit 38f4cfe
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions sscma/core/model/ma_model_yolov5.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <algorithm>
#include <forward_list>
#include <vector>
#include <utility>
#include <vector>

#include "../utils/ma_nms.h"

Expand All @@ -24,7 +24,7 @@ YoloV5::YoloV5(Engine* p_engine_) : Detector(p_engine_, "yolov5", MA_MODEL_TYPE_
YoloV5::~YoloV5() {}

static bool generalValid(Engine* engine) {
const auto inputs_count = engine->getInputSize();
const auto inputs_count = engine->getInputSize();
const auto outputs_count = engine->getOutputSize();

if (inputs_count != 1 || outputs_count != 1) {
Expand All @@ -37,8 +37,7 @@ static bool generalValid(Engine* engine) {
if (input_shape.size != 4)
return false;

int n = input_shape.dims[0], h = input_shape.dims[1], w = input_shape.dims[2],
c = input_shape.dims[3];
int n = input_shape.dims[0], h = input_shape.dims[1], w = input_shape.dims[2], c = input_shape.dims[3];
bool is_nhwc = c == 3 || c == 1;

if (!is_nhwc)
Expand All @@ -55,8 +54,7 @@ static bool generalValid(Engine* engine) {
if (output_shape.size != 3 && output_shape.size != 4)
return false;

if (output_shape.dims[0] != 1 || output_shape.dims[1] != ibox_len || output_shape.dims[2] < 6 ||
output_shape.dims[2] > 85)
if (output_shape.dims[0] != 1 || output_shape.dims[1] != ibox_len || output_shape.dims[2] < 6 || output_shape.dims[2] > 85)
return false;

return true;
Expand All @@ -67,7 +65,7 @@ static bool nmsValid(Engine* engine) {
if (engine->getInputSize() != 1 || engine->getOutputSize() != 1)
return false;

auto input = engine->getInput(0);
auto input = engine->getInput(0);
auto output = engine->getOutput(0);

if (input.shape.size != 4 || output.shape.size != 4)
Expand All @@ -86,7 +84,7 @@ static bool nmsValid(Engine* engine) {
auto mb = output.shape.dims[2];
auto f = output.shape.dims[3];

if (b != 1 || cs <= 0 || mb <= 1 || f != 0)
if (b != 1 || cs <= 0 || mb <= 1 || f != 0)
return false;

return true;
Expand Down Expand Up @@ -146,12 +144,7 @@ ma_err_t YoloV5::generalPostProcess() {
h /= img_.height;
}

ma_bbox_t box{.x = MA_CLIP(x, 0, 1.0f),
.y = MA_CLIP(y, 0, 1.0f),
.w = MA_CLIP(w, 0, 1.0f),
.h = MA_CLIP(h, 0, 1.0f),
.score = score,
.target = target};
ma_bbox_t box{.x = MA_CLIP(x, 0, 1.0f), .y = MA_CLIP(y, 0, 1.0f), .w = MA_CLIP(w, 0, 1.0f), .h = MA_CLIP(h, 0, 1.0f), .score = score, .target = target};

results_.emplace_front(box);
}
Expand Down Expand Up @@ -187,12 +180,7 @@ ma_err_t YoloV5::generalPostProcess() {
h /= img_.height;
}

ma_bbox_t box{.x = MA_CLIP(x, 0, 1.0f),
.y = MA_CLIP(y, 0, 1.0f),
.w = MA_CLIP(w, 0, 1.0f),
.h = MA_CLIP(h, 0, 1.0f),
.score = score,
.target = target};
ma_bbox_t box{.x = MA_CLIP(x, 0, 1.0f), .y = MA_CLIP(y, 0, 1.0f), .w = MA_CLIP(w, 0, 1.0f), .h = MA_CLIP(h, 0, 1.0f), .score = score, .target = target};

results_.emplace_front(box);
}
Expand Down Expand Up @@ -254,7 +242,7 @@ ma_err_t YoloV5::nmsPostProcess() {
ptr += sizeof(P);

ma_bbox_t res;

auto x_min = static_cast<float>(bbox.x_min - zp) * scale;
auto y_min = static_cast<float>(bbox.y_min - zp) * scale;
auto x_max = static_cast<float>(bbox.x_max - zp) * scale;
Expand All @@ -264,7 +252,7 @@ ma_err_t YoloV5::nmsPostProcess() {
res.x = x_min + res.w * 0.5;
res.y = y_min + res.h * 0.5;
res.score = static_cast<float>(bbox.score - zp) * scale;

res.target = static_cast<int>(i);

res.x = MA_CLIP(res.x, 0, 1.0f);
Expand All @@ -276,7 +264,7 @@ ma_err_t YoloV5::nmsPostProcess() {
}
}
} break;

case MA_TENSOR_TYPE_NMS_BBOX_F32: {
using T = float32_t;
using P = hailo_bbox_float32_t;
Expand All @@ -297,13 +285,13 @@ ma_err_t YoloV5::nmsPostProcess() {
ptr += sizeof(P);

ma_bbox_t res;

res.w = bbox.x_max - bbox.x_min;
res.h = bbox.y_max - bbox.y_min;
res.x = bbox.x_min + res.w * 0.5;
res.y = bbox.y_min + res.h * 0.5;
res.score = bbox.score;

res.target = static_cast<int>(i);

res.x = MA_CLIP(res.x, 0, 1.0f);
Expand All @@ -315,11 +303,15 @@ ma_err_t YoloV5::nmsPostProcess() {
}
}
} break;

default:
return MA_ENOTSUP;
}

ma::utils::nms(results_, threshold_nms_, threshold_score_, false, false);

results_.sort([](const ma_bbox_t& a, const ma_bbox_t& b) { return a.x < b.x; });

return MA_OK;
#else
return MA_FAILED;
Expand All @@ -342,7 +334,7 @@ ma_err_t YoloV5::postprocess() {
threshold_score_ = thr;
}
thr = threshold_nms_;
rc = (*ph)(3, &thr, sizeof(float));
rc = (*ph)(3, &thr, sizeof(float));
if (rc == MA_OK) {
threshold_nms_ = thr;
}
Expand All @@ -355,6 +347,7 @@ ma_err_t YoloV5::postprocess() {
return generalPostProcess();
}


return MA_ENOTSUP;
}
} // namespace ma::model

0 comments on commit 38f4cfe

Please sign in to comment.