基于粒子群优化的中文文本分类

news2024/10/6 18:25:28

基本思路:

方法:使用优化算法(如粒子群)优化支持向量机SVM;

本文所使用的应用背景:中文文本分类(同时可以应用到其他背景领域,如)

应用背景(元启发式算法优化SVM):

  1. 图像分类和识别:图像分类和识别,如人脸识别、数字识别等
  2. 自然语言处理:SVM可以用于文本分类和情感分析等自然语言处理任务。它可以通过学习文本特征来判断一个文本属于哪个类别或者情感极性。
  3. 金融预测:SVM可以用于股票价格预测、信用评级、欺诈检测等金融领域的问题。它可以通过学习历史数据来预测未来的趋势或者检测异常交易。
  4. 医学诊断:SVM可以用于医学诊断,如癌症分类、药物分子筛选等。它可以通过学习医学数据特征来辅助医生进行诊断和治疗。
  5. 视频分类和检测:SVM可以用于视频分类和检测,如视频目标跟踪、行人检测等。它可以通过学习视频特征来判断一个视频中的目标或者行为。

本文所使用的数据集

数据集:搜狐新闻文本语料库

其他可使用的数据集:复旦大学文本语料库

参考文章:

搜狐新闻文本分类:机器学习大乱斗 - 简书

【Python NLP】:搜狗语料库-新闻语料处理_搜狗新闻语料库_QuantCoder的博客-CSDN博客

源码部分:

工具:pycharm,

PSO代码:

import numpy as np
import random
import copy
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


''' 种群初始化函数 '''
def initial(pop, dim, ub, lb):
    X = np.zeros([pop, dim])
    for i in range(pop):
        for j in range(dim):
            X[i, j] = random.random()*(ub[j] - lb[j]) + lb[j]
    
    return X,lb,ub
            
'''边界检查函数'''
def BorderCheck(X,ub,lb,pop,dim):
    for i in range(pop):
        for j in range(dim):
            if X[i,j]>ub[j]:
                X[i,j] = ub[j]
            elif X[i,j]<lb[j]:
                X[i,j] = lb[j]
    return X
    
    
'''计算适应度函数'''
def CaculateFitness(X,fun):
    pop = X.shape[0]
    fitness = np.zeros([pop, 1])
    for i in range(pop):
        fitness[i] = fun(X[i, :])
    return fitness

'''适应度排序'''
def SortFitness(Fit):
    fitness = np.sort(Fit, axis=0)
    index = np.argsort(Fit, axis=0)
    return fitness,index


'''根据适应度对位置进行排序'''
def SortPosition(X,index):
    Xnew = np.zeros(X.shape)
    for i in range(X.shape[0]):
        Xnew[i,:] = X[index[i],:]
    return Xnew


'''粒子群算法'''
def PSO(pop,dim,lb,ub,MaxIter,fun,Vmin,Vmax):
    # 参数设置
    w = 0.9      # 惯性因子
    c1 = 2       # 加速常数
    c2 = 2       # 加速常数
    X,lb,ub = initial(pop, dim, ub, lb) #初始化种群
    V,Vmin,Vmax = initial(pop, dim, Vmax, Vmin) #初始速度
    fitness = CaculateFitness(X,fun) #计算适应度值
    fitness,sortIndex = SortFitness(fitness) #对适应度值排序
    X = SortPosition(X,sortIndex) #种群排序
    GbestScore = copy.copy(fitness[0])
    GbestPositon = copy.copy(X[0,:])
    Curve = np.zeros([MaxIter,1])
    Pbest = copy.copy(X)
    fitnessPbest = copy.copy(fitness)
    for i in range(MaxIter):
        for j in range(pop):
           #速度更新
           V[j,:] = w*V[j,:] + c1*np.random.random()*(Pbest[j,:] - X[j,:]) + c2*np.random.random()*(GbestPositon - X[j,:])
           #速度边界检查
           for ii in range(dim):
               if V[j,ii]<Vmin[ii]:
                   V[j,ii]=Vmin[ii]           
               if V[j,ii]>Vmax[ii]:
                   V[j,ii] = Vmax[ii]
            #位置更新
           X[j,:] = X[j,:] + V[j,:]
            #位置边界检查
           for ii in range(dim):
               if X[j,ii]<lb[ii]:
                   V[j,ii]=lb[ii]           
               if X[j,ii]>ub[ii]:
                   V[j,ii] = ub[ii]
           fitness[j] = fun(X[j,:])
           if fitness[j]<fitnessPbest[j]:
               Pbest[j,:]=copy.copy(X[j,:])
               fitnessPbest[j] = copy.copy(fitness[j])
           if fitness[j]<GbestScore[0]:
               GbestScore[0] = copy.copy(fitness[j])
               # 修改  2022/07/21 04:51
               #if(fitnessPbest[j]>0):
               GbestPositon = copy.copy(X[j,:])


               
        Curve[i] = GbestScore
    
    return GbestScore,GbestPositon,Curve
    









PSO_SVM代码:

# 导包
import re
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
import numpy as np
from sklearn.metrics import accuracy_score
import PSO_.PSO # 该行是导入粒子群PSO的源码,可能需要做对应修改
import sys
import os
import time
sys.path.append(os.path.abspath("../")) 
#上一行代码和主要是我将PSO.py和PSO_SVM未放到同一个文件夹
#PSO_SVM放入到了202207SSA_SVM_combin,PSO.py放到了PSO_文件夹中
#两个文件夹同级目录
token = "[0-9\s+\.\!\/_,$%^*()?;;:【】+\"\'\[\]\\]+|[+——!,;:。?《》、~@#¥%……&*()“”.=-]+"
stopwords = open('../dict/stop_words.txt', encoding='utf-8').read().split()    # read() split()
#上两行分词去停用词

def preprocess(text):
    text1 = re.sub('&nbsp', ' ', text)          #去掉text中的 空格
    str_no_punctuation = re.sub(token, ' ', text1)  # 去掉标点
    text_list = list(jieba.cut(str_no_punctuation))  # 分词列表
    text_list = [item for item in text_list if item != ' ']  # 去掉空格
    return ' '.join(text_list)

def load_datasets():
    base_dir = '../data/'  
#base_idr需要注意!放的是对应的数据集
    X_data = {'train': [], 'test': []}
    y = {'train':[], 'test':[]}
    for type_name in ['train', 'test']:
        corpus_dir = os.path.join(base_dir, type_name)
        corpus_list = []                                ###
        for label in os.listdir(corpus_dir):            #
            label_dir = os.path.join(corpus_dir, label) #标签
            file_list = os.listdir(label_dir)           #标签目录下对应的文件数
            print("label: {}, len: {}".format(label, len(file_list)))

            for fname in file_list:
                file_path = os.path.join(label_dir, fname)
                with open(file_path, encoding='gb2312', errors='ignore') as text_file:
                    text_content = preprocess(text_file.read())           # 使用preprocess处理文件的内容
                X_data[type_name].append(text_content)                  #
                y[type_name].append(label)                              #

        print("{} corpus len: {}\n".format(type_name, len(X_data[type_name])))

    return X_data['train'], y['train'], X_data['test'], y['test']

def fun(parameter): ##适应度函数!优化算法的关键
    #data_train, label_train, data_test, label_test = load_datasets(
    # 自己单独修改的 20220721 5:02
    # 下面的if多余了,只要__main__方法中的lb 和ub的范围即可
    if(parameter[0]<=0 ):
        parameter[0] = -1 * parameter[0] + 0.1
    if(parameter[1]<0 ):
        parameter[1] = -1 * parameter[1]
    text_clf_svm = Pipeline([
        ('vect', TfidfVectorizer()),
        ('svm_clf', SVC(C=parameter[0], kernel='rbf', gamma=parameter[1])),
    ])
    text_clf_svm.fit(X_train_data, y_train)


    y_predict  = text_clf_svm.predict(X_test_data)
    acc = accuracy_score(y_test, y_predict)
    print(1-acc)
    return 1-acc

if __name__ == '__main__':
    start_time = time.time()  # 设置起始的时间

    X_train_data, y_train, X_test_data, y_test = load_datasets()
    # 设置参数
    pop = 2  # 种群数量 默认是50
    MaxIter = 2  # 最大迭代次数 默认是1000
    dim = 2  # 维度 默认是10最后输出时设置10会报错的
    lb = np.matrix([[0.1], [0.1]])  # 下边界
    ub = np.matrix([[200], [200]])  # 上边界
    Vmin = -5 * np.ones([dim, 1])  # 速度下边界
    Vmax = 5 * np.ones([dim, 1])  # 速度上边界
    fobj = fun

    # 设置麻雀参数
    # pop = 20  # 种群数量
    # MaxIter = 50  # 最大迭代次数
    # dim = 2  # 维度
    # lb = np.matrix([[0.1], [0.1]])  # 下边界
    # ub = np.matrix([[200], [200]])  # 上边界
    # fobj = fun

    # 出错的粒子群PSO, 多了参数Vmin,Vmax 先设置-1 和 1 试试,还是有问题
    GbestScore, GbestPositon, Curve = PSO_.PSO.PSO(pop, dim, lb, ub, MaxIter, fobj, Vmin, Vmax)

    print('最优适应度值:', GbestScore)
    print('c,g最优解:', GbestPositon)

    # # 构建 SVM分类器
    print("======SVM")
    text_clf_svm = Pipeline([
        ('vect', TfidfVectorizer()),
        ('svm_clf', SVC(C=GbestPositon[0], kernel='rbf', gamma=GbestPositon[1])),
    ])
    text_clf_svm.fit(X_train_data, y_train)
    predicted_svm = text_clf_svm.predict(X_test_data)
    print(classification_report(predicted_svm, y_test))
    print('completed')
    elapsed_time = time.time() - start_time  # 设置截止时间
    print('inference time cost: {}'.format(elapsed_time))  # 输出消耗的时间

 

运行结果这样:(由于种群数和迭代次数只设置了,效果会较差,而且PSO可能存在局部最优解等问题)

 

改进思路1:

使用麻雀搜索算法/改进的麻雀搜索算法对支持向量机进行优化;

使用其他优化算法对支持向量机进行优化;

目前可以参考的其他优化算法:资源文件正在审核,后面可能放在评论区

改进思路2:

使用优化算法优化LSTM等深度学习模型的相关参数!(目前自己在情感分析中有实现,最差和最好的效果大概相差2%-3%左右)

完整代码文件正在审核,后面可能放在评论区

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/508356.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(学习日记)2023.5.9

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

QTableview常用几种代理总结

QTableview常用几种代理总结 [1] QTableview常用几种代理总结1、QCheckBox和QRadioButton的嵌入2、QHeadView中嵌入QCheckBox类3、QCombobox的嵌入4、 QCombox QCheckBox类5、SpinBox的嵌入类6、QProcess的嵌入类7、QProcess绘制版本的嵌入类8、QPushButton/QLabel/QImage的嵌…

鸿蒙Hi3861学习八-Huawei LiteOS-M(事件标记)

一、简介 事件是一种实现任务间通信的机制&#xff0c;可用于实现任务间的同步。但事件通信只能是事件类型的通信&#xff0c;无数据传输。一个任务可以等待多个事件的发生&#xff1a;可以是任意一个事件发生时唤醒任务进行事件处理&#xff1b;也可以是几个事件都发生后才唤醒…

mongodb副本集搭建

1.本次搭建使用三台centos7主机搭建集群&#xff0c;关闭防火墙和selinux服务 2.主机信息如下图所示 主机名称IPPortServiceA10.1.60.11427017mongodbB10.1.60.11527017mongodbC10.1.60.11827017mongodb 3.从官网下载mongodb安装包(我这里下载的是6.0.5版本的tgz包) Instal…

小家电LED显示驱动多功能语音芯片IC方案 WT2003H4 B002

随着时代的进步&#xff0c;智能家电的普及已经成为了一个趋势。而在智能家电中&#xff0c;LED显示屏也成为了不可或缺的一部分。因此&#xff0c;在小家电的设计中&#xff0c;LED显示驱动芯片的应用也越来越广泛。比如&#xff1a;电饭煲、电磁炉、数字时钟、咖啡机、电磁炉…

【Vue3】如何创建Vue3项目及组合式API

文章目录 前言 一、如何创建vue3项目&#xff1f; ①使用 vue-cli 创建 ②使用可视化ui创建 ③npm init vite-app ④npm init vuelatest 二、 API 风格 2.1 选项式 API (Options API) 2.2 组合式 API (Composition API) 总结 前言 例如&#xff1a;随着前端领域的不断发展&am…

【SSM框架】SpringMVC 中常见的注解和用法

SSM框架 SpringMVC 中常见的注解和用法基础注解介绍RequestMapping 注解介绍PostMapping 和 GetMapping 注解介绍 获取参数相关注解的介绍只通过 RequestMapping 来获取参数只传递一个参数传递对象参数传递多个参数(非对象) RequestParam 后端参数重命名required 必传参数的设置…

SpringBoot+Redis+自定义注解实现接口防刷(限制不同接口单位时间内最大请求次数)

场景 SpringBoot搭建的项目需要对开放的接口进行防刷限制&#xff0c;不同接口指定多少秒内可以请求指定次数。 比如下方限制接口一秒内最多请求一次。 注&#xff1a; 博客&#xff1a;霸道流氓气质的博客_CSDN博客-C#,架构之路,SpringBoot领域博主 实现 1、实现思路 首…

flink学习37:DataStream/DataSet与Table的互相转换

DataStream/DataSet转换成视图 DataStream/DataSet转换成表 表转换成DataStream/DataSet 表转换为DataStream/DataSet时&#xff0c;需要指定字段数据类型&#xff0c;最方便的就是把数据类型定为row&#xff0c;即行数据。 两种模式&#xff1a; 把表转为dataStream 把表转为d…

100ASK-V853-PRO编译烧写

100ASK_V853-PRO 环境配置及编译烧写 0.前言 本章主要介绍关于100ASK_V853-PRO开发板的Tina SDK包的下载和编译打包生成镜像&#xff0c;并将镜像烧录到100ASK_V853-PRO开发板上。在进行100ASK_V853-PRO开发板的环境配置前需要获取配置虚拟机系统&#xff0c;可以参考&#x…

【C语言督学训练营 第十二天】三篇文章吃透数据结构中的线性表(三)----- 线性表考研真题

文章目录 前言题目描述题目分析代码实战 前言 本篇博客从头到尾都在解析一道2019年考研真题中的一道关于链表的大题&#xff0c;虽然题目没有竞赛算法题那么复杂&#xff0c;那么难想&#xff0c;但是我们依旧可以从中收获到好多知识&#xff0c;本题的突破点就是快慢指针与链…

AP5153 线性降压恒流驱动芯片 2.5A

AP5153 是一种 PWM 调光的、低压 差的 LED 线性降压恒流驱动器。 AP5153 仅需要外接一个电阻和一个 NMOS 管就可以构成一个完整的 LED 恒 流驱动电路&#xff0c; 调节该外接电阻就可以调节 输出电流&#xff0c;输出电流可调范围为 20mA 到 3.0A。 AP5153 还可以通过在 DIM…

echarts x轴与y轴 刻度 数据设置

xAxis: {nameTextStyle: {fontWeight: "bold",fontSize: "20",align: "left",},splitLine: {show: false,},axisLine: {show: true,symbol: ["none", "arrow"], //加箭头处symbolOffset: 0,lineStyle: {color: "rgb(12…

aardio的优缺点,强烈推荐大家试用一下,可以用它在windows下面写一些小工具

概述 官网 aardio是一种用于Windows平台的脚本编程语言&#xff0c;以及一个功能丰富的集成开发环境&#xff08;IDE&#xff09;。它结合了强大的原生Windows API访问能力和简单易学的语法。以下是aardio的一些优缺点。 优点&#xff1a; 简单易学&#xff1a;aardio的语法简…

Linux | 本地Yum源 | 网络Yum源(阿里云Yum源)

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; 本地Yum源配置 创建挂载点目录 [rootlocalhost ~]# mkdir /mnt/cdrom [rootlocalhost ~]# df /mnt/cdrom/ 文件系统 1K-块 已用 可用 已用%…

慎入坑:腾讯云轻量2核2G3M服务器30元不建议选择

腾讯云轻量应用服务器2核2G3M带宽30元3个月不建议买&#xff0c;自带3M带宽&#xff0c;下载速度可达384KB/秒&#xff0c;100%CPU性能&#xff0c;系统盘为40GB SSD盘&#xff0c;200GB月流量&#xff0c;折合每天6.6G流量&#xff0c;地域节点可选上海/广州/北京&#xff0c;…

React Router 6 函数式组件withRouter 路由属性配置

withRouter为解决开发过程中函数组件路由参数获取问题&#xff0c;之前版本的withRouter是直接可以导入使用的&#xff0c;现在的需要手写 这里使用了hooks&#xff0c;获取路由、参数等相关信息 需要在函数式组件内使用props&#xff0c;用法&#xff1a; 1.需要先使用高阶组…

K8s常见面试题20问

K8s常见面试题19问 收集了一些K8s常见问题和同学们面试常被问到的问题. 如果有新的面试题私聊或者留言给我 1. Docker和虚拟机有那些不同 虚拟化环境下每个 VM 是一台完整的计算机&#xff0c;在虚拟化硬件之上运行所有组件&#xff0c;包括其自己的操作系统。 容器之间可以共…

MySQL数据库备份并还原

使用Navicat和命令行备份并恢复数据库 第三方工具备份并恢复步骤1步骤2步骤3步骤4&#xff1a;步骤5 命令行方式备份并恢复&#xff1a;步骤1步骤2步骤3步骤4 第三方工具备份并恢复 步骤1 步骤2 在弹出的窗口上选择要备份的路径&#xff0c;单击保存&#xff0c;下图为备份完…

MySQL与Hadoop数据同步方案:Sqoop与Flume的应用探究【上进小菜猪大数据系列】

&#x1f4ec;&#x1f4ec;我是上进小菜猪&#xff0c;沈工大软件工程专业&#xff0c;爱好敲代码&#xff0c;持续输出干货&#xff0c;欢迎关注。 MySQL与Hadoop数据同步 随着大数据技术的发展&#xff0c;越来越多的企业开始采用分布式系统和云计算技术来处理和存储海量数…