分离轴定理的c++/python实现
现在要对BEV模型检查出来的车辆做NMS,把3d框的平面属性获取到后,配合旋转角度投影到地面就是2D图形。
开始碰撞检测,判断是否重叠,保留置信度高的框就行。
原理
分离轴定理(Separating Axis Theorem,简称 SAT)是一种常用的碰撞检测算法,尤其在检测两个凸多边形或凸多面体是否相交时非常有效。SAT 的核心思想是:如果两个多边形在任何一条“轴”上都没有重叠投影,那么它们就不相交。如果在所有潜在的轴上都存在重叠投影,那么它们就是相交的。
SAT 算法的核心思想:
- 分离轴的定义:两个凸多边形相交的必要条件是,在某个方向上,两个多边形的投影不重叠。这个方向称为“分离轴”。如果存在这样的轴,称为“分离轴”,那么这两个多边形一定不相交。
- 投影:将多边形的所有顶点投影到某个轴上,可以得到一个一维的区间。如果两个多边形在某条轴上的投影不重叠,说明在该方向上它们是分开的,所以不可能相交。
- 边法线作为分离轴:对凸多边形的每一条边,计算其垂直方向(即边的法线),并将所有顶点沿着这个法线方向进行投影。如果在任何一条法线方向上,两个多边形的投影不重叠,则这两个多边形不会相交。
- 反例与正例:如果在所有法线方向上,两个多边形的投影都重叠,那么它们是相交的;如果找到一个轴使得投影不重叠,那么它们是不相交的。
SAT 算法的步骤:
- 获取多边形的所有边:对于两个多边形,分别获取它们所有的边。
- 计算边的法线(轴):对每一条边,计算其法线。这些法线是我们用来做投影的方向(轴)。
- 将顶点投影到轴上:对每条轴,分别将两个多边形的所有顶点投影到该轴上,得到两个投影区间。
- 检查投影区间是否重叠:比较两个多边形在该轴上的投影区间。如果任何一个轴上的投影不重叠,立即返回不相交(即找到了分离轴)。
- 判断相交:如果所有轴上的投影都重叠,那么可以判断两个多边形相交。
具体例子:
假设有两个二维矩形 A 和 B,算法流程如下:
- 计算 A 的所有边以及 B 的所有边。
- 对每条边,计算它的垂直方向作为分离轴。
- 将 A 和 B 的顶点投影到每条分离轴上,得到两个投影区间。
- 检查每一条轴上 A 和 B 的投影是否有重叠。
- 如果在某条轴上投影没有重叠,A 和 B 不相交。
- 如果在所有轴上投影都重叠,A 和 B 相交。
示例图解:
- 假设两个矩形相对旋转,SAT 会先找出每个矩形的边,接着计算这些边的法线(垂直方向)。
- 在这些法线方向上,将矩形的所有顶点投影到这些法线上。
- 如果在所有法线上的投影都重叠,就说明矩形相交;如果存在一条轴上的投影不重叠,矩形就不相交。
SAT 算法的优点:
- 高效:特别适用于凸多边形,时间复杂度相对较低。
- 通用性强:适用于任何凸多边形,不仅仅是矩形,也包括任意形状的多边形或三维多面体。
SAT 的局限:
- 仅适用于凸多边形:对凹多边形不适用。
- 计算较复杂的形状时需要的边数较多:多边形的边数越多,所需的计算量也会随之增加。
应用场景:
- 游戏开发:在物理引擎中用于碰撞检测,判断玩家、敌人或者其他物体是否相互碰撞。
- 图形处理:用于检测两个二维或三维形状是否相交,通常用于图像合成、动画、模拟仿真等场景。
总结:
分离轴定理通过投影和比较的方式,能够有效地判断两个凸多边形是否相交。它利用边法线作为可能的分离轴,并通过检测是否存在投影不重叠的轴来快速判断是否有碰撞。
cpp实现
#include <vector>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <cassert>
#include <algorithm>
#include <fstream>
#include <vector>
#include <string>
#include <sstream>
struct Box3D {
float conf;
int sub_cls;
float x;
float y;
float z;
float w; // 宽度
float l; // 长度
float h; // 高度
float yaw; // 旋转角度
float velocity_x = 0.0f;
float velocity_y = 0.0f;
};
// 旋转一个点(x, y)围绕中心点(cx, cy)旋转angle_degrees角度
std::pair<float, float> RotatePoint(float x, float y, float cx, float cy, float angleRad) {
// float angleRad = angleDegrees * M_PI / 180.0f;
float cosA = std::cos(angleRad);
float sinA = std::sin(angleRad);
float rx = cx + (x - cx) * cosA - (y - cy) * sinA;
float ry = cy + (x - cx) * sinA + (y - cy) * cosA;
return {rx, ry};
}
// 获取旋转后的边框的顶点
std::vector<std::pair<float, float>> GetRotatedBboxVertices(float cx, float cy, float w, float h, float angle) {
float half_w = w / 2;
float half_h = h / 2;
// 定义未旋转的矩形的四个顶点
std::vector<std::pair<float, float>> vertices = {
{cx - half_w, cy - half_h},
{cx + half_w, cy - half_h},
{cx + half_w, cy + half_h},
{cx - half_w, cy + half_h}
};
// 旋转每个顶点
for (auto& vertex : vertices) {
vertex = RotatePoint(vertex.first, vertex.second, cx, cy, angle);
}
return vertices;
}
// 获取边框的边
std::vector<std::pair<std::pair<float, float>, std::pair<float, float>>> GetEdges(const std::vector<std::pair<float, float>>& vertices) {
std::vector<std::pair<std::pair<float, float>, std::pair<float, float>>> edges;
for (size_t i = 0; i < vertices.size(); ++i) {
edges.push_back({vertices[i], vertices[(i + 1) % vertices.size()]});
}
return edges;
}
// 获取边缘的轴
std::pair<float, float> GetAxis(const std::pair<std::pair<float, float>, std::pair<float, float>>& edge) {
float x1 = edge.first.first;
float y1 = edge.first.second;
float x2 = edge.second.first;
float y2 = edge.second.second;
return {y2 - y1, x1 - x2}; // 垂直于边的向量
}
// 投影顶点到轴上
std::pair<float, float> Project(const std::vector<std::pair<float, float>>& vertices, const std::pair<float, float>& axis) {
std::vector<float> dots;
for (const auto& vertex : vertices) {
dots.push_back(vertex.first * axis.first + vertex.second * axis.second);
}
return {*std::min_element(dots.begin(), dots.end()), *std::max_element(dots.begin(), dots.end())};
}
// 判断投影是否重叠
bool Overlap(const std::pair<float, float>& projection1, const std::pair<float, float>& projection2) {
return std::min(projection1.second, projection2.second) >= std::max(projection1.first, projection2.first);
}
// 使用分离轴定理(SAT)判断两个二维框是否相交
bool SatIntersection(const std::vector<float>& box1, const std::vector<float>& box2) {
std::vector<std::pair<float, float>> vertices1 = GetRotatedBboxVertices(box1[0], box1[1], box1[2], box1[3], box1[4]);
std::vector<std::pair<float, float>> vertices2 = GetRotatedBboxVertices(box2[0], box2[1], box2[2], box2[3], box2[4]);
std::vector<std::pair<std::pair<float, float>, std::pair<float, float>>> edges;
auto edges1 = GetEdges(vertices1);
auto edges2 = GetEdges(vertices2);
edges.insert(edges.end(), edges1.begin(), edges1.end());
edges.insert(edges.end(), edges2.begin(), edges2.end());
for (const auto& edge : edges) {
auto axis = GetAxis(edge);
auto projection1 = Project(vertices1, axis);
auto projection2 = Project(vertices2, axis);
std::cout << projection1.first << " " << projection1.second << " " << projection2.first << " " << projection2.second << std::endl;
if (!Overlap(projection1, projection2)) {
return false; // 找到了分离轴,两个框不相交
}
}
return true; // 没有分离轴,两个框相交
}
// 解析 Box3D 对象为二维 box 参数
std::vector<float> ParseBox3D(const Box3D& box) {
return {box.x, box.y, box.l, box.w, box.yaw};
}
// SAT NMS 实现,返回选择的索引
std::vector<int> SatNms(std::vector<Box3D> obj_list) {
// 如果输入为空,返回空结果
if (obj_list.empty()) {
return {};
}
if (obj_list.size() == 1) {
return {1}; // 只有一个框时,直接返回
}
// 保存原始索引
std::vector<int> original_indices(obj_list.size());
std::iota(original_indices.begin(), original_indices.end(), 0); // 生成 0 到 obj_list.size()-1 的索引
// 按置信度从高到低排序,同时调整原始索引顺序
std::sort(original_indices.begin(), original_indices.end(), [&](int i, int j) {
return obj_list[i].conf > obj_list[j].conf;
});
std::sort(obj_list.begin(), obj_list.end(), [&](const Box3D& a, const Box3D& b) {
return a.conf > b.conf;
});
std::vector<std::vector<float>> boxes;
std::vector<int> keep(obj_list.size(), 0); // 初始化 keep 向量为 0,大小与 obj_list 相同
// 解析 Box3D 为二维 box 参数并存储
for (const auto& obj : obj_list) {
boxes.push_back({obj.x, obj.y, obj.l, obj.w, obj.yaw});
}
// 执行 NMS 逻辑
while (!boxes.empty()) {
std::vector<float> next_box = boxes.front(); // 获取置信度最高的 box
boxes.erase(boxes.begin()); // 移除第一个框
// 获取对应的原始索引
int original_index = original_indices.front();
original_indices.erase(original_indices.begin()); // 移除处理过的索引
keep[original_index] = 1; // 标记为保留
std::cout << "see: " << original_index << std::endl;
// 同时移除 obj_list 中对应的 Box3D 对象
obj_list.erase(obj_list.begin());
std::vector<std::vector<float>> remain_boxes;
std::vector<Box3D> remaining_obj;
std::vector<int> remaining_indices; // 新增:存储不相交框的原始索引
// 遍历剩下的 boxes,检查与当前 box 的相交情况
for (size_t i = 0; i < boxes.size(); ++i) {
std::cout << "compare with: " << original_indices[i] << std::endl;
if (!SatIntersection(next_box, boxes[i])) {
remain_boxes.push_back(boxes[i]); // 保留不相交的框
remaining_obj.push_back(obj_list[i]); // 保留不相交的 obj_list 对象
remaining_indices.push_back(original_indices[i]); // 保留不相交框的原始索引
} else {
std::cout << "remove: " << original_indices[i] << std::endl;
}
}
// 更新剩余的 boxes、obj_list 和 original_indices
boxes = remain_boxes;
obj_list = remaining_obj;
original_indices = remaining_indices; // 更新原始索引
}
return keep; // 返回保留的框的索引向量
}
std::vector<Box3D> readData(const std::string& filename) {
std::vector<Box3D> boxes;
std::ifstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open file for reading: " << filename << std::endl;
return boxes;
}
std::string line;
while (std::getline(file, line)) {
std::istringstream iss(line);
Box3D box;
if (!(iss >> box.conf >> box.sub_cls >> box.x >> box.y >> box.z >> box.w >> box.l >> box.h >> box.yaw >> box.velocity_x >> box.velocity_y)) {
std::cerr << "Error reading line: " << line << std::endl;
continue;
}
boxes.push_back(box);
}
file.close();
return boxes;
}
void dumpData(const std::vector<Box3D>& boxes, const std::string& filename) {
std::ofstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open file for writing: " << filename << std::endl;
return;
}
for (const auto& box : boxes) {
file << box.conf << " " << box.sub_cls << " " << box.x << " " << box.y << " "
<< box.z << " " << box.w << " " << box.l << " " << box.h << " " << box.yaw
<< " " << box.velocity_x << " " << box.velocity_y << "\n";
}
file.close();
}
int main() {
// 创建测试用例
#define PI 3.1415926f
Box3D bbox1;
bbox1.conf = 0.9f;
bbox1.sub_cls = 0;
bbox1.x = 100.0f;
bbox1.y = 150.0f;
bbox1.z = 0.0f;
bbox1.w = 40.0f;
bbox1.l = 80.0f;
bbox1.h = 0.0f;
bbox1.yaw = fmod(10.0f*PI/180, PI * 2);
Box3D bbox2;
bbox2.conf = 0.85f;
bbox2.sub_cls = 0;
bbox2.x = 105.0f;
bbox2.y = 160.0f;
bbox2.z = 0.0f;
bbox2.w = 60.0f;
bbox2.l = 90.0f;
bbox2.h = 0.0f;
bbox2.yaw = fmod(20.0f*PI/180, PI * 2);
Box3D bbox3;
bbox3.conf = 0.7f;
bbox3.sub_cls = 0;
bbox3.x = 300.0f;
bbox3.y = 400.0f;
bbox3.z = 0.0f;
bbox3.w = 30.0f;
bbox3.l = 70.0f;
bbox3.h = 0.0f;
bbox3.yaw = fmod(45.0f*PI/180, PI * 2);
Box3D bbox4;
bbox4.conf = 0.75f;
bbox4.sub_cls = 0;
bbox4.x = 110.0f;
bbox4.y = 160.0f;
bbox4.z = 0.0f;
bbox4.w = 40.0f;
bbox4.l = 85.0f;
bbox4.h = 0.0f;
bbox4.yaw = fmod(30.0f*PI/180, PI * 2);
Box3D bbox5;
bbox5.conf = 0.6f;
bbox5.sub_cls = 0;
bbox5.x = 500.0f;
bbox5.y = 600.0f;
bbox5.z = 0.0f;
bbox5.w = 50.0f;
bbox5.l = 100.0f;
bbox5.h = 0.0f;
bbox5.yaw = fmod(90.0f*PI/180, PI * 2);
Box3D bbox6;
bbox6.conf = 0.5f;
bbox6.sub_cls = 0;
bbox6.x = 505.0f;
bbox6.y = 610.0f;
bbox6.z = 0.0f;
bbox6.w = 45.0f;
bbox6.l = 105.0f;
bbox6.h = 0.0f;
bbox6.yaw = fmod(75.0f*PI/180, PI * 2);
Box3D bbox7;
bbox7.conf = 0.3f;
bbox7.sub_cls = 0;
bbox7.x = 700.0f;
bbox7.y = 800.0f;
bbox7.z = 0.0f;
bbox7.w = 30.0f;
bbox7.l = 60.0f;
bbox7.h = 0.0f;
bbox7.yaw = fmod(15.0f*PI/180, PI * 2);
// 创建框列表
// std::vector<Box3D> boxes = {bbox1, bbox2, bbox3, bbox4, bbox5, bbox6, bbox7};
std::vector<Box3D> boxes = readData("data.txt");
// 执行 SAT NMS
std::vector<int> keep = SatNms(boxes);
std::vector<Box3D> filtered_boxes;
for (size_t i = 0; i < keep.size(); ++i) {
if (keep[i]) {
filtered_boxes.push_back(boxes[i]);
}
}
boxes = filtered_boxes;
dumpData(boxes, "filtered_data.txt");
return 0;
}
python实现
import math
def sat_nms(obj_list, conf_score_key, parse_obj_func):
original_indices = list(range(len(obj_list)))
original_indices = sorted(original_indices, key=lambda i: obj_list[i][conf_score_key], reverse=True)
obj_list = sorted(obj_list, key=lambda i: i[conf_score_key], reverse=True)
if len(obj_list) < 2:
return obj_list
boxes = []
for obj in obj_list:
X, Y, L, W, _, yaw = parse_obj_func(obj)
box = [X, Y, L, W, yaw]
boxes.append(box)
selected_boxes = []
selected_obj = []
while boxes:
next_box = boxes.pop(0)
original_index = original_indices.pop(0)
selected_boxes.append(next_box)
next_obj = obj_list.pop(0)
selected_obj.append(next_obj)
print(f"See {original_index} with conf: {next_obj[conf_score_key]}")
remain_boxes = []
remaining_obj = []
remaining_indices = []
for idx, (rest_box, rest_obj) in enumerate(zip(boxes, obj_list)):
print(f"compare: {original_indices[idx]} ")
if not sat_intersection(next_box, rest_box):
remain_boxes.append(rest_box)
remaining_obj.append(rest_obj)
remaining_indices.append(original_indices[idx])
else:
print(f"remove {original_indices[idx]} with conf: {rest_obj[conf_score_key]}")
boxes = remain_boxes
obj_list = remaining_obj
original_indices = remaining_indices
return selected_obj
def rotate_point(x, y, cx, cy, angle_rad):
cos_a, sin_a = math.cos(angle_rad), math.sin(angle_rad)
rx = cx + (x - cx) * cos_a - (y - cy) * sin_a
ry = cy + (x - cx) * sin_a + (y - cy) * cos_a
return rx, ry
def get_rotated_bbox_vertices(cx, cy, w, h, angle):
half_w, half_h = w / 2, h / 2
vertices = [ (cx - half_w, cy - half_h), (cx + half_w, cy - half_h), (cx + half_w, cy + half_h), (cx - half_w, cy + half_h) ]
return [rotate_point(x, y, cx, cy, angle) for x, y in vertices]
def get_edges(vertices):
return [(vertices[i], vertices[(i + 1) % len(vertices)]) for i in range(len(vertices))]
def get_axis(edge):
x1, y1 = edge[0]
x2, y2 = edge[1]
return y2 - y1, x1 - x2
def project(vertices, axis):
dots = [vertex[0] * axis[0] + vertex[1] * axis[1] for vertex in vertices]
return min(dots), max(dots)
def overlap(projection1, projection2):
return min(projection1[1], projection2[1]) >= max(projection1[0], projection2[0])
def sat_intersection(bbox1, bbox2):
vertices1 = get_rotated_bbox_vertices(*bbox1)
vertices2 = get_rotated_bbox_vertices(*bbox2)
edges = get_edges(vertices1) + get_edges(vertices2)
for edge in edges:
axis = get_axis(edge)
projection1 = project(vertices1, axis)
projection2 = project(vertices2, axis)
print(f"Projection1: {projection1[0]} {projection1[1]}, Projection2: {projection2[0]} {projection2[1]}")
if not overlap(projection1, projection2):
return False # Separating axis found, no intersection
return True # No separating axis found, boxes intersect
def parse_box3d(obj):
return obj['x'], obj['y'], obj['l'], obj['w'], obj['z'], obj['yaw']
# 创建测试用例
PI = 3.1415926
boxes = [
{'conf': 0.9, 'sub_cls': 0, 'x': 100.0, 'y': 150.0, 'z': 0.0, 'w': 40.0, 'l': 80.0, 'h': 0.0, 'yaw': math.fmod(10.0 * PI / 180, PI * 2)},
{'conf': 0.85, 'sub_cls': 0, 'x': 105.0, 'y': 160.0, 'z': 0.0, 'w': 60.0, 'l': 90.0, 'h': 0.0, 'yaw': math.fmod(20.0 * PI / 180, PI * 2)},
{'conf': 0.7, 'sub_cls': 0, 'x': 300.0, 'y': 400.0, 'z': 0.0, 'w': 30.0, 'l': 70.0, 'h': 0.0, 'yaw': math.fmod(45.0 * PI / 180, PI * 2)},
{'conf': 0.75, 'sub_cls': 0, 'x': 110.0, 'y': 160.0, 'z': 0.0, 'w': 40.0, 'l': 85.0, 'h': 0.0, 'yaw': math.fmod(30.0 * PI / 180, PI * 2)},
{'conf': 0.6, 'sub_cls': 0, 'x': 500.0, 'y': 600.0, 'z': 0.0, 'w': 50.0, 'l': 100.0, 'h': 0.0, 'yaw': math.fmod(90.0 * PI / 180, PI * 2)},
{'conf': 0.5, 'sub_cls': 0, 'x': 505.0, 'y': 610.0, 'z': 0.0, 'w': 45.0, 'l': 105.0, 'h': 0.0, 'yaw': math.fmod(75.0 * PI / 180, PI * 2)},
{'conf': 0.3, 'sub_cls': 0, 'x': 700.0, 'y': 800.0, 'z': 0.0, 'w': 30.0, 'l': 60.0, 'h': 0.0, 'yaw': math.fmod(15.0 * PI / 180, PI * 2)},
]
# 执行 SAT NMS
selected_boxes = sat_nms(boxes, 'conf', parse_box3d)
# 输出结果
print("Kept boxes:")
for box in selected_boxes:
print(box)