Skip to content

Commit

Permalink
delete maxNms
Browse files Browse the repository at this point in the history
  • Loading branch information
ncdhz committed May 10, 2024
1 parent a8157c6 commit 0f99871
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 70 deletions.
70 changes: 35 additions & 35 deletions include/Yolo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,47 @@
class ImageResizeData
{
public:
// 设置处理过后的图片
// 设置处理过后的图片
void setImg(cv::Mat img);
// 获取处理过后的图片
// 获取处理过后的图片
cv::Mat getImg();
// 当原始图片宽高比大于处理过后图片宽高比时此函数返回 true
bool isW();
// 当原始图片高宽比大于处理过后图片高宽比时此函数返回 true
bool isH();
// 设置处理之后图片的宽
// 当原始图片宽高比大于处理过后图片宽高比时此函数返回 true
bool isW() const;
// 当原始图片高宽比大于处理过后图片高宽比时此函数返回 true
bool isH() const;
// 设置处理之后图片的宽
void setWidth(int width);
// 获取处理之后图片的宽
// 获取处理之后图片的宽
int getWidth();
// 设置处理之后图片的高
// 设置处理之后图片的高
void setHeight(int height);
// 获取处理之后图片的高
int getHeight();
// 设置原始图片的宽
// 获取处理之后图片的高
int getHeight() const;
// 设置原始图片的宽
void setW(int w);
// 获取原始图片的宽
int getW();
// 设置原始图片的高
// 获取原始图片的宽
int getW() const;
// 设置原始图片的高
void setH(int h);
// 获取原始图片的高
int getH();
// 设置从原始图片到处理过后图片所添加黑边大小
// 获取原始图片的高
int getH() const;
// 设置从原始图片到处理过后图片所添加黑边大小
void setBorder(int border);
// 获取从原始图片到处理过后图片所添加黑边大小
int getBorder();
// 获取从原始图片到处理过后图片所添加黑边大小
int getBorder() const;
private:
// 处理过后图片高
// 处理过后图片高
int height;
// 处理过后图片宽
int width;
int width;
// 原始图片宽
int w;
int w;
// 原始图片高
int h;
int h;
// 从原始图片到处理图片所添加的黑边大小
int border;
int border;
// 处理过后的图片
cv::Mat img;
cv::Mat img;
};

/**
Expand All @@ -62,17 +62,17 @@ class ImageResizeData
class Yolo
{
public:
/**
* 构造函数
* @param ptFile Yolo pt文件路径
/**
* 构造函数
* @param ptFile Yolo pt文件路径
* @param version Yolo的版本 ["v5", "v6", "v7", "v8"] 中选一
* @param device 推理使用的设备默认为cpu
* @param height Yolo 训练时图片的高
* @param width Yolo 训练时图片的宽
* @param confThres 非极大值抑制中的 scoreThresh
* @param iouThres 非极大值抑制中的 iouThresh
*/
Yolo(std::string ptFile, std::string version="v8", std::string device="cpu", bool isHalf=false, int height=640, int width=640, float confThres=0.25, float iouThres=0.45);
*/
Yolo(std::string ptFile, std::string version = "v8", std::string device = "cpu", bool isHalf = false, int height = 640, int width = 640, float confThres = 0.25, float iouThres = 0.45);
/**
* 预测函数
* @param data 需要预测的数据格式 (batch, rgb, height, width)
Expand All @@ -90,7 +90,7 @@ class Yolo
*/
std::vector<torch::Tensor> prediction(cv::Mat img);
/**
* 预测函数
* 预测函数
* @param imgs 需要预测的图片集合
*/
std::vector<torch::Tensor> prediction(std::vector<cv::Mat> imgs);
Expand Down Expand Up @@ -144,7 +144,7 @@ class Yolo
* @param imgs 原始图片集合
* @param rectangles 通过预测函数处理好的结果
* @param colors 每种类型对应颜色
* @param labels 类别标签
* @param labels 类别标签
* @return 画好框的图片
*/
std::vector<cv::Mat> drawRectangle(std::vector<cv::Mat> imgs, std::vector<torch::Tensor> rectangles, std::map<int, cv::Scalar> colors, std::map<int, std::string> labels, int thickness = 2);
Expand Down Expand Up @@ -210,7 +210,7 @@ class Yolo
// 随机获取一种颜色
cv::Scalar getRandScalar();
// 图片通道转换为 rgb
cv::Mat img2RGB(cv::Mat img);
void img2RGB(cv::Mat& img, cv::Mat& dst);
// 图片变为 Tensor
torch::Tensor img2Tensor(cv::Mat img);
// (center_x center_y w h) to (left, top, right, bottom)
Expand All @@ -220,5 +220,5 @@ class Yolo
// 预测出来的框根据原始图片还原算法
std::vector<torch::Tensor> sizeOriginal(std::vector<torch::Tensor> result, std::vector<ImageResizeData> imgRDs);
// 非极大值抑制算法整体
std::vector<torch::Tensor> non_max_suppression(torch::Tensor preds, float confThres = 0.25, float iouThres = 0.45);
std::vector<torch::Tensor> nonMaxSuppression(torch::Tensor preds, float confThres = 0.25, float iouThres = 0.45);
};
65 changes: 30 additions & 35 deletions src/Yolo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Yolo::Yolo(std::string ptFile, std::string version, std::string device, bool isH
model.to(torch::kHalf);
}
model.to(device);
this->device=device;
this->device = device;
this->height = height;
this->width = width;
this->iouThres = iouThres;
Expand All @@ -20,11 +20,9 @@ Yolo::Yolo(std::string ptFile, std::string version, std::string device, bool isH
std::srand(seed);
}

std::vector<torch::Tensor> Yolo::non_max_suppression(torch::Tensor prediction, float confThres, float iouThres)
std::vector<torch::Tensor> Yolo::nonMaxSuppression(torch::Tensor prediction, float confThres, float iouThres)
{
torch::Tensor xc = prediction.select(2, 4) > confThres;
int maxWh = 4096;
int maxNms = 30000;
std::vector<torch::Tensor> output;
for (int i = 0; i < prediction.size(0); i++)
{
Expand All @@ -35,8 +33,7 @@ std::vector<torch::Tensor> Yolo::non_max_suppression(torch::Tensor prediction, f
torch::Tensor x = prediction[i];
x = x.index_select(0, torch::nonzero(xc[i]).select(1, 0));
if (x.size(0) == 0) continue;

x.slice(1, 5, x.size(1)).mul_(x.slice(1, 4, 5));

torch::Tensor box = xywh2xyxy(x.slice(1, 0, 4));
std::tuple<torch::Tensor, torch::Tensor> max_tuple = torch::max(x.slice(1, 5, x.size(1)), 1, true);
x = torch::cat({ box, std::get<0>(max_tuple), std::get<1>(max_tuple) }, 1);
Expand All @@ -46,11 +43,7 @@ std::vector<torch::Tensor> Yolo::non_max_suppression(torch::Tensor prediction, f
{
continue;
}
else if (n > maxNms)
{
x = x.index_select(0, x.select(1, 4).argsort(0, true).slice(0, 0, maxNms));
}
torch::Tensor c = x.slice(1, 5, 6) * maxWh;
torch::Tensor c = x.slice(1, 5, 6) * 4096;
torch::Tensor boxes = x.slice(1, 0, 4) + c, scores = x.select(1, 4);
torch::Tensor ix = nms(boxes, scores, iouThres).to(x.device());
output[i] = x.index_select(0, ix).cpu();
Expand All @@ -63,7 +56,7 @@ cv::Scalar Yolo::getRandScalar()
return cv::Scalar(std::rand() % 256, std::rand() % 256, std::rand() % 256);
}

cv::Mat Yolo::img2RGB(cv::Mat img)
void Yolo::img2RGB(cv::Mat &img, cv::Mat& dst)
{
int imgC = img.channels();
if (imgC == 1)
Expand All @@ -74,12 +67,11 @@ cv::Mat Yolo::img2RGB(cv::Mat img)
{
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
}
return img;
}

torch::Tensor Yolo::img2Tensor(cv::Mat img)
{
torch::Tensor data = torch::from_blob(img.data, {(int)height, (int)width, 3 }, torch::kByte);
torch::Tensor data = torch::from_blob(img.data, { (int)height, (int)width, 3 }, torch::kByte);
data = data.permute({ 2, 0, 1 });
data = data.toType(torch::kFloat);
data = data.div(255);
Expand Down Expand Up @@ -108,15 +100,15 @@ torch::Tensor Yolo::nms(torch::Tensor bboxes, torch::Tensor scores, float thresh
auto order = std::get<1>(tuple_sorted);

std::vector<int> keep;
while (order.numel() > 0)
while (order.numel() > 0)
{
if (order.numel() == 1)
if (order.numel() == 1)
{
auto i = order.item();
keep.push_back(i.toInt());
break;
}
else
else
{
auto i = order[0].item();
keep.push_back(i.toInt());
Expand All @@ -132,7 +124,7 @@ torch::Tensor Yolo::nms(torch::Tensor bboxes, torch::Tensor scores, float thresh

auto iou = inter / (areas[keep.back()] + areas.index({ order.narrow(0,1,order.size(-1) - 1) }) - inter);
auto idx = (iou <= thresh).nonzero().squeeze();
if (idx.numel() == 0)
if (idx.numel() == 0)
{
break;
}
Expand Down Expand Up @@ -180,7 +172,7 @@ std::vector<torch::Tensor> Yolo::sizeOriginal(std::vector<torch::Tensor> result,
}
}
}

resultOrg.push_back(data);
}
return resultOrg;
Expand All @@ -195,12 +187,12 @@ std::vector<torch::Tensor> Yolo::prediction(torch::Tensor data)
data = data.to(this->device);

auto pred = model.forward({ data });

if (strcmp(this->version.c_str(), V8) == 0)
{
torch::Tensor pT = pred.toTensor();
torch::Tensor score = std::get<0>(pT.slice(1, 4, pT.size(1)).max(1, true));
data = torch::cat({pT.slice(1, 0, 4), score, pT.slice(1, 4, pT.size(1))}, 1).permute({0, 2, 1});
data = torch::cat({ pT.slice(1, 0, 4), score, pT.slice(1, 4, pT.size(1)) }, 1).permute({ 0, 2, 1 });
}
else if (strcmp(this->version.c_str(), V6) == 0)
{
Expand All @@ -211,7 +203,7 @@ std::vector<torch::Tensor> Yolo::prediction(torch::Tensor data)
data = pred.toTuple()->elements()[0].toTensor();
}

return non_max_suppression(data, confThres, iouThres);
return nonMaxSuppression(data, confThres, iouThres);
}

std::vector<torch::Tensor> Yolo::prediction(std::string filePath)
Expand All @@ -223,7 +215,8 @@ std::vector<torch::Tensor> Yolo::prediction(std::string filePath)
std::vector<torch::Tensor> Yolo::prediction(cv::Mat img)
{
ImageResizeData imgRD = resize(img);
cv::Mat reImg = img2RGB(imgRD.getImg());
cv::Mat reImg = imgRD.getImg();
img2RGB(reImg, reImg);
torch::Tensor data = img2Tensor(reImg);
std::vector<torch::Tensor> result = prediction(data);
std::vector<ImageResizeData> imgRDs;
Expand All @@ -239,7 +232,8 @@ std::vector<torch::Tensor> Yolo::prediction(std::vector<cv::Mat> imgs)
{
ImageResizeData imgRD = resize(imgs[i]);
imageRDs.push_back(imgRD);
cv::Mat img = img2RGB(imgRD.getImg());
cv::Mat img = imgRD.getImg();
img2RGB(img, img);
datas.push_back(img2Tensor(img));
}
torch::Tensor data = torch::cat(datas, 0);
Expand All @@ -259,18 +253,19 @@ ImageResizeData Yolo::resize(cv::Mat img, int height, int width)

cv::resize(img, img, cv::Size(
isW ? width : (int)((float)height / (float)h * w),
isW ? (int)((float)width / (float)w * h) : height));
isW ? (int)((float)width / (float)w * h) : height),
0, 0, cv::INTER_AREA);

w = img.cols, h = img.rows;
if (isW)
{
imgResizeData.setBorder((height - h) / 2);
cv::copyMakeBorder(img, img, (height - h) / 2, height - h - (height - h) / 2, 0, 0, cv::BORDER_CONSTANT);
cv::copyMakeBorder(img, img, (height - h) / 2, height - h - (height - h) / 2, 0, 0, cv::BORDER_CONSTANT, cv::Scalar(114.));
}
else
{
imgResizeData.setBorder((width - w) / 2);
cv::copyMakeBorder(img, img, 0, 0, (width - w) / 2, width - w - (width - w) / 2, cv::BORDER_CONSTANT);
cv::copyMakeBorder(img, img, 0, 0, (width - w) / 2, width - w - (width - w) / 2, cv::BORDER_CONSTANT, cv::Scalar(114.));
}
imgResizeData.setImg(img);
return imgResizeData;
Expand Down Expand Up @@ -362,7 +357,7 @@ cv::Mat Yolo::drawRectangle(cv::Mat img, torch::Tensor rectangle, std::map<int,
labelIt = labels.find(clazz);

std::ostringstream oss;

if (labelIt != labels.end())
{
oss << labelIt->second << " ";
Expand All @@ -385,7 +380,7 @@ bool Yolo::existencePrediction(std::vector<torch::Tensor> classs)
{
for (int i = 0; i < classs.size(); i++)
{
if (existencePrediction(classs[i]))
if (existencePrediction(classs[i]))
{
return true;
}
Expand All @@ -404,12 +399,12 @@ cv::Mat ImageResizeData::getImg()
return img;
}

bool ImageResizeData::isW()
bool ImageResizeData::isW() const
{
return (float)w / (float)h > (float)width / (float)height;
}

bool ImageResizeData::isH()
bool ImageResizeData::isH() const
{
return (float)h / (float)w > (float)height / (float)width;
}
Expand All @@ -429,7 +424,7 @@ void ImageResizeData::setHeight(int height)
this->height = height;
}

int ImageResizeData::getHeight()
int ImageResizeData::getHeight() const
{
return height;
}
Expand All @@ -439,7 +434,7 @@ void ImageResizeData::setW(int w)
this->w = w;
}

int ImageResizeData::getW()
int ImageResizeData::getW() const
{
return w;
}
Expand All @@ -449,7 +444,7 @@ void ImageResizeData::setH(int h)
this->h = h;
}

int ImageResizeData::getH()
int ImageResizeData::getH() const
{
return h;
}
Expand All @@ -459,7 +454,7 @@ void ImageResizeData::setBorder(int border)
this->border = border;
}

int ImageResizeData::getBorder()
int ImageResizeData::getBorder() const
{
return border;
}

0 comments on commit 0f99871

Please sign in to comment.