From 0f998710e3e0e8c38f70f6732b18cd58ac4999b3 Mon Sep 17 00:00:00 2001 From: ncdhz <1137436221@qq.com> Date: Fri, 10 May 2024 11:40:29 +0800 Subject: [PATCH] delete maxNms --- include/Yolo.h | 70 +++++++++++++++++++++++++------------------------- src/Yolo.cpp | 65 ++++++++++++++++++++++------------------------ 2 files changed, 65 insertions(+), 70 deletions(-) diff --git a/include/Yolo.h b/include/Yolo.h index 1bf0fb9..f062741 100644 --- a/include/Yolo.h +++ b/include/Yolo.h @@ -13,47 +13,47 @@ class ImageResizeData { public: - // 设置处理过后的图片 + // 设置处理过后的图片 void setImg(cv::Mat img); - // 获取处理过后的图片 + // 获取处理过后的图片 cv::Mat getImg(); - // 当原始图片宽高比大于处理过后图片宽高比时此函数返回 true - bool isW(); - // 当原始图片高宽比大于处理过后图片高宽比时此函数返回 true - bool isH(); - // 设置处理之后图片的宽 + // 当原始图片宽高比大于处理过后图片宽高比时此函数返回 true + bool isW() const; + // 当原始图片高宽比大于处理过后图片高宽比时此函数返回 true + bool isH() const; + // 设置处理之后图片的宽 void setWidth(int width); - // 获取处理之后图片的宽 + // 获取处理之后图片的宽 int getWidth(); - // 设置处理之后图片的高 + // 设置处理之后图片的高 void setHeight(int height); - // 获取处理之后图片的高 - int getHeight(); - // 设置原始图片的宽 + // 获取处理之后图片的高 + int getHeight() const; + // 设置原始图片的宽 void setW(int w); - // 获取原始图片的宽 - int getW(); - // 设置原始图片的高 + // 获取原始图片的宽 + int getW() const; + // 设置原始图片的高 void setH(int h); - // 获取原始图片的高 - int getH(); - // 设置从原始图片到处理过后图片所添加黑边大小 + // 获取原始图片的高 + int getH() const; + // 设置从原始图片到处理过后图片所添加黑边大小 void setBorder(int border); - // 获取从原始图片到处理过后图片所添加黑边大小 - int getBorder(); + // 获取从原始图片到处理过后图片所添加黑边大小 + int getBorder() const; private: - // 处理过后图片高 + // 处理过后图片高 int height; // 处理过后图片宽 - int width; + int width; // 原始图片宽 - int w; + int w; // 原始图片高 - int h; + int h; // 从原始图片到处理图片所添加的黑边大小 - int border; + int border; // 处理过后的图片 - cv::Mat img; + cv::Mat img; }; /** @@ -62,17 +62,17 @@ class ImageResizeData class Yolo { public: - /** - * 构造函数 - * @param ptFile Yolo pt文件路径 + /** + * 构造函数 + * @param ptFile Yolo pt文件路径 * @param version Yolo的版本 ["v5", "v6", "v7", "v8"] 中选一 * @param device 推理使用的设备默认为cpu * @param height Yolo 训练时图片的高 * @param width Yolo 训练时图片的宽 * @param confThres 非极大值抑制中的 scoreThresh * @param iouThres 非极大值抑制中的 iouThresh - */ - Yolo(std::string ptFile, std::string version="v8", std::string device="cpu", bool isHalf=false, int height=640, int width=640, float confThres=0.25, float iouThres=0.45); + */ + Yolo(std::string ptFile, std::string version = "v8", std::string device = "cpu", bool isHalf = false, int height = 640, int width = 640, float confThres = 0.25, float iouThres = 0.45); /** * 预测函数 * @param data 需要预测的数据格式 (batch, rgb, height, width) @@ -90,7 +90,7 @@ class Yolo */ std::vector prediction(cv::Mat img); /** - * 预测函数 + * 预测函数 * @param imgs 需要预测的图片集合 */ std::vector prediction(std::vector imgs); @@ -144,7 +144,7 @@ class Yolo * @param imgs 原始图片集合 * @param rectangles 通过预测函数处理好的结果 * @param colors 每种类型对应颜色 - * @param labels 类别标签 + * @param labels 类别标签 * @return 画好框的图片 */ std::vector drawRectangle(std::vector imgs, std::vector rectangles, std::map colors, std::map labels, int thickness = 2); @@ -210,7 +210,7 @@ class Yolo // 随机获取一种颜色 cv::Scalar getRandScalar(); // 图片通道转换为 rgb - cv::Mat img2RGB(cv::Mat img); + void img2RGB(cv::Mat& img, cv::Mat& dst); // 图片变为 Tensor torch::Tensor img2Tensor(cv::Mat img); // (center_x center_y w h) to (left, top, right, bottom) @@ -220,5 +220,5 @@ class Yolo // 预测出来的框根据原始图片还原算法 std::vector sizeOriginal(std::vector result, std::vector imgRDs); // 非极大值抑制算法整体 - std::vector non_max_suppression(torch::Tensor preds, float confThres = 0.25, float iouThres = 0.45); + std::vector nonMaxSuppression(torch::Tensor preds, float confThres = 0.25, float iouThres = 0.45); }; diff --git a/src/Yolo.cpp b/src/Yolo.cpp index dcccb4b..d70a7bc 100644 --- a/src/Yolo.cpp +++ b/src/Yolo.cpp @@ -8,7 +8,7 @@ Yolo::Yolo(std::string ptFile, std::string version, std::string device, bool isH model.to(torch::kHalf); } model.to(device); - this->device=device; + this->device = device; this->height = height; this->width = width; this->iouThres = iouThres; @@ -20,11 +20,9 @@ Yolo::Yolo(std::string ptFile, std::string version, std::string device, bool isH std::srand(seed); } -std::vector Yolo::non_max_suppression(torch::Tensor prediction, float confThres, float iouThres) +std::vector Yolo::nonMaxSuppression(torch::Tensor prediction, float confThres, float iouThres) { torch::Tensor xc = prediction.select(2, 4) > confThres; - int maxWh = 4096; - int maxNms = 30000; std::vector output; for (int i = 0; i < prediction.size(0); i++) { @@ -35,8 +33,7 @@ std::vector Yolo::non_max_suppression(torch::Tensor prediction, f torch::Tensor x = prediction[i]; x = x.index_select(0, torch::nonzero(xc[i]).select(1, 0)); if (x.size(0) == 0) continue; - - x.slice(1, 5, x.size(1)).mul_(x.slice(1, 4, 5)); + torch::Tensor box = xywh2xyxy(x.slice(1, 0, 4)); std::tuple max_tuple = torch::max(x.slice(1, 5, x.size(1)), 1, true); x = torch::cat({ box, std::get<0>(max_tuple), std::get<1>(max_tuple) }, 1); @@ -46,11 +43,7 @@ std::vector Yolo::non_max_suppression(torch::Tensor prediction, f { continue; } - else if (n > maxNms) - { - x = x.index_select(0, x.select(1, 4).argsort(0, true).slice(0, 0, maxNms)); - } - torch::Tensor c = x.slice(1, 5, 6) * maxWh; + torch::Tensor c = x.slice(1, 5, 6) * 4096; torch::Tensor boxes = x.slice(1, 0, 4) + c, scores = x.select(1, 4); torch::Tensor ix = nms(boxes, scores, iouThres).to(x.device()); output[i] = x.index_select(0, ix).cpu(); @@ -63,7 +56,7 @@ cv::Scalar Yolo::getRandScalar() return cv::Scalar(std::rand() % 256, std::rand() % 256, std::rand() % 256); } -cv::Mat Yolo::img2RGB(cv::Mat img) +void Yolo::img2RGB(cv::Mat &img, cv::Mat& dst) { int imgC = img.channels(); if (imgC == 1) @@ -74,12 +67,11 @@ cv::Mat Yolo::img2RGB(cv::Mat img) { cv::cvtColor(img, img, cv::COLOR_BGR2RGB); } - return img; } torch::Tensor Yolo::img2Tensor(cv::Mat img) { - torch::Tensor data = torch::from_blob(img.data, {(int)height, (int)width, 3 }, torch::kByte); + torch::Tensor data = torch::from_blob(img.data, { (int)height, (int)width, 3 }, torch::kByte); data = data.permute({ 2, 0, 1 }); data = data.toType(torch::kFloat); data = data.div(255); @@ -108,15 +100,15 @@ torch::Tensor Yolo::nms(torch::Tensor bboxes, torch::Tensor scores, float thresh auto order = std::get<1>(tuple_sorted); std::vector keep; - while (order.numel() > 0) + while (order.numel() > 0) { - if (order.numel() == 1) + if (order.numel() == 1) { auto i = order.item(); keep.push_back(i.toInt()); break; } - else + else { auto i = order[0].item(); keep.push_back(i.toInt()); @@ -132,7 +124,7 @@ torch::Tensor Yolo::nms(torch::Tensor bboxes, torch::Tensor scores, float thresh auto iou = inter / (areas[keep.back()] + areas.index({ order.narrow(0,1,order.size(-1) - 1) }) - inter); auto idx = (iou <= thresh).nonzero().squeeze(); - if (idx.numel() == 0) + if (idx.numel() == 0) { break; } @@ -180,7 +172,7 @@ std::vector Yolo::sizeOriginal(std::vector result, } } } - + resultOrg.push_back(data); } return resultOrg; @@ -195,12 +187,12 @@ std::vector Yolo::prediction(torch::Tensor data) data = data.to(this->device); auto pred = model.forward({ data }); - + if (strcmp(this->version.c_str(), V8) == 0) { torch::Tensor pT = pred.toTensor(); torch::Tensor score = std::get<0>(pT.slice(1, 4, pT.size(1)).max(1, true)); - data = torch::cat({pT.slice(1, 0, 4), score, pT.slice(1, 4, pT.size(1))}, 1).permute({0, 2, 1}); + data = torch::cat({ pT.slice(1, 0, 4), score, pT.slice(1, 4, pT.size(1)) }, 1).permute({ 0, 2, 1 }); } else if (strcmp(this->version.c_str(), V6) == 0) { @@ -211,7 +203,7 @@ std::vector Yolo::prediction(torch::Tensor data) data = pred.toTuple()->elements()[0].toTensor(); } - return non_max_suppression(data, confThres, iouThres); + return nonMaxSuppression(data, confThres, iouThres); } std::vector Yolo::prediction(std::string filePath) @@ -223,7 +215,8 @@ std::vector Yolo::prediction(std::string filePath) std::vector Yolo::prediction(cv::Mat img) { ImageResizeData imgRD = resize(img); - cv::Mat reImg = img2RGB(imgRD.getImg()); + cv::Mat reImg = imgRD.getImg(); + img2RGB(reImg, reImg); torch::Tensor data = img2Tensor(reImg); std::vector result = prediction(data); std::vector imgRDs; @@ -239,7 +232,8 @@ std::vector Yolo::prediction(std::vector imgs) { ImageResizeData imgRD = resize(imgs[i]); imageRDs.push_back(imgRD); - cv::Mat img = img2RGB(imgRD.getImg()); + cv::Mat img = imgRD.getImg(); + img2RGB(img, img); datas.push_back(img2Tensor(img)); } torch::Tensor data = torch::cat(datas, 0); @@ -259,18 +253,19 @@ ImageResizeData Yolo::resize(cv::Mat img, int height, int width) cv::resize(img, img, cv::Size( isW ? width : (int)((float)height / (float)h * w), - isW ? (int)((float)width / (float)w * h) : height)); + isW ? (int)((float)width / (float)w * h) : height), + 0, 0, cv::INTER_AREA); w = img.cols, h = img.rows; if (isW) { imgResizeData.setBorder((height - h) / 2); - cv::copyMakeBorder(img, img, (height - h) / 2, height - h - (height - h) / 2, 0, 0, cv::BORDER_CONSTANT); + cv::copyMakeBorder(img, img, (height - h) / 2, height - h - (height - h) / 2, 0, 0, cv::BORDER_CONSTANT, cv::Scalar(114.)); } else { imgResizeData.setBorder((width - w) / 2); - cv::copyMakeBorder(img, img, 0, 0, (width - w) / 2, width - w - (width - w) / 2, cv::BORDER_CONSTANT); + cv::copyMakeBorder(img, img, 0, 0, (width - w) / 2, width - w - (width - w) / 2, cv::BORDER_CONSTANT, cv::Scalar(114.)); } imgResizeData.setImg(img); return imgResizeData; @@ -362,7 +357,7 @@ cv::Mat Yolo::drawRectangle(cv::Mat img, torch::Tensor rectangle, std::mapsecond << " "; @@ -385,7 +380,7 @@ bool Yolo::existencePrediction(std::vector classs) { for (int i = 0; i < classs.size(); i++) { - if (existencePrediction(classs[i])) + if (existencePrediction(classs[i])) { return true; } @@ -404,12 +399,12 @@ cv::Mat ImageResizeData::getImg() return img; } -bool ImageResizeData::isW() +bool ImageResizeData::isW() const { return (float)w / (float)h > (float)width / (float)height; } -bool ImageResizeData::isH() +bool ImageResizeData::isH() const { return (float)h / (float)w > (float)height / (float)width; } @@ -429,7 +424,7 @@ void ImageResizeData::setHeight(int height) this->height = height; } -int ImageResizeData::getHeight() +int ImageResizeData::getHeight() const { return height; } @@ -439,7 +434,7 @@ void ImageResizeData::setW(int w) this->w = w; } -int ImageResizeData::getW() +int ImageResizeData::getW() const { return w; } @@ -449,7 +444,7 @@ void ImageResizeData::setH(int h) this->h = h; } -int ImageResizeData::getH() +int ImageResizeData::getH() const { return h; } @@ -459,7 +454,7 @@ void ImageResizeData::setBorder(int border) this->border = border; } -int ImageResizeData::getBorder() +int ImageResizeData::getBorder() const { return border; }