计算数据集中的元素与各个簇的中心的距离,将它赋给最近的簇,然后重新计算每个簇的平均值,再将元素按离平均值点最近的原则重新分配直到没有出现重新分配
该算法要事先给出k的值,即划分为几个簇。
vector<int> datoclu(data.size(), -1);用这个来标记每个数据在哪个簇中。
#include <fstream>
#include <sstream>
#include <vector>
#include <iostream>
using namespace std;
struct Point
{
double x;
double y;
};
double distance(const Point& a, const Point& b)
{
return sqrt(pow(a.x - b.x, 2) + pow(a.y - b.y, 2));
}
vector<int> KMeans(vector<Point>& data, int k, int maxIterations)
{
vector<Point> centroids(k);
for (int i = 0; i < k; i++)
{
centroids[i] = data[rand() % data.size()]; //随机选择k个类聚中心。0到(data.size()-1)
}
vector<int> datoclu(data.size(), -1); //每个数据属于哪个簇
bool flag = 0;
while (!flag && maxIterations)
{
flag = 1;
for (int i = 0; i < data.size(); i++)
{
double minDis = numeric_limits<double>::max();
int index = -1;
for (int j = 0; j < centroids.size(); j++)
{
double dis = distance(data[i], centroids[j]);
if (dis < minDis)
{
minDis = dis;
index = j;
}
}
if (datoclu[i] != index) //记录每个数据属于的聚类中心
{
datoclu[i] = index;
flag = 0;
}
}
vector<Point> newClu(k);
vector<int> num(k, 0);
//计算每个簇平均值点
for (int i = 0; i < data.size(); i++)
{
newClu[datoclu[i]].x += data[i].x;
newClu[datoclu[i]].y += data[i].y;
num[datoclu[i]]++;
}
for (int i = 0; i < k; i++)
{
newClu[i].x /= num[i];
newClu[i].y /= num[i];
}
centroids = newClu;
maxIterations--;
}
return datoclu;
}
vector<Point> ReadData(string filename)
{
vector<Point> data;
ifstream file(filename);
if (file.is_open())
{
string line;
while (getline(file, line))
{
istringstream iss(line);
double x, y;
string token;
Point point;
if (getline(iss, token, ',') && istringstream(token) >> point.x &&
getline(iss, token, ',') && istringstream(token) >> point.y) {
data.push_back(point);
}
}
}
else
{
cout << "open fail";
}
file.close();
return data;
}
int main()
{
vector<Point> dataset = ReadData("data.txt");
vector<int> clusters;
int k, maxIterations;
cout << "输入簇的个数和最大迭代次数"<<endl;
cin >> k >> maxIterations;
clusters= KMeans(dataset, k, maxIterations);
vector <vector<int>> index(k);
for (int j = 0; j < k; j++)
{
for (int i = 0; i < clusters.size(); i++)
{
if (clusters[i] == j)
{
index[j].push_back(i);
}
}
}
for (int i = 0; i < index.size(); i++)
{
cout << "{";
for (int j = 0; j < index[i].size(); j++)
{
cout << index[i][j]+1;
if (j != index[i].size() - 1)
{
cout << ",";
}
}
cout << "}";
}
}
数据集
1.0, 1.0
2.0, 1.0
1.0, 2.0
2.0, 2.0
4.0, 3.0
5.0, 3.0
4.0, 4.0
5.0,4.0
运行结果