Skip to content

Commit

Permalink
feat: add support for YOLO11 classification model
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnL4 committed Nov 22, 2024
1 parent d688561 commit d737124
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions sscma/core/model/ma_model_classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,36 @@ Classifier::~Classifier() {}

bool Classifier::isValid(Engine* engine) {

const auto& input_shape = engine->getInputShape(0);
auto is_nhwc{input_shape.dims[3] == 3 || input_shape.dims[3] == 1};

if (is_nhwc) {
if (input_shape.size != 4 || // N, H, W, C
input_shape.dims[0] != 1 || // N = 1
input_shape.dims[1] < 16 || // H >= 16
input_shape.dims[2] < 16 || // W >= 16
(input_shape.dims[3] != 3 && // C = RGB or Gray
input_shape.dims[3] != 1))
return false;
} else {
const auto inputs_count = engine->getInputSize();
const auto outputs_count = engine->getOutputSize();

if (input_shape.size != 4 || // N, C, H, W
input_shape.dims[0] != 1 || // N = 1
input_shape.dims[2] < 16 || // H >= 16
input_shape.dims[3] < 16 || // W >= 16
(input_shape.dims[1] != 3 && // C = RGB or Gray
input_shape.dims[1] != 1))
return false;
if (inputs_count != 1 || outputs_count != 1) {
return false;
}


const auto& input_shape = engine->getInputShape(0);
const auto& output_shape{engine->getOutputShape(0)};

if (output_shape.size != 2 || // N, C
output_shape.dims[0] != 1 || // N = 1
int n = input_shape.dims[0], h = input_shape.dims[1], w = input_shape.dims[2], c = input_shape.dims[3];
bool is_nhwc = c == 3 || c == 1;

if (!is_nhwc)
std::swap(h, c);

if (n != 1 || h < 32 || h % 32 != 0 || (c != 3 && c != 1))
return false;


if (output_shape.dims[0] != 1 || // N = 1
output_shape.dims[1] < 2 // C >= 2
) {
return false;
}

if (output_shape.size == 4 && (output_shape.dims[2] != 1 || output_shape.dims[3] != 1)) {
return false;
}

return true;
}

Expand Down Expand Up @@ -107,6 +105,16 @@ ma_err_t Classifier::postprocess() {
if (score > threshold_score_)
results_.emplace_front(ma_class_t{score, i});
}
}
if (output_.type == MA_TENSOR_TYPE_F32) {
auto* data = output_.data.f32;
auto pred_l{output_.shape.dims[1]};
for (decltype(pred_l) i{0}; i < pred_l; ++i) {
auto score{data[i]};
if (score > threshold_score_)
results_.emplace_front(ma_class_t{score, i});
}

} else {
return MA_ENOTSUP;
}
Expand All @@ -122,7 +130,7 @@ const std::forward_list<ma_class_t>& Classifier::getResults() {
}

const void* Classifier::getInput() {
return static_cast<const void*>(&img_);
return &img_;
}

ma_err_t Classifier::run(const ma_img_t* img) {
Expand Down

0 comments on commit d737124

Please sign in to comment.