目录
一、解析yolov5的输出结果
1、对1*25200*85的向量进行解析
2、预测框中心(x,y),预测框的高和宽(h,w)详解
3、解析代码:
二、confidence过滤
1、confidence计算
三、Non-Maximum Suppression
1、非极大值抑制
2、NMS代码:
四、预测框校正
1、校正的目的
2、代码:
五,检测框可视化
1、代码:
2、coco数据集的类别和颜色
yolov5后处理大致分为四步:confidence过滤,NMS,预测框校准,结果可视化。
首先需要对yoloV5的输出解析,认识清楚之后才能进行下一步。
一、解析yolov5的输出结果
首先yolov5的输出是一组1*25200*85的向量,我们想要的是检测框的坐标和置信度等,怎么得到?
1、对1*25200*85的向量进行解析
yolov5是一种目标检测算法,但它主要做的是对检测框进行预测,在导出为trt文件之后,yolov5的检测框固定为20*20;40*40;80*80大小。yolov5的输出结果是1*25200*85,也就是1*[(3*20*20)+(3*40*40)+(3*80*80)]*85。
1是代表输出图片的batchsize,我们在输入数据的时候bs维度设置为1。
25200也就是(3*20*20)+(3*40*40)+(3*80*80),可以这样理解:yolov5对图像依次做了8,16,32倍下采样,依次得到了80*80;40*40;20*20大小的特征图,每个特征图上有三个anchor box。
85可以分开理解:前5个数据分别代表了预测框中心(x,y),预测框的高和宽(h,w),还有obj_score(这个输出是只得框中是否有object的得分),剩下的80个数据分别是每一类别的得分,分别代表该框中的物体属于每一类的得分,代码中记作class_score。为什么是80呢?
note:最后的confidence的计算是obj_score * class_score。
2、预测框中心(x,y),预测框的高和宽(h,w)详解
x和y是预测框的中心坐标,h和w是预测框的高和宽,通过简单计算可以达到框的坐标。但是需要注意我们传进来的image的坐标是归一化之后的,需要判断是否反归一化。
至此我们完成了对yolo输出的解析,创建一个结构体包含每一个框的信息,然后创建一个该结构体类型的vector用来保存所有的信息。
//存放结果的结构体
struct Detection {
float x1, y1, x2, y2;
int class_id;
float confidence;
};
//存放所有结果的vector
std::vector<Detection> detections;
3、解析代码:
std::vector<Detection> parse_yolo(const std::vector<float>& yoloout,int img_width,int img_height){
std::vector<Detection> detections;
for (int i = 0; i < NUM_ANCHORS; ++i) {
const float* data = &yoloout[i * NUM_OUTPUTS];
//边界框坐标
float x = data[0];
float y = data[1];
float w = data[2];
float h = data[3];
//计算obj score
float obj_score = data[4];
float class_score = 0;
int class_id = -1;
//遍历每一类的score,取最大值并记录该类id
for(int j = 0; j < NUM_CLASSES; ++j){
float score = data[5 + j];
if (score > class_score){
class_score = score;
class_id = j;
}
};
//计算最后的置信度
float confidence = class_score * obj_score;
//设置置信度阈值进行过滤
if(confidence > CONFIDENCE_THRESHOLD){
// yolov5推理输出的坐标框已经是反归一化之后的,不需要再反归一化
// ----------------------------------------------------------
float x_center = x;
float y_center = y;
float width = w;
float height = h;
//图像坐标系和笛卡尔坐标系的y轴是反的
float x1 = x_center - width / 2;
float y1 = y_center - height / 2;
float x2 = x_center + width / 2;
float y2 = y_center + height / 2;
//push倒detection这个vector的末尾
detections.push_back({x1, y1, x2, y2, class_id, confidence});
}
}
//NMS
std::vector<Detection> nms_result;
nms_process(detections,nms_result,NMS_THRESHOLD);
return nms_result;
}
二、confidence过滤
设定阈值对检测框进行初步过滤,即过滤掉置信度较低的框。
1、confidence计算
Confidence = obj_score * class_score。很显然obj_score是yolo直接输出的。class_score怎么来?
class_score需要遍历后面的80的数据,选择出得分最大的一个数据。我们可以得到class_score和下标i,代表该预测框预测出的是第i类的得分。
这一步相对简单,上一步操作完之后可以得到confidence,遍历判断,大于confidence阈值则将检测框的信息push_back到容器末尾。代码在上面。
//push倒detection这个vector的末尾
detections.push_back({x1, y1, x2, y2, class_id, confidence});
三、Non-Maximum Suppression
1、非极大值抑制
经过置信度初步过滤之后,剩下的框都是置信度较高的。但是会存在两个坐标相似的框框住同一个物体的情况,所以需要对框做抑制。根据什么做抑制呢?
每个框都有坐标,可以计算出两个框之间的iou,我们通过设定iou_threshold阈值,判断iou大于该阈值,意味着这两个框重叠度过高,则抑制该框。iou示意图如下左图,NMS示意图如下右图所示。
2、NMS代码:
void nms_process(std::vector<Detection>& detection,
std::vector<Detection>& result, float nms_threshold){
//sort函数默认升序,这边使用confidence进行降序排序
std::sort(detection.begin(),detection.end(),
//Lambda表达式作为第三个参数输入。
[](const Detection& a,const Detection& b){return a.confidence > b.confidence;}
);
// 创建一个布尔类型的std::vector is_suppressed,其长度等于detections的大小。
// 所有元素被初始化为false,表示初始时所有检测结果都没有被抑制。
std::vector<bool> is_suppressed(detection.size(),false);
for(size_t i = 0; i < detection.size(); ++i){
if(is_suppressed[i]) continue;
//依次遍历整个vector,先选择i,然后依次和(i+1)开始对比iou
//每次只从第一个循环中写值
result.push_back(detection[i]);
for(size_t j = i + 1; j < detection.size(); ++j){
if(is_suppressed[j]) continue;
if(detection[i].class_id == detection[j].class_id){
float iou = calculate_iou(detection[i],detection[j]);
if(iou > nms_threshold){
is_suppressed[j] = true;
}
}
}
}
}
四、预测框校正
1、校正的目的
在图像预处理的时候,大多数时候图像并不是正方形的,所以设计到padding操作,所以预测框的尺寸是在带着padding的图像的,但是在最后可视化阶段我们需要去掉padding的部分,但是框的坐标还是没变,所以我们需要对预测框进行校准。如下图所示,去掉offset部分之后,框会变化:
坦白讲:这段代码没有什么难度,关键得明白是在做什么事情,主要做的事情就是对框进行缩放。
2、代码:
std::vector<Detection> correct_boxes(const std::vector<Detection>& detections, const cv::Size& input_shape, const cv::Size& image_shape) {
// 构建存放结果的结构体
std::vector<Detection> corrected_boxes;
// 得到训练缩放的尺寸 640
float input_w = input_shape.width; //640
float input_h = input_shape.height; //640
// 计算原图的w和h
float image_w = image_shape.width; //810
float image_h = image_shape.height; //1080
// 计算缩放因子
float scale = std::min(input_w / image_w, input_h / image_h); // 0.592593
float new_w = image_w * scale; //480
float new_h = image_h * scale; //640
float pad_w = (input_w - new_w) / 2.0f;
float pad_h = (input_h - new_h) / 2.0f;
// std::cout<<"校准框之后的的边框坐标分别是"<<std::endl;
for (const auto& detection : detections) {
float x1 = (detection.x1 - pad_w) / scale;
float y1 = (detection.y1 - pad_h) / scale;
float x2 = (detection.x2 - pad_w) / scale;
float y2 = (detection.y2 - pad_h) / scale;
corrected_boxes.push_back({x1, y1, x2, y2, detection.class_id, detection.confidence});
}
return corrected_boxes;
}
五,检测框可视化
需要准备的是categories,colors,坐标。
1、代码:
使用cv2.rectangle函数讲矩形的顶点画在原图上,同时绘制text。比较简单。
//draw box
void draw_detections(cv::Mat& img, const std::vector<Detection>& detections, const std::vector<std::string>& categories) {
for (const auto& det : detections) {
// 绘制检测框
cv::rectangle(img, cv::Point(det.x1, det.y1), cv::Point(det.x2, det.y2), cv::Scalar(0, 255, 0), 2);
// 准备标签文本
std::string label = categories[det.class_id] + ": " + std::to_string(det.confidence).substr(0, 4);
// 获取文本大小
int baseline;
cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseline);
// 确保文本不超出图像边界
int text_x = det.x1;
int text_y = det.y1;
// 绘制文本背景框
cv::rectangle(img, cv::Point(text_x, text_y - textSize.height),
cv::Point(text_x + textSize.width, text_y + baseline),
cv::Scalar(255, 255, 255), cv::FILLED);
// 绘制文本
cv::putText(img, label, cv::Point(text_x, text_y),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0), 1);
}
}
2、coco数据集的类别和颜色
const std::vector<std::string> categories = {"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
"frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
"surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed",
"dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush"};
//COCO80的颜色
const std::vector<cv::Scalar> color80{
cv::Scalar(128, 77, 207),cv::Scalar(65, 32, 208),cv::Scalar(0, 224, 45),cv::Scalar(3, 141, 219),cv::Scalar(80, 239, 253),cv::Scalar(239, 184, 12),
cv::Scalar(7, 144, 145),cv::Scalar(161, 88, 57),cv::Scalar(0, 166, 46),cv::Scalar(218, 113, 53),cv::Scalar(193, 33, 128),cv::Scalar(190, 94, 113),
cv::Scalar(113, 123, 232),cv::Scalar(69, 205, 80),cv::Scalar(18, 170, 49),cv::Scalar(89, 51, 241),cv::Scalar(153, 191, 154),cv::Scalar(27, 26, 69),
cv::Scalar(20, 186, 194),cv::Scalar(210, 202, 167),cv::Scalar(196, 113, 204),cv::Scalar(9, 81, 88),cv::Scalar(191, 162, 67),cv::Scalar(227, 73, 120),
cv::Scalar(177, 31, 19),cv::Scalar(133, 102, 137),cv::Scalar(146, 72, 97),cv::Scalar(145, 243, 208),cv::Scalar(2, 184, 176),cv::Scalar(219, 220, 93),
cv::Scalar(238, 153, 134),cv::Scalar(197, 169, 160),cv::Scalar(204, 201, 106),cv::Scalar(13, 24, 129),cv::Scalar(40, 38, 4),cv::Scalar(5, 41, 34),
cv::Scalar(46, 94, 129),cv::Scalar(102, 65, 107),cv::Scalar(27, 11, 208),cv::Scalar(191, 240, 183),cv::Scalar(225, 76, 38),cv::Scalar(193, 89, 124),
cv::Scalar(30, 14, 175),cv::Scalar(144, 96, 90),cv::Scalar(181, 186, 86),cv::Scalar(102, 136, 34),cv::Scalar(158, 71, 15),cv::Scalar(183, 81, 247),
cv::Scalar(73, 69, 89),cv::Scalar(123, 73, 232),cv::Scalar(4, 175, 57),cv::Scalar(87, 108, 23),cv::Scalar(105, 204, 142),cv::Scalar(63, 115, 53),
cv::Scalar(105, 153, 126),cv::Scalar(247, 224, 137),cv::Scalar(136, 21, 188),cv::Scalar(122, 129, 78),cv::Scalar(145, 80, 81),cv::Scalar(51, 167, 149),
cv::Scalar(162, 173, 20),cv::Scalar(252, 202, 17),cv::Scalar(10, 40, 3),cv::Scalar(150, 90, 254),cv::Scalar(169, 21, 68),cv::Scalar(157, 148, 180),
cv::Scalar(131, 254, 90),cv::Scalar(7, 221, 102),cv::Scalar(19, 191, 184),cv::Scalar(98, 126, 199),cv::Scalar(210, 61, 56),cv::Scalar(252, 86, 59),
cv::Scalar(102, 195, 55),cv::Scalar(160, 26, 91),cv::Scalar(60, 94, 66),cv::Scalar(204, 169, 193),cv::Scalar(126, 4, 181),cv::Scalar(229, 209, 196),
cv::Scalar(195, 170, 186),cv::Scalar(155, 207, 148)
};
使用cv::rectangle函数在输入图片上进行绘制框,同时绘制文本。