【Python】 剪辑法欠采样 CNN压缩近邻法欠采样

news2025/1/22 21:06:59

借鉴:关于K近邻(KNN),看这一篇就够了!算法原理,kd树,球树,KNN解决样本不平衡,剪辑法,压缩近邻法 - 知乎

但是不要看他里面的代码,因为作者把代码里的一些符号故意颠倒了 ,比如“==”改成“!=”,还有乱加“~”,看明白逻辑才能给他改过来

一、剪辑法

        当训练集数据中存在一部分不同类别数据的重叠时(在一部分程度上说明这部分数据的类别比较模糊),这部分数据会对模型造成一定的过拟合,那么一个简单的想法就是将这部分数据直接剔除掉即可,也就是剪辑法。

        剪辑法将训练集 D 随机分成两个部分,一部分作为新的训练集 Dtrain,一部分作为测试集 Dtest,然后基于 Dtrain,使用 KNN 的方法对 Dtest 进行分类,并将其中分类错误的样本从整体训练集 D 中剔除掉,得到 Dnew。

        由于对训练集 D 的划分是随机划分,难以保证数据重叠部分的样本在第一次剪辑时就被剔除,因此在得到 Dnew 后,可以对 Dnew 继续进行上述操作数次,这样可以得到一个比较清爽的类别分界。

        效果如下图:

        附上可直接运行的代码:

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where

# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,
                                            n_informative=2, n_redundant=0, n_repeated=0,
                                            n_classes=4, n_clusters_per_class=1)

# # # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()

# 剪辑10次
for i in range(10):
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

    k = 5
    KNN_clf = KNN(n_neighbors=k)
    KNN_clf.fit(x_train, y_train)  # 用训练集训练KNN
    y_predict = KNN_clf.predict(x_test)  # 用测试集测试

    cond = y_predict == y_test
    x_test = x_test[cond]  # 把预测错误的从整体数据集中剔除掉
    y_test = y_test[cond]  # 把预测错误的从整体数据集中剔除掉

    X = np.vstack([x_train, x_test])  # 为下一次循环做准备(剔除掉本轮预测错误的
    y = np.hstack([y_train, y_test])  # 为下一次循环做准备(剔除掉本轮预测错误的

# summarize the new class distribution
counter = Counter(y)
print(counter)

# 画出二维散点图
for label, _ in counter.items():
	row_ix = where(y == label)[0]
	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        以上使用了k=20的参数进行剪辑的结果,循环了10次,一般而言,k越大,被抛弃的样本会越多,因为被分类的错误的概率更大。

二、CNN压缩近邻法欠采样

        

        压缩近邻法的想法是认为同一类型的样本大量集中在类簇的中心,而这些集中在中心的样本对分类没有起到太大的作用,因此可以舍弃掉这些样本。

        其做法是将训练集随机分为两个部分,第一个部分为 store,占所有样本的 10% 左右,第二个部分为 grabbag,占所有样本的 90% 左右,然后将 store 作为训练集训练 KNN 模型,grabbag 作为测试集,将分类错误的样本从 grabbag 中移动到 store 里,然后继续用增加了样本的 store 和减少了样本的 grabbag 再次训练和测试 KNN 模型,直到 grabbag 中所有样本被分类正确,或者 grabbag 中样本数为0。

        在压缩结束之后,store 中存储的是初始化时随机选择的 10% 左右的样本,以及在之后每一次循环中被分类错误的样本,这些被分类错误的样本集中在类簇的边缘,认为是对分类作用较大的样本。

        CNN欠采样已经有相应的Python实现库了,相应的方法是CondensedNearestNeighbour(),下面是可直接运行的代码。

# Undersample and plot imbalanced dataset with the Condensed Nearest Neighbor Rule
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import CondensedNearestNeighbour
from matplotlib import pyplot
from numpy import where

# make_classification方法用于生成分类任务的人造数据集
# X是数据,几维都可以,n_features=4表示4维
# y用0/1表示类别,weights调整0和1的占比
X, y = make_classification(n_samples=500, n_classes=2, n_features=3, n_redundant=0,
	# n_clusters_per_class表示每个类别多少簇  # flip_y噪声,增加分类难度
	n_clusters_per_class=2, weights=[0.5], flip_y=0, random_state=1)

# summarize class distribution
counter = Counter(y)  # {0: 990, 1: 10} counter是一个字典,value存储类别,key存储类别个数
print(counter)

# ==================CNN有直接可以调用的包  n_neighbors设置k值,k值越小越省时间,就设置为1吧
undersample = CondensedNearestNeighbour(n_neighbors=1)
# transform the dataset
X, y = undersample.fit_resample(X, y)

# summarize the new class distribution
counter = Counter(y)
print(counter)

# scatter plot of examples by class label
for label, _ in counter.items():
	row_ix = where(y == label)[0]
	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        但是我觉得这个CondensedNearestNeighbour()方法的可操作性太低,所以没用这个方法,而是根据CNN的原理(CNN底层是训练KNN)去写的

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where

# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,
                                            n_informative=2, n_redundant=0, n_repeated=0,
                                            n_classes=4, n_clusters_per_class=1, random_state=1)
counter = Counter(y)
# 画出二维散点图
for label, _ in counter.items():
	row_ix = where(y == label)[0]
	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)


while True:
	k = 1
	KNN_clf = KNN(n_neighbors=k)
	KNN_clf.fit(x_train, y_train)
	y_predict = KNN_clf.predict(x_test)

	cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True
	# 都分类正确,退出
	if  cond.all():
		print('所有测试集都分类正确,CNN正常结束')
		break

	x_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里
	y_train = np.hstack([y_train, y_test[~cond]])
	x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集
	y_test = y_test[cond]

	if len(x_test) == 0:
		print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")
		break


# summarize the new class distribution
counter = Counter(y_train)
print(counter)

# 画出二维散点图
for label, _ in counter.items():
	row_ix = where(y_train == label)[0]
	pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

2.1 改进版——指定压缩后样本大小的CNN

在如下代码中,用sampleNum指定全体样本数量,用endNum指定压缩后样本数量

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where


sampleNum = 1000
endNum = 500
k = 1  # KNN算法的K值
# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=sampleNum, n_features=2,
                                            n_informative=2, n_redundant=0, n_repeated=0,
                                            n_classes=4, n_clusters_per_class=1, random_state=1)
# counter = Counter(y)
# # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()

# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)
# print(x_train.shape[0])  # 100

nowNum = x_train.shape[0]  # 用来控制 训练集/筛选后的样本数 满足resultNum就停下, 初始有x_train这么多个

while True:
	KNN_clf = KNN(n_neighbors=k)
	KNN_clf.fit(x_train, y_train)
	y_predict = KNN_clf.predict(x_test)
	cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True
	# 都分类正确,退出
	if cond.all():
		print('所有测试集都分类正确,CNN自动结束,但是结果集没凑够呢!')
		break

	# 如果结果集数量不够要求的endNum,继续下一轮
	if nowNum+y_test[~cond].shape[0] < endNum:
		nowNum = nowNum+y_test[~cond].shape[0]
		print("目前结果集数量:", nowNum)
		x_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里
		y_train = np.hstack([y_train, y_test[~cond]])
		x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集
		y_test = y_test[cond]
	# 如果结果集数量超过endNum,我们只要测试集里分类错误的前endNum-nowNum个
	else:
		# 记录前endNum-nowNum个的位置(截取位置
		condCut = 0  # 记录截取位置
		for i in range(cond.shape[0]):
			if not cond[i]:
				nowNum = nowNum + 1
			if nowNum == endNum:
				condCut = i  # 在cond[condCut]处刚好是我们要的第endNum个结果集样本
				break
		# 把cond[condCut]后面的都设置成True
		cond[condCut+1:] = True
		x_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里
		y_train = np.hstack([y_train, y_test[~cond]])
		print("结果集的数量为", x_train.shape[0], "满足endNum=", endNum)
		break

	if len(x_test) == 0:
		print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")
		break


# summarize the new class distribution
counter = Counter(y_train)
print(counter)

# 画出二维散点图
for label, _ in counter.items():
	row_ix = where(y_train == label)[0]
	pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1460790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Paddlepaddle使用自己的VOC数据集训练目标检测(0废话简易教程)

一 安装paddlepaddle和paddledection&#xff08;略&#xff09; 笔者使用的是自己的数据集 二 在dataset目录下新建自己的数据集文件&#xff0c;如下&#xff1a; 其中 xml文件内容如下&#xff1a; 另外新建一个createList.py文件&#xff1a; # -- coding: UTF-8 -- imp…

云打印api接口收费吗?

随着近来云打印服务的发展&#xff0c;越来越多的用户都开始选择云打印服务。很多工具类、学习累的App和软件看到了这其中的甜头&#xff0c;也都想要对接云打印业务来完成变现。对接云打印服务则需要找到合适的平台进行api对接。那么云打印api接口收费吗&#xff1f;收费标准是…

TF卡辨别指南|拓优星辰

在存储领域&#xff0c;TF卡&#xff08;MicroSD卡&#xff09;是一种常见的存储设备&#xff0c;但市场上也存在着各种品牌和型号。为了帮助用户准确辨别TF卡&#xff0c;我们提供了以下辨别指南&#xff0c;以确保用户能够选择符合其需求的高性能、高可靠性的TF卡。 二、外观…

数据结构笔记1线性表及其实现

终于开始学习数据结构了 c语言结束之后 我们通过题目来巩固了 接下来我们来学习数据结构 这里我们将去认识到数据结构的一些基础知识&#xff0c;我在第一次学习的时候会很迷糊现在重新学习发现之前的时候还是因为c语言学的不牢固导致学习数据结构困难 这里 我会尽量的多写代码…

fast-planner代码解读【kino_replan_fsm.cpp】

概述 kino_replan_fsm.cpp订阅实时定位和目标点信息&#xff0c;每隔0.01s执行一次状态机&#xff0c;进行状态切换&#xff1b;每隔0.05s执行一次碰撞检测&#xff0c;按需进行重新规划。核心为执行变量exec_state_ 主要函数及作用 KinoReplanFSM::init 输入&#xff1a;句…

SD-WAN解决方案:企业异地组网挑战之视频会议

随着企业的发展&#xff0c;不少企业开始面临规模扩大、分公司组建、异地办公的需求。其中&#xff0c;远程视频会议作为企业异地管理和运营的重要组成部分&#xff0c;对网络稳定性和视频传输质量有较高的要求。在本文&#xff0c;我们将探讨企业视频会议遇到的网络问题以及这…

SpringBoot+Vue+MySQL:图书管理系统的技术革新

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

面试经典150题——生命游戏

​"Push yourself, because no one else is going to do it for you." - Unknown 1. 题目描述 2. 题目分析与解析 2.1 思路一——暴力求解 之所以先暴力求解&#xff0c;是因为我开始也没什么更好的思路&#xff0c;所以就先写一种解决方案&#xff0c;没准写着写…

2月21日

Bean生命周期 过程概述 创建对象 实例化(构造方法) 依赖注入 初始化 执行Aware接口回调 执行BeanPostProcessor.psotProcessBeforeInitialization 执行InitializingBean回调(先执行PostConstruct) 执行BeanPsotProcessor.postProcessAfterInitialization 使用对象 销毁对象…

Javaweb之SpringBootWeb案例之切入点表达式的详细解析

3.3 切入点表达式 从AOP的入门程序到现在&#xff0c;我们一直都在使用切入点表达式来描述切入点。下面我们就来详细的介绍一下切入点表达式的具体写法。 切入点表达式&#xff1a; 描述切入点方法的一种表达式 作用&#xff1a;主要用来决定项目中的哪些方法需要加入通知 …

ffmpeg TS复用代码详解——mpegtsenc.c

一、mpegtsenc.c 整体架构 二、主要函数 mpegts_write_pes(AVFormatContext *s, AVStream *st, const uint8_t *payload, int payload_size, int64_t pts, int64_t dts)这个函数就是TS打包的主函数了&#xff0c;这个函数主要功能就是把一帧数据拆分成188字节的TS包&#xff0…

自助点餐系统微信小程序,支持外卖、到店等

总体介绍 系统总共分为三个端&#xff1a;后端&#xff0c;后台管理系统、微信小程序。 基于当前流行技术组合的前后端分离商城系统&#xff1a; SpringBoot2MybatisPlusSpringSecurityjwtredisVue的前后端分离的商城系统&#xff0c; 包含分类、sku、积分、多门店等 预览图…

FariyGUI × Cocos Creator 入门

前言 程序员向的初探Cocos Creator结和FairyGUI的使用&#xff0c;会比较偏向FairyGUI一点&#xff0c;默认各位读者都熟练掌握Cocos Creator以及js/ts脚本编写。 初探门径&#xff0c;欢迎大佬指教&#xff0c;欢迎在评论区或私信与本人交流&#xff0c;谢谢&#xff01; 下…

DBSCAN密度聚类介绍 样本点 样本集合 半径 邻域 核心对象 边界点 密度直达 密度可达 密度相连

DBSCAN密度聚类介绍 样本点 样本集合 半径 邻域 核心对象 边界点 密度直达 密度可达 密度相连 简介概念定义原理DBSCAN的优点DBSCAN的缺点小尝试制作不易&#xff0c;感谢三连&#xff0c;谢谢啦 简介 DBSCAN&#xff08;Density-Based Spatial Clustering of Applications wi…

Codeforces Round 927 (Div. 3)(A,B,C,D,E,F,G)

这场简单些&#xff0c;E题是个推结论的数学题&#xff0c;沾点高精的思想。F是个需要些预处理的DP&#xff0c;G题是用exgcd算边权的堆优化dijkstra。C题有点骗&#xff0c;硬啃很难做。 A Thorns and Coins 题意&#xff1a; 在你的电脑宇宙之旅中&#xff0c;你偶然发现了…

LeetCode 0105.从前序与中序遍历序列构造二叉树:分治(递归)——五彩斑斓的题解(若不是彩色的可以点击原文链接查看)

【LetMeFly】105.从前序与中序遍历序列构造二叉树&#xff1a;分治&#xff08;递归&#xff09;——五彩斑斓的题解&#xff08;若不是彩色的可以点击原文链接查看&#xff09; 力扣题目链接&#xff1a;https://leetcode.cn/problems/construct-binary-tree-from-preorder-a…

java数据类型、运算符

一、数据的表示详解 1.1 整数在计算机中的存储原理 任何数据在计算机中都是以二进制表示的。那这里肯定有人问&#xff0c;什么是二进制啊&#xff1f;所谓二进制其实就是一种数据的表示形式&#xff0c;它的特点是逢2进1。 数据的表示形式除了二进制&#xff08;逢2进1&…

Https证书续签-acme.sh-腾讯云之DnsPod

ename 域名切换到 DnsPod 上面解析 可以先看下之前的 acme.sh 介绍文章然后再来次补充更多。 之前说过了 acme.sh 在阿里云下的使用。 这里做个后续补充 之前的域名是在 ename 上的 &#xff0c;为了自动续签切换到 DnsPod 上面解析 注意事项 可以把原来 ename 上的解析先导出…

Android全新UI框架之Jetpack Compose入门基础

Jetpack Compose是什么 如果有跨端开发经验的同学&#xff0c;理解和学习compose可能没有那么大的压力。简单地说&#xff0c;compose可以让Android的原生开发也可以使用类似rn的jsx的语法来开发UI界面。以往&#xff0c;我们开发Android原生页面的时候&#xff0c;通常是在xml…