NLP之文本分类模型调优(模型基于tensorflow1.14)

news2025/1/23 10:34:10

项目链接:https://pan.baidu.com/s/1yOu0DogWkL8WOJksJmeiPw?pwd=4bsg 提取码: 4bsg 复制这段内容后打开百度网盘手机App,操作更方便哦 
--来自百度网盘超级会员v4的分享

1.基于上一篇文章中的文本分类项目进行精度调优,提升模型准确率,需要改动的地方除了参数以外,就是模型结构,这里主要就是进行模型结构调优,因此代码只需要修改rnn_model.py,修改前后代码如下:

原模型结构:

双层LSTM+两层全连接神经网络+softmax

 原模型代码:

#!/usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf

class TRNNConfig(object):
    """RNN配置参数"""

    # 模型参数
    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    vocab_size = 5000       # 词汇表达小

    num_layers= 2           # 隐藏层层数
    hidden_dim = 128        # 隐藏层神经元
    rnn = 'gru'             # lstm 或 gru

    dropout_keep_prob = 0.8 # dropout保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 128         # 每批训练大小
    num_epochs = 10          # 总迭代轮次

    print_per_batch = 100    # 每多少轮输出一次结果
    save_per_batch = 10      # 每多少轮存入tensorboard


class TextRNN(object):
    """文本分类,RNN模型"""
    def __init__(self, config):
        self.config = config

        # 三个待输入的数据
        self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')#定义dropout占位符,输入失活比例

        self.rnn()

    def rnn(self):
        """rnn模型"""

        def lstm_cell():   # lstm核
            return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)

        def gru_cell():  # gru核
            return tf.contrib.rnn.GRUCell(self.config.hidden_dim)

        def dropout(): # 为每一个rnn核后面加一个dropout层
            if (self.config.rnn == 'lstm'):
                cell = lstm_cell()
            else:
                cell = gru_cell()
            return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

        # 词向量映射
        with tf.device('/cpu'):
            embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])  # [5000, 64]
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)  # [?, 600, 64]

        with tf.name_scope("lstm"):
            '''
            TF 低版本RNN多层获取cell,使用一下代码
            cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
            cell = rnn.MultiRNNCell([cell] * 2, state_is_tuple=True)
            '''
            # 多层rnn网络
            cells = [dropout() for _ in range(self.config.num_layers)]
            rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)

            _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)  # [?, 600, 128]
            last = _outputs[:, -1, :]  # 取最后一个时序输出作为结果  # [?, 128]

        with tf.name_scope("score"):
            # 全连接层,后面接dropout以及relu激活
            fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            fc=tf.layers.dense(fc,units=64,name='fc2')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            fc = tf.layers.dense(fc, units=32, name='fc3')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)


            # 分类器
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc4')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别

        with tf.name_scope("optimize"):
            # 损失函数,交叉熵
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
            self.loss = tf.reduce_mean(cross_entropy)
            # 优化器
            self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope("accuracy"):
            # 准确率
            correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

修改为双层双向LSTM+4层全连接神经网络+softmax

#!/usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf

class TRNNConfig(object):
    """RNN配置参数"""

    # 模型参数
    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    vocab_size = 5000       # 词汇表达小

    num_layers= 2           # 隐藏层层数
    hidden_dim = 128        # 隐藏层神经元
    rnn = 'gru'             # lstm 或 gru

    dropout_keep_prob = 0.8 # dropout保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 128         # 每批训练大小
    num_epochs = 10          # 总迭代轮次

    print_per_batch = 100    # 每多少轮输出一次结果
    save_per_batch = 10      # 每多少轮存入tensorboard


class TextRNN(object):
    """文本分类,RNN模型"""
    def __init__(self, config):
        self.config = config

        # 三个待输入的数据
        self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')#定义dropout占位符,输入失活比例

        self.rnn()

    def rnn(self):
        """rnn模型"""

        def lstm_cell():   # lstm核
            return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)

        def gru_cell():  # gru核
            return tf.contrib.rnn.GRUCell(self.config.hidden_dim)

        def bilstm_cell():
            return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim)

        def dropout(): # 为每一个rnn核后面加一个dropout层
            if (self.config.rnn == 'lstm'):
                cell = lstm_cell()

            else:
                cell = gru_cell()
            return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

        # 词向量映射
        with tf.device('/cpu'):
            embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])  # [5000, 64]
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)  # [?, 600, 64]

        with tf.name_scope("lstm"):
            '''
            TF 低版本RNN多层获取cell,使用一下代码
            cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
            cell = rnn.MultiRNNCell([cell] * 2, state_is_tuple=True)
            '''
            cell_fw=[dropout() for _ in range(self.config.num_layers)]
            cell_bw=[dropout() for _ in range(self.config.num_layers)]
            lstm_forward = tf.contrib.rnn.MultiRNNCell(cells=cell_bw, state_is_tuple=True)
            lstm_backward=tf.contrib.rnn.MultiRNNCell(cells=cell_fw, state_is_tuple=True)
            outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_forward, cell_bw=lstm_backward, inputs=embedding_inputs,
                                                              dtype=tf.float32)
            outputs_fw = outputs[0]
            outputs_bw = outputs[1]
            last = outputs_fw[:, -1, :] + outputs_bw[:, 0, :]  # 取最后一个时序输出作为结果  # [?, 128]

        with tf.name_scope("score"):
            # 全连接层,后面接dropout以及relu激活
            fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            fc=tf.layers.dense(fc,units=64,name='fc2')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            fc = tf.layers.dense(fc, units=32, name='fc3')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)


            # 分类器
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc4')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别

        with tf.name_scope("optimize"):
            # 损失函数,交叉熵
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
            self.loss = tf.reduce_mean(cross_entropy)
            # 优化器
            self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope("accuracy"):
            # 准确率
            correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

修改结束后,记得在原项目中的TextRNN\checkpoints\textrnn中将训练好的模型删除,避免冲突。

注:如使用tensorflow gpu版本进行模型训练,将with tf.device('/cpu'):改为with tf.device('/gpu:0'):

注:可通过此命令下载scikit-learn包:conda install -c anaconda scikit-learn  注:(安装sklearn)

注:conda环境:
conda create -n 环境名 python=版本号   : 新建虚拟环境
activate 环境名 进入环境
conda deactivate 退出
conda env list 显示环境
conda remove -n 环境名 --all

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/91115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CET-4 week10 语法

0基础入门 point 谓语中自带 助动词 Such as ,I do like computer (强调且默认带有这个属性 大多数情况隐藏状态) 所有时态都有被动 do ->donewill do -> will be donehave down ->have been downbe doing ->be being donedid -> been downhad down ->ha…

flink-cdc-connectors-release-2.3.0自己编译

最新的cdc是2.21仅支持flink 1.13* 1.14*,而flink已经有1.15版本;自己编译支持1.15 下载官方包 https://github.com/ververica/flink-cdc-connectors/releases/tag/release-2.3.0 我下载的是source包,大家也可以去下载源码 1、下载后我们只需…

JUC并发编程第十三篇,AQS的作用与体系结构

JUC并发编程第十三篇,AQS的作用与体系结构一、AQS是什么?二、AQS在JUC中的地位与作用三、AQS体系结构一、AQS是什么? AbstractQueuedSynchronizer(抽象队列同步器),是用来构建锁或者其它同步器组件的重量级…

[附源码]Node.js计算机毕业设计高校社团管理系统Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

java EE初阶 — 多线程案例单例模式

文章目录1单例模式主要模式1.1 饿汉模式1.2 懒汉模式2 单例模式安全性问题1单例模式主要模式 在某些场景中,有些特定的类只能输出一个实例(对象),不应该创建多个实例,此时就可以使用 单例模式。 使用了单例模式后&…

5款轻量级小软件,第一款更是近期必备!

今天的主题是简洁,轻便,都是轻量级的小软件,界面都是非常简洁,而且无广告的。 1.自动抢火车票工具——12306Bypass 12306Bypass是一款专用于帮助用户抢购火车车票的工具,春运马上就到了,又到了抢票回家的…

Docker数据卷操作

1. 为什么使用数据卷 卷是在一个或多个容器内被选定的目录,为docker提供持久化数据或共享数据,是docker存储容器生成和使用的数据的首选机制。对卷的修改会直接生效,当提交或创建镜像时,卷不被包括在镜像中。 总结为两个作用&am…

功率放大器在压电传感器矩形阵列成像研究中的应用

实验名称:激光和压电传感器密集型矩形阵列成像质量的比较分析 研究方向:Lamb波、无损检测、缺陷成像和定位 测试目的: 将密集型矩形阵列分别与压电传感器检测技术和激光检测技术相结合,利用幅值成像和符号相干因子成像实现对铝板结…

vector模拟实现下篇及迭代器失效和深浅拷贝问题详解

文章目录1:构造函数1.1默认构造函数1.2迭代器构造1.3用n个val构造1.4拷贝构造2:operator3:析构函数和clear4:迭代器失效问题4.1:删除偶数深浅拷贝1:构造函数 1.1默认构造函数 vector():_start(nullptr),_end(nullptr),_endofstorage(nullptr){}1.2迭代器构造 template<clas…

手动安装Kylin5.0版本的过程

官方文档 https://kylin.apache.org/目前kylin3,4版本是有docker版本和安装包的,5.0只有docker没有安装包 安装包 https://kylin.apache.org/download/安装kylin5.0 Kylin5.0文档拉取镜像 docker pull apachekylin/apache-kylin-standalone:5.0.0运行镜像 docker run -d \ …

linux-jdk、nginx

一、安装nginx Nginx是一个web服务器也可以用来做负载均衡及反向代理使用&#xff0c; 目前使用最多的就是负载均衡&#xff0c;这篇文章主要介绍了centos8 安装 nginx Nginx是一种开源的高性能HTTP和反向代理服务器&#xff0c;负责处理Internet上一些最大站点的负载。 它…

数据结构——重点代码汇总

顺序表 设计算法&#xff0c;从顺序表L中删除值为x的元素。要求算法的时间复杂度为O(n)&#xff0c;空间复杂度为O(1)。设计算法&#xff0c;判断一个字符串是否是回文。如abc3cba是回文序列&#xff0c;而1331不是回文序列。从顺序存储结构的线性表a中删除第i个元素起的k个元…

SuperMap GIS 倾斜摄影数据处理 QA

一、倾斜摄影数据简介 倾斜摄影&#xff08;Oblique photography&#xff09;是指由一定倾斜角度的航摄相机所获取的影像。倾斜摄影技术是国际摄影测量领域近十几年发展起来的一项高新技术&#xff0c;该技术通过从一个垂直、四个倾斜、五个不同的视角同步采集影像&#xff0c…

深度学习训练营之天气识别P3

深度学习训练营之天气识别原文链接环境介绍前置工作设置GPU导入数据数据查看数据预处理加载数据可视化数据检查数据配置数据集prefetch()功能详细介绍&#xff1a;构建CNN网络编译模型训练结果可视化原文链接 &#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记…

卫龙上市首日破发:高瓴、红杉、腾讯等账面亏损,刘卫平为董事长

12月15日&#xff0c;卫龙美味全球控股有限公司&#xff08;下称“卫龙”&#xff0c;HK:09985&#xff09;在港交所上市。本次上市&#xff0c;卫龙的发行价格为10.56港元/股&#xff0c;募资总额约为10.18亿港元&#xff0c;募资净额约为8.99亿港元。 上市首日&#xff0c;卫…

Web3中文|NFT如何助力项目进入Web3?

自NFT流行以来&#xff0c;一直有人将这些由区块链驱动的代币视作贯彻人类精神的最终疗法。 但是NFT真的都存储在区块链上了吗&#xff1f;如果是这样的话&#xff0c;怎么还会出现百万NFT被盗的事件呢&#xff1f; 如果你也想过这些问题&#xff0c;那么请相信我&#xff0c…

在现有项目里面添加 TSX 并编写组件过程记录

首先需要安装编译支持和 vite 支持插件 ## babel 基础插件 yarn add vue/babel-plugin-jsx -D## 项目用 vite 构建的就需要按照这个 yarn add vitejs/plugin-vue-jsx -D 使用插件 按照 babel-plugin-jsx 的指引在 babel 配置项中启用插件&#xff1a; {"plugins":…

Linux操作系统常见问题汇总

1.系统启动流程。 uboot -> kernel -> 根文件系统。 uboot第一阶段属于汇编阶段&#xff1a; 定义入口&#xff08;start.S&#xff09;&#xff1a;uboot中因为有汇编阶段参与&#xff0c;因此不能直接找main.c。 设置异常向量&#xff1a;当硬件发生故障的时候CPU会…

K8s Dashboard 部署

1、下载 Dashboard 的 yaml 文件 点击链接下载 YAML 文件 2、源码包中 yaml 文件在哪里 3、修改 yaml 文件 vim recommended.yaml... kind: Service apiVersion: v1 metadata:labels:k8s-app: kubernetes-dashboardname: kubernetes-dashboardnamespace: kubernetes-dashboard…

Java web 2022跟学尚硅谷(十) 后端基础 书城

Java web 2022跟学尚硅谷十 后端基础 书城验证码kaptcha和缓存cookie简单了解cookie步骤简单创建cookie的样例代码CookieServlet01hello.html页面结果Cookie保存结果第二次请求cookie的APIKaptcha验证码使用步骤显示效果验证码的校验相关类KaptchaServlet01书城1.2正则表达式正…