Skip to content

Commit

Permalink
chore: update the serialization process for encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
nullptr committed Nov 25, 2024
1 parent 268027a commit a899a9f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
45 changes: 28 additions & 17 deletions sscma/server/at/callback/refactor_required.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ ma_err_t setAlgorithmInput(Model* algorithm, ma_img_t& img) {
return static_cast<Detector*>(algorithm)->run(&img);

case MA_MODEL_TYPE_YOLOV8_POSE:
case MA_MODEL_TYPE_YOLO11_POSE:
return static_cast<PoseDetector*>(algorithm)->run(&img);

default:
Expand All @@ -76,16 +77,22 @@ ma_err_t serializeAlgorithmOutput(Model* algorithm, Encoder* encoder, int width,

switch (algorithm->getType()) {
case MA_MODEL_TYPE_PFLD: {
const auto& results = static_cast<PointDetector*>(algorithm)->getResults();
// encoder->write(results);
auto results = static_cast<PointDetector*>(algorithm)->getResults();
for (auto& result : results) {
result.x = static_cast<int>(std::round(result.x * width));
result.y = static_cast<int>(std::round(result.y * height));
result.score = static_cast<int>(std::round(result.score * 100));
}
ret = encoder->write(results);

break;
}

case MA_MODEL_TYPE_IMCLS: {

auto results = static_cast<Classifier*>(algorithm)->getResults();
for (auto& result : results) {
result.score *= 100;
result.score = static_cast<int>(std::round(result.score * 100));
}
ret = encoder->write(results);

Expand All @@ -95,36 +102,38 @@ ma_err_t serializeAlgorithmOutput(Model* algorithm, Encoder* encoder, int width,
case MA_MODEL_TYPE_FOMO:
case MA_MODEL_TYPE_YOLOV5:
case MA_MODEL_TYPE_YOLOV8:
case MA_MODEL_TYPE_YOLO11:
case MA_MODEL_TYPE_NVIDIA_DET:
case MA_MODEL_TYPE_YOLO_WORLD: {

auto results = static_cast<Detector*>(algorithm)->getResults();
MA_LOGD(MA_TAG, "Results size: %d", std::distance(results.begin(), results.end()));
for (auto& result : results) {
result.x *= width;
result.y *= height;
result.w *= width;
result.h *= height;
result.score *= 100;
result.x = static_cast<int>(std::round(result.x * width));
result.y = static_cast<int>(std::round(result.y * height));
result.w = static_cast<int>(std::round(result.w * width));
result.h = static_cast<int>(std::round(result.h * height));
result.score = static_cast<int>(std::round(result.score * 100));
}
ret = encoder->write(results);

break;
}

case MA_MODEL_TYPE_YOLOV8_POSE: {
case MA_MODEL_TYPE_YOLOV8_POSE:
case MA_MODEL_TYPE_YOLO11_POSE: {

auto results = static_cast<PoseDetector*>(algorithm)->getResults();
for (auto& result : results) {
auto& box = result.box;
box.x *= width;
box.y *= height;
box.w *= width;
box.h *= height;
box.score *= 100;
box.x = static_cast<int>(std::round(box.x * width));
box.y = static_cast<int>(std::round(box.y * height));
box.w = static_cast<int>(std::round(box.w * width));
box.h = static_cast<int>(std::round(box.h * height));
box.score = static_cast<int>(std::round(box.score * 100));
for (auto& pt : result.pts) {
pt.x *= width;
pt.y *= height;
pt.x = static_cast<int>(std::round(pt.x * width));
pt.y = static_cast<int>(std::round(pt.y * height));
}
}
ret = encoder->write(results);
Expand Down Expand Up @@ -246,6 +255,7 @@ struct TriggerRule {
case MA_MODEL_TYPE_FOMO:
case MA_MODEL_TYPE_YOLOV5:
case MA_MODEL_TYPE_YOLOV8:
case MA_MODEL_TYPE_YOLO11:
case MA_MODEL_TYPE_NVIDIA_DET:
case MA_MODEL_TYPE_YOLO_WORLD: {
auto results = static_cast<Detector*>(algorithm)->getResults();
Expand All @@ -259,7 +269,8 @@ struct TriggerRule {
break;
}

case MA_MODEL_TYPE_YOLOV8_POSE: {
case MA_MODEL_TYPE_YOLOV8_POSE:
case MA_MODEL_TYPE_YOLO11_POSE: {
auto results = static_cast<PoseDetector*>(algorithm)->getResults();
for (auto& result : results) {
if (result.box.target == class_id && comp(result.box.score, threshold)) {
Expand Down
22 changes: 14 additions & 8 deletions sscma/server/at/codec/ma_codec_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,11 @@ ma_err_t EncoderJSON::write(const std::forward_list<ma_keypoint3f_t>& value) {
if (item == nullptr) {
return MA_FAILED;
}
// pts
cJSON* pts = cJSON_CreateArray();
for (const auto& pt : it->pts) {
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.x));
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.y));
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.z));
}
cJSON_AddItemToArray(item, pts);
// box
cJSON* box = cJSON_CreateArray();
if (box == nullptr) {
return MA_FAILED;
}
cJSON_AddItemToArray(box, cJSON_CreateNumber(it->box.x));
cJSON_AddItemToArray(box, cJSON_CreateNumber(it->box.y));
cJSON_AddItemToArray(box, cJSON_CreateNumber(it->box.w));
Expand All @@ -341,6 +336,17 @@ ma_err_t EncoderJSON::write(const std::forward_list<ma_keypoint3f_t>& value) {
cJSON_AddItemToArray(box, cJSON_CreateNumber(it->box.target));
cJSON_AddItemToArray(item, box);
cJSON_AddItemToArray(array, item);
// pts
cJSON* pts = cJSON_CreateArray();
if (pts == nullptr) {
return MA_FAILED;
}
for (const auto& pt : it->pts) {
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.x));
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.y));
cJSON_AddItemToArray(pts, cJSON_CreateNumber(pt.z));
}
cJSON_AddItemToArray(item, pts);
}

return MA_OK;
Expand Down

0 comments on commit a899a9f

Please sign in to comment.