文章目录
- cv2.kmeans
- 牛刀小试
cv2.kmeans
cv2.kmeans 是 OpenCV 库中用于执行 K-Means 聚类算法的函数。以下是根据参考文章整理的 cv2.kmeans 函数的中文文档:
一、函数功能
cv2.kmeans 用于执行 K-Means 聚类算法,将一组数据点划分到 K 个簇中,使得簇内的数据点尽可能相似,而簇间的数据点尽可能不同。
二、函数格式
retval, labels, centers = cv2.kmeans(data, K, None, criteria, attempts, flags, centers=None)
三、参数说明
-
data:需要被聚类的原始数据集合,数据类型应为 np.float32。数据应是一维或多维的,每个样本应使用一行表示。例如,Mat points(count, 2, CV_32F) 表示二维浮点数据集。
-
K:聚类簇数,即希望将数据分成的簇的数量。
-
None:在原始 API 中,此位置是用于传递之前迭代的标签的,但在大多数情况下,可以设置为 None,因为算法会自动处理。
-
criteria:算法的终止条件。通常是一个包含三个元素的元组 (type, max_iter, epsilon):
-
type:终止条件类型,可以是
cv2.TERM_CRITERIA_EPS
(仅当 epsilon 满足时停止)、cv2.TERM_CRITERIA_MAX_ITER
(当迭代次数超过阈值时停止)或两者之和。 -
max_iter:最大迭代次数。
-
epsilon:精确度阈值。
-
-
attempts:使用不同的初始中心(或种子)来执行算法的次数。算法会返回最好的结果。
-
flags:用于设置如何选择起始重心。可以是
cv2.KMEANS_PP_CENTERS
(使用 K-Means++ 初始化)或cv2.KMEANS_RANDOM_CENTERS
(随机初始化)。
centers(可选):输出的聚类中心。如果未提供,则算法会返回一个。
四、返回值
- retval:紧密度(compactness),即每个点到其相应簇中心的距离的平方和。
- labels:每个数据点的最终分类标签数组。
- centers:由聚类中心组成的数组。
五、注意事项
-
在调用 cv2.kmeans 之前,通常需要将数据转换为 np.float32 类型,并确保数据的形状是 (样本数, 特征数)。
-
聚类结果可能受初始中心选择的影响,因此设置 attempts 参数为较高的值可能会得到更稳定的结果。
-
根据问题的具体需求和数据特性,可能需要调整 K、max_iter 和 epsilon 等参数以获得最佳聚类效果。
牛刀小试
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
num_classes = 6
img = cv2.imread('2.jpg', 0) # image read be 'gray'
# change img(2D) to 1D
img1 = img.reshape((img.shape[0]*img.shape[1], 1))
img1 = np.float32(img1)
# define criteria = (type,max_iter,epsilon)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
# set flags: hou to choose the initial center
# ---cv2.KMEANS_PP_CENTERS ;
# cv2.KMEANS_RANDOM_CENTERS
flags = cv2.KMEANS_RANDOM_CENTERS
# apply kmenas
compactness, labels, centers = cv2.kmeans(img1, num_classes, None, criteria, 10, flags)
print(len(centers))
mask = labels.reshape((img.shape[0],img.shape[1]))
cmap = cm.get_cmap('Set1', num_classes) # 使用'viridis' colormap,但你可以使用其他colormap
# 绘制mask图像
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(mask, cmap=cmap, interpolation='nearest', alpha=0.8)
plt.title('mask')
plt.xticks([])
plt.yticks([])
# 你可以添加颜色条(colorbar)来显示每个颜色对应的类别
cbar = fig.colorbar(ax.images[-1], ax=ax, ticks=np.arange(num_classes))
cbar.ax.set_yticklabels(['background'] + [f'class {i}' for i in range(1, num_classes)])
# 显示图像
plt.show()
输入的原图
显示的灰度图
聚成2类
聚成3类
聚成4类
聚成5类
聚成6类