文章目录
- day56-57 kMeans 聚类
- 1.kMeans聚类理解
- 2.代码理解
- 2.1代码中变量的理解
- 2.2代码理解
day56-57 kMeans 聚类
1.kMeans聚类理解
无监督的机器学习算法,其中k是划分为几个簇,并且选择k个数据作为不同簇的聚类中心,计算每个数据样本和聚类中心的距离(欧式距离或曼哈顿距离)并将数据样本分配给离聚类中心最近的类别。在遍历完所有数据后,则可以把数据集分成k个簇,对每个簇又要重新计算他的聚类中心(求平均值)。我们会进行多次迭代,直到聚类中心不变或者是到达一定次数的迭代。
2.代码理解
2.1代码中变量的理解
(主要是clustering()方法中的变量)
- tempClusterArray
当前循环中每个数据样本属于哪一个簇。如下值2=2代表数据样本2通过与k个聚类中心之间的计算,发现离2这个聚类中心距离最近,故将数据样本2聚类到1这个簇中。
- tempOldClusterArray
用于存储旧的聚类分配结果的数组(可以理解为上一次迭代对数据聚类的结果)
- tempCenters
存放聚类的中心。初始化时赋值为:对数据样本集随机排序,再随机选择数据集中的数据点作为初始聚类中心
- tempNewCenters
对循环后分类后的不同簇重新选择聚类中心(求平均值)
2.2代码理解
只要理解了KMeans的核心,代码分段读很好理解。
- 1.选择簇的数量K(目前设置为3)、
- 2.初始化聚类中心tempCenters(将数据集随机排序后选择前K个作为聚类中心)
- 3.分配数据样本到簇(计算数据样本与聚类中心的距离,选择距离最短的)
- 4.重新计算聚类中心(计算不同簇的平均值)
- 5.重复步骤3,4(调出循环的条件是:tempOldClusterArray与tempClusterArray相等时 即上一次迭代和当前迭代聚类分配结果不再发生变化时)
package machinelearing.knn;
import weka.core.Instances;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
/**
* @author: fulisha
* @date: 2023-05-28 10:36
* @description:
*/
public class KMeans {
/**
* Manhattan distance.
*/
public static final int MANHATTAN = 0;
/**
* Euclidean distance.
*/
public static final int EUCLIDEAN = 1;
/**
* The distance measure.
*/
public int distanceMeasure = EUCLIDEAN;
/**
* A random instance;
*/
public static final Random random = new Random();
/**
* The data.
*/
Instances dataset;
/**
* The number of clusters.
*/
int numClusters = 2;
/**
* The clusters.
*/
int[][] clusters;
/**
* The first constructor.
* @param paraFilename The data filename.
*/
public KMeans(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
}
}
public void setNumClusters(int paraNumClusters) {
numClusters = paraNumClusters;
}
/**
* Get a random indices for data randomization.
* @param paraLength The length of the sequence.
* @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
}
return resultIndices;
}
/**
* The distance between two instances.
* @param paraI The index of the first instance.
* @param paraArray The array representing a point in the space.
* @return The distance.
*/
public double distance(int paraI, double[] paraArray) {
int resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
}
}
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
resultDistance += tempDifference * tempDifference;
}
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}
return resultDistance;
}
public void clustering() {
int[] tempOldClusterArray = new int[dataset.numInstances()];
tempOldClusterArray[0] = -1;
int[] tempClusterArray = new int[dataset.numInstances()];
Arrays.fill(tempClusterArray, 0);
double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];
// Step 1. Initialize centers.
int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
for (int i = 0; i < numClusters; i++) {
for (int j = 0; j < tempCenters[0].length; j++) {
tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
}
}
int[] tempClusterLengths = null;
while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
System.out.println("New loop ...");
tempOldClusterArray = tempClusterArray;
tempClusterArray = new int[dataset.numInstances()];
// Step 2.1 Minimization. Assign cluster to each instance.
int tempNearestCenter;
double tempNearestDistance;
double tempDistance;
for (int i = 0; i < dataset.numInstances(); i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for (int j = 0; j < numClusters; j++) {
tempDistance = distance(i, tempCenters[j]);
if (tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
}
}
tempClusterArray[i] = tempNearestCenter;
}
// Step 2.2 Mean. Find new centers.
tempClusterLengths = new int[numClusters];
Arrays.fill(tempClusterLengths, 0);
double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
// Arrays.fill(tempNewCenters, 0);
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
}
tempClusterLengths[tempClusterArray[i]]++;
}
// Step 2.3 Now average
for (int i = 0; i < tempNewCenters.length; i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[i][j] /= tempClusterLengths[i];
}
}
System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
tempCenters = tempNewCenters;
}
// Step 3. Form clusters.
clusters = new int[numClusters][];
int[] tempCounters = new int[numClusters];
for (int i = 0; i < numClusters; i++) {
clusters[i] = new int[tempClusterLengths[i]];
}
for (int i = 0; i < tempClusterArray.length; i++) {
clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
tempCounters[tempClusterArray[i]]++;
}
System.out.println("The clusters are: " + Arrays.deepToString(clusters));
}
public static void testClustering() {
KMeans tempKMeans = new KMeans("C:/Users/王忠云/Desktop/sampledata/iris.arff");
tempKMeans.setNumClusters(3);
tempKMeans.clustering();
}
public static void main(String arags[]) {
testClustering();
}
}
- 代码结果