本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在AI学习笔记:
AI学习笔记(10)---《k-NN 剪辑近邻法》
k-NN 剪辑近邻法
目录
1. 前言
2.相关知识
2.1最邻近法决策规则
2.2剪辑最近邻算法基本原理
3.K-means 算法原理
3.1k 最近邻法( k-NN )
3.2重复剪辑最近邻法
4.编程实现
4.1最近邻法MATLAB代码
4.2 剪辑最近邻法MATLAB代码
4.3 测试输入
1. 前言
理解并掌握基本最邻近法( k-NN )的算法思想以及基本过程,并实现一个 k-NN 算法的函数。
理解并掌握剪辑最近邻算法思想以及基本过程,并实现一个剪辑最近邻算法的函数。
2.相关知识
2.1最邻近法决策规则
2.2剪辑最近邻算法基本原理
这种方法的思想是,在最近邻算法的基础上,理清两类间的边界,去掉类别混杂的样本,使两类别的边界更清晰。这样,不仅能够减少最近邻法判别类别的样本数量从而提升分类效率,还在理论上明显好于一般的最近邻法。
可以证明,经过剪辑的最近邻法的渐进误判概率接近 Bayes 误判概率。
3.K-means 算法原理
3.1k 最近邻法( k-NN )
由于上述方法只根据离待识别模式最近的一个样本的类别而决定其类别,通常称其为最近邻法或 1-NN 法。
3.2重复剪辑最近邻法
只要样本足够多,就可以重复地执行剪辑程序,以进一步提高分类性能。这里给出一种称为 MULTIEDIT 的使用算法:
4.编程实现
4.1最近邻法MATLAB代码
function [rm] = step1_mission(samples, labels, k, x)
[m,n] = size(samples); % 获取样本矩阵 samples 的维度,m 为样本数量,n 为特征数量
E = zeros(1,m); % 初始化长度为 m 的零向量,用于存储距离值
c = 0; % 初始化类别计数器
for i=1:m % 遍历每一个样本
E(i) = norm(samples(i,:)-x)/sqrt(n); % 计算每个样本与输入向量 x 的欧氏距离,并标准化
c = max(c,labels(i)); % 更新类别计数器,获取样本中最大的类别标签
end
kc = zeros(c,1); % 初始化长度为 c 的零向量,用于存储每个类别的计数
for i=1:k % 选取前 k 个最近邻
[b,j] = min(E); % 找到距离最小的样本及其索引
kc(labels(j)) = kc(labels(j)) + 1; % 对应类别计数加一
E(j) = []; % 删除已选择的最近邻的距离值,防止重复选择
end
[~,rm] = max(kc); % 找到计数最多的类别作为最终分类结果,并返回其标签值
end
4.2 剪辑最近邻法MATLAB代码
function [samples, labels] = step2_ans(samples, labels, s, k)
while 1==1 % 无限循环,直到满足终止条件
[m,n] = size(samples); % 获取样本矩阵 samples 的维度,m 为样本数量,n 为特征数量
stride = ceil(m/s); % 计算每个子集的步长,即每个子集的样本数量
edi = zeros(1,m); % 初始化长度为 m 的零向量,用于标记需要编辑的样本
head = zeros(1,s); % 初始化长度为 s 的零向量,用于存储每个子集的起始索引
tail = zeros(1,s); % 初始化长度为 s 的零向量,用于存储每个子集的结束索引
for i=1:s
head(i)=(i-1)*stride+1; % 计算每个子集的起始索引
tail(i)=min(head(i)+stride-1,m); % 计算每个子集的结束索引
end
for i=1:s % 遍历每个子集
eh = head(mod(i,s)+1); % 获取下一个子集的起始索引
et = tail(mod(i,s)+1); % 获取下一个子集的结束索引
for j=head(i):tail(i) % 遍历当前子集中的每个样本
y = step1_ans_func(samples(eh:et,:),labels(eh:et),k,samples(j)); % 调用 step1_ans_func 函数预测样本 j 的标签
if abs(y-labels(j)) > 10e-6 % 判断预测标签与真实标签是否相差过大
edi(j) = 1; % 标记需要编辑的样本
end
end
end
if max(edi)<1 % 如果没有需要编辑的样本,跳出循环
break
end
edi_idx = find(edi==0); % 找出所有未标记的样本索引
samples = samples(edi_idx,:); % 更新样本矩阵,只保留未标记的样本
labels = labels(edi_idx); % 更新标签向量,只保留未标记的样本对应的标签
size(samples); % 获取更新后的样本矩阵的大小
end
end
function [rm] = step1_ans_func(samples, labels, k, x)
[m,n] = size(samples); % 获取样本矩阵 samples 的维度,m 为样本数量,n 为特征数量
E = zeros(1,m); % 初始化长度为 m 的零向量,用于存储距离值
c = 0; % 初始化类别计数器
for i=1:m % 遍历每一个样本
E(i) = norm(samples(i,:)-x)/sqrt(n); % 计算每个样本与输入向量 x 的欧氏距离,并标准化
c = max(c,labels(i)); % 更新类别计数器,获取样本中最大的类别标签
end
kc = zeros(c,1); % 初始化长度为 c 的零向量,用于存储每个类别的计数
for i=1:k % 选取前 k 个最近邻
[b,j] = min(E); % 找到距离最小的样本及其索引
kc(labels(j)) = kc(labels(j)) + 1; % 对应类别计数加一
E(j) = []; % 删除已选择的最近邻的距离值,防止重复选择
end
[~,rm] = max(kc); % 找到计数最多的类别作为最终分类结果,并返回其标签值
end
4.3 测试输入
MATLAB终端输入下面指令
step1_mission(samples, labels, k, x)
samples = [-7.82 -4.58 -3.97; -6.68 3.16 2.71; 4.36 -2.91 2.09; 6.72 0.88 2.80; -8.64 3.06 3.50; -6.87 0.57 -5.45; 4.47 -2.62 5.76; 6.73 -2.01 4.18; -7.71 2.34 -6.33; -6.91 -0.49 -5.68; 6.18 2.81 5.82; 6.72 -0.93 -4.04; -6.25 -0.26 0.56; -6.94 -1.22 1.13; 8.09 0.20 2.25; 6.81 0.17 -4.15; -5.19 4.24 4.04; -6.38 -1.74 1.43; 4.08 1.30 5.33; 6.27 0.93 -2.78];
labels = [1 2 2 2 1 1 2 2 1 1 2 2 1 1 2 2 2 1 2 2];
disp('task1');
rm = step1_mission(samples(1:20,:),labels(1:20),1,[10 10 10])
disp('task2');
rm = step1_mission(samples(1:20,:),labels(1:20),3,[10 10 10])
disp('task3');
rm = step1_mission(samples(1:20,:),labels(1:20),5,[10 10 10])
disp('task4');
rm = step1_mission(samples(1:20,:),labels(1:20),7,[10 10 10])
disp('task5');
rm = step1_mission(samples(1:20,:),labels(1:20),9,[10 10 10])
测试step2_ans(samples, labels, s, k)
randn('state',1);
warning off;
num = 200;
R1 = [5 0; 0 1];
R2 = [10 0; 0 25];
u1 = [-3 0];
u2 = [5 0];
Y1 = multivrandn(u1,R1,num,1);
Y2 = multivrandn(u2,R2,num,2);
L1 = ones(1,num);
L2 = L1 * 2;
Y = [Y1; Y2];
L = [L1 L2];
size(Y);
size(L);
randIndex = randperm(size(Y,1));
Y = Y(randIndex,:);
L = L(randIndex);
subplot(211);
plot(Y1(:,1),Y1(:,2),'*');
hold on;
plot(Y2(:,1),Y2(:,2),'o');
[ry, rl] = step2_mission(Y, L, 4, 5);
idx1 = find(rl==1);
idx2 = find(rl==2);
size(ry);
ry1 = ry(idx1,:);
ry2 = ry(idx2,:);
subplot(212);
plot(ry1(:,1),ry1(:,2),'*');
hold on;
plot(ry2(:,1),ry2(:,2),'o');
saveas(1,'./result/myfig.png');
[ans_ry, ans_rl] = step2_ans(Y, L, 4, 5);
ans_idx1 = find(ans_rl==1);
ans_idx2 = find(ans_rl==2);
size(ans_ry);
ans_ry1 = ans_ry(ans_idx1,:);
ans_ry2 = ans_ry(ans_idx2,:);
subplot(212);
plot(ans_ry1(:,1),ans_ry1(:,2),'*');
hold on;
plot(ans_ry2(:,1),ans_ry2(:,2),'o');
saveas(1,'./answer/ans.png');
diff = sum(ans_rl-rl)+sum(sum(ans_ry-ry));
if diff==0
disp('OK');
else
disp('Not Equal!')
end
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者私信联系作者。