Kmeans 是一种可以将一个数据集按照距离(相似度)划分成不同类别的算法,它无需借助外部标记,因此也是一种无监督学习算法。
什么是聚类
用官方的话说聚类就是将物理或抽象对象的集合分成由类似的对象组成的多个类的过程。用自己的话说聚类是根据不同样本数据间的相似度进行种类划分的算法。这种划分可以基于我们的业务需求或建模需求来完成,也可以单纯地帮助我们探索数据的自然结构和分布。
什么是K-means聚类
用官方的话说:k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是,预将数据分为K组,则随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。
K-means聚类实现流程
K-means聚类聚类的优劣性
优点:
- K-means聚类可以支持无监督学习,无需人工标记即可进行分类
- K-means聚类有处理不同类型数据的能力,如二元、序数、标称、数值等类型数据都可以处理。
- K-means聚类算法基于欧几里得或者曼哈顿距离度量来决定聚类。基于这样的距离度量的算法趋向于发现具有相近尺度和密度的球状簇。但是,一个簇可能是任意形状的。提出能发现任意形状簇的算法是很重要的。
缺点:
- 需要提前确定几何中心的数量
- 设置初始几何中心需要考虑尽可能选取差异较大的数据作为初始几何中心
- 适用于有明显中心的数据样本,对于相对分散的数据样本处理效果欠佳。
典型案例
学校A有若干不同年龄分布的学生,并且性别也不一样,想要依据这两个参数对学生进行分类。
学生类
import java.util.List;
public class Student{
@Override
public String toString() {
return "Student [name=" + name + ", age=" + age + ", gender=" + gender + ", myHobby=" + myHobby
+ ", myDream=" + myDream + "]";
}
public List<MyHobby> getMyHobby() {
return myHobby;
}
public Student setMyHobby(List<MyHobby> myHobby) {
this.myHobby = myHobby;
return this;
}
public String getName() {
return name;
}
public Student setName(String name) {
this.name = name;
return this;
}
public int getAge() {
return age;
}
public Student setAge(int age) {
this.age = age;
return this;
}
public String getGender() {
return gender;
}
public Student setGender(String gender) {
this.gender = gender;
return this;
}
String name;
@Elem(type = ElemType.NUMBER)
int age;
@Elem(type = ElemType.XUSHU,list={"男","女"})
String gender;
@Elem()
List<MyHobby> myHobby;
@Elem()
List<String> myDream;
public Student(String name, int age, String gender) {
super();
this.name = name;
this.age = age;
this.gender = gender;
}
public Student(String name, int age, String gender,List<MyHobby> myHobby) {
this(name,age,gender);
this.myHobby = myHobby;
}
public Student(String name, int age, String gender,List<MyHobby> myHobby, List<String> myDreams) {
this(name,age,gender);
this.myHobby = myHobby;
this.myDream = myDreams;
}
}
配置类
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Elem {
ElemType type() default ElemType.BASIC; //属性类型
String[] list() default {}; //选择项
}
package kmeans;
/**
* 元素属性类型(标称属性、序数属性、数值属性、二元属性)
* @author zygswo
*
*/
public enum ElemType {
BASIC("标称属性"),
XUSHU("序数属性"),
NUMBER("数值属性"),
ERYUAN("二元属性");
private String name;
private ElemType(String name) {
this.setName(name);
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
}
package kmeans;
public enum DistanceType {
EUCLID("欧几里得距离"),
MANHATTAN("曼哈顿距离"),
QIEBIXUEFU("切比雪夫距离");
private String name;
private DistanceType(String name) {
this.setName(name);
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
}
主方法
package kmeans;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* kmeans聚类工具类
* @author zygswo
*
*/
public class KmeansUtils<T> {
private int initKNodeNb; //kmeans初始几何中心数量
private List<T> trainData; //kmeans训练数据
private DistanceType distanceType;
/**
* kmeans构造方法(默认为欧式距离公式)
* @param initKNodeNb kmeans初始几何中心数量
* @param trainData 训练数据
*/
public KmeansUtils(List<T> trainData, int initKNodeNb) {
this.initKNodeNb = initKNodeNb;
this.trainData = trainData;
this.distanceType = DistanceType.EUCLID;
}
/**
* kmeans构造方法(默认为欧式距离公式)
* @param initKNodeNb kmeans初始几何中心数量
* @param trainData 训练数据
* @param distanceType 距离公式
*/
public KmeansUtils(List<T> trainData, int initKNodeNb, DistanceType distanceType) {
this.initKNodeNb = initKNodeNb;
this.trainData = trainData;
this.distanceType = distanceType;
}
/**
* kmeans模型训练
*/
public void fit(){
//计算距离
List<Map<String,Double>> initKNodeDistanceVal = Collections.synchronizedList(
new ArrayList<>()
);
//初始化几何列表
List<List<T>> resList = Collections.synchronizedList(
new ArrayList<>()
);
if (this.trainData == null || this.trainData.isEmpty()) {
throw new IllegalArgumentException("训练集为空");
}
if (this.initKNodeNb <=0) {
throw new IllegalArgumentException("几何中心数量小于0");
}
if (this.initKNodeNb > this.trainData.size()) {
throw new IllegalArgumentException("几何中心数量超过数组数量");
}
if (this.distanceType == null) {
throw new IllegalArgumentException("距离类型为空");
}
//1.获取前initKNodeNb个数据放入initKNodeList列表中
//初始化的几何中心,需要选择差异较大的
this.trainData.sort((T item1,T item2)-> {
return (int)(calcDiff(item1,this.trainData.get(0)) - calcDiff(item2,this.trainData.get(0)));
});
int step = this.trainData.size() / initKNodeNb;
//选择从小到大的initKNodeNb个元素作为初始几何
for (int i = 0; i < this.trainData.size() && resList.size() < initKNodeNb; i+=step) {
List<T> temp = Collections.synchronizedList(
new ArrayList<>()
);
temp.add(this.trainData.get(i));
resList.add(temp); //多个几何列表设置初始结点
}
//2.计算所有变量到不同的几何中心距离,如果稳定了(几何中心固定了),就退出循环
while(true) {
boolean balanced = true; //是否已经平衡
for (T item: this.trainData) {
double distance, minDistance = Double.MAX_VALUE; //求最小距离
int preIndex = 0,afterIndex = 0; //preIndex-原位置
initKNodeDistanceVal.clear();
// for (List<T> list : resList) {
// System.out.println(list.toString());
// }
//计算几何中心
for (int i = 0; i < initKNodeNb; i++) {
if (resList.get(i).size() > 0)
initKNodeDistanceVal.add(calc(resList.get(i))); //计算初始结点距离
}
//计算原来的位置
for (int i = 0; i < initKNodeNb; i++) {
if(resList.get(i).contains(item)) {
preIndex = i;
break;
}
}
// System.out.println("item = " + item.toString());
//计算不同变量到不同的几何中心距离
for (int i = 0; i < initKNodeNb; i++) {
if (resList.get(i).size() > 0 && i < initKNodeDistanceVal.size()) {
distance = calcDistance(item, initKNodeDistanceVal.get(i));
// System.out.println("distance = " + distance);
// System.out.println("minDistance = " + minDistance);
if (distance < minDistance) {
minDistance = distance;
afterIndex = i;
}
}
}
// System.out.println("preIndex = " + preIndex);
// System.out.println("afterIndex = " + afterIndex);
//位置替换,如果替换就还没结束
if (preIndex != afterIndex) {
resList.get(preIndex).remove(item);
resList.get(afterIndex).add(item);
balanced = false;
}
if (preIndex == afterIndex) {
//如果新增就还没结束
if (!resList.get(preIndex).contains(item)) {
resList.get(preIndex).add(item);
balanced = false;
}
}
}
if (balanced){
break;
}
}
// //打印结果
for (List<T> list : resList) {
System.out.println(list.toString());
}
}
/**
* 计算距离
* @param item1 item1
* @param item2 item2
* @return
*/
private double calcDiff(T item1, T item2) {
List<T> list = Collections.synchronizedList(new ArrayList<>());
list.add(item2);
Map<String, Double> map = calc(list);
double dist = calcDistance(item1, map);
return dist;
}
/**
* 计算距离
* @param item 当前对象
* @param map 几何中心
* @return
*/
private double calcDistance(T item, Map<String, Double> map) {
double distance = 0.0;//距离
int level = 0;//根据距离公式判断距离计算等级
Class<?> cls = item.getClass();
Field[] fs = cls.getDeclaredFields();
for (Field f : fs) {
double dist1 = 0.0, dist2 = 0.0;
f.setAccessible(true);
//获取需要计算的参数
Elem el = f.getAnnotation(Elem.class);
if (el == null) {
continue;
}
try {
switch(el.type()) {
case BASIC: break;
case XUSHU:
//获取数组
String[] arr = el.list();
if (arr == null) {
throw new IllegalArgumentException("序数属性需配置属性集合数组");
}
//数组排序
Arrays.sort(arr);
List<String> list = Arrays.asList(arr);
//计算差距步长
Double diffStep = 1 / (list.size() * 1.0);
//获取当前对象序数属性的值
Object value = f.get(item);
dist1 = list.indexOf(value) * diffStep;
break;
case NUMBER:
//获取当前对象数值属性的值
Object value1 = f.get(item);
//数据转换
Double intVal = Double.parseDouble(String.valueOf(value1));
dist1 = intVal;
break;
case ERYUAN:
//获取数组
String[] arr1 = el.list();
if (arr1 == null) {
arr1 = new String[]{"0","1"};
} else {
//数组排序
Arrays.sort(arr1);
}
//转列表
List<String> list1 = Arrays.asList(arr1);
//计算差距步长
Double diffStep1 = 1 / (list1.size() * 1.0);
Object value2 = f.get(item);
int ind = list1.indexOf(value2);
dist1 = ind * diffStep1;
break;
}
//获取当前几何中心属性的值
dist2 = map.get(f.getName());
//计算距离
switch(distanceType) {
case EUCLID: level = 2; break;
case MANHATTAN: level = 1;break;
case QIEBIXUEFU: level = 100;break;
}
distance += Math.pow(Math.abs(dist1 - dist2),level);
} catch(Exception ex) {
throw new RuntimeException(ex.getMessage());
}
distance = Math.pow(distance, 1/(level * 1.0));
}
return distance;
}
/**
* 计算几何中心坐标
* @param kNodeList
* @return 几何中心坐标map
*/
private Map<String, Double> calc(List<T> kNodeList) {
if (kNodeList == null || kNodeList.size() <= 0) {
throw new IllegalArgumentException("几何中心列表数组为空");
}
//反射获取参数,形成数值数组
Map<String, Double> result = new ConcurrentHashMap<>();
T item = kNodeList.get(0);
Class<?> cls = item.getClass();
Field[] fs = cls.getDeclaredFields();
for (Field f: fs) {
//获取需要计算的参数
Elem el = f.getAnnotation(Elem.class);
if (el == null) {
continue;
}
//将数据转换成数值
Double dist = 0.0;
switch(el.type()) {
case BASIC: break;
case XUSHU:
//获取数组
String[] arr = el.list();
if (arr == null) {
throw new IllegalArgumentException("序数属性需配置属性集合数组");
}
//数组排序
Arrays.sort(arr);
//转列表
List<String> list = Arrays.asList(arr);
//计算差距步长
Double diffStep = 1 / (list.size() * 1.0);
for (T kNode : kNodeList) {
try {
//获取当前对象序数属性的值
Object value = f.get(kNode);
int ind = list.indexOf(value);
//求和
dist += ind * diffStep;
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
break;
case NUMBER:
for (T kNode : kNodeList) {
try {
//获取当前对象数值属性的值
Object value = f.get(kNode);
//数据转换
Double intVal = Double.parseDouble(String.valueOf(value));
dist += intVal;
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
break;
case ERYUAN:
//获取数组
String[] arr1 = el.list();
if (arr1 == null) {
arr1 = new String[]{"0","1"};
} else {
//数组排序
Arrays.sort(arr1);
}
//转列表
List<String> list1 = Arrays.asList(arr1);
//计算差距步长
Double diffStep1 = 1 / (list1.size() * 1.0);
for (T kNode : kNodeList) {
try {
//获取当前对象二元属性的值
Object value = f.get(kNode);
int ind = list1.indexOf(value);
//求和
dist += ind * diffStep1;
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
break;
}
dist /= (kNodeList.size() * 1.0); //求平均值
result.put(f.getName(), dist);
}
return result;
}
public static void main(String[] args) {
List<Student> trainData = new ArrayList<>();
trainData.add(new Student("zyl",28,"男"));
trainData.add(new Student("sjl",28,"女"));
trainData.add(new Student("xxx",27,"男"));
trainData.add(new Student("stc",30,"男"));
trainData.add(new Student("wxq",30,"女"));
trainData.add(new Student("zzz",27,"男"));
trainData.add(new Student("sss",27,"女"));
trainData.add(new Student("mmm",20,"男"));
trainData.add(new Student("qqq",20,"女"));
trainData.add(new Student("666",30,"男"));
// trainData.add(new Student("mmm",19,"男"));
KmeansUtils<Student> utils = new KmeansUtils<>(trainData, 4);
utils.fit();
}
}
运行结果