Tensorflow2基础代码实战系列之双层RNN文本分类任务

news2025/1/6 19:56:37

深度学习框架Tensorflow2系列

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#


文章目录

  • 深度学习框架Tensorflow2系列
  • 前言
  • 一、文本分类任务实战
  • 二、数据集介绍
  • 三、RNN模型所需数据解读
  • 四、实战代码
    • 1.数据预处理
    • 2.构建初始化embedding层
    • 3.构建训练数据
    • 4.自定义双层RNN网络模型
    • 5.设置参数和训练策略
    • 6.模型训练
  • 总结


前言

通过代码案例实战,学习Tensorflow2的各种api。


一、文本分类任务实战

任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别

二、数据集介绍

训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
在这里插入图片描述

三、RNN模型所需数据解读

在这里插入图片描述
RNN是一个比较基础的序列化模型,其中输入的数据为[batch_size,max_len,feature_dim]

四、实战代码

1.数据预处理

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import Counter
from pathlib import Path
from tqdm import tqdm

# 构建语料表,基于词频来进行统计
counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:
    for line in f:
        line = line.rstrip()
        label, words = line.split('\t')
        words = words.split(' ')
        counter.update(words)

words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))

Path('./vocab').mkdir(exist_ok=True)

with open('./vocab/word.txt', 'w',encoding='utf-8') as f:
    for w in words:
        f.write(w+'\n')

# 得到word2id映射表
word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:
    for i, line in enumerate(f):
        line = line.rstrip()
        word2idx[line] = i

得到的结果如下
在这里插入图片描述

2.构建初始化embedding层

# 做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknow

with open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的
    count = 0
    for i, line in enumerate(f):
        if i % 100000 == 0:
            print('- At line {}'.format(i)) #打印处理了多少数据
        line = line.rstrip()
        sp = line.split(' ')
        word, vec = sp[0], sp[1:]
        if word in word2idx:
            count += 1
            embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量
 
# 保存结果
print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')

得到的结果如下:word.txt中的每个单词转换成对应的向量
在这里插入图片描述

3.构建训练数据

def data_generator(f_path, params):
    with open(f_path,encoding='utf-8') as f:
        print('Reading', f_path)
        for line in f:
            line = line.rstrip()
            label, text = line.split('\t')
            text = text.split(' ')
            x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的ID
            if len(x) >= params['max_len']:#截断操作
                x = x[:params['max_len']]
            else:
                x += [0] * (params['max_len'] - len(x))#补齐操作
            y = int(label)
            yield x, y

def dataset(is_training, params):
    _shapes = ([params['max_len']], ())
    _types = (tf.int32, tf.int32)
  
    if is_training:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['train_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.shuffle(params['num_samples'])
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)#设置缓存序列,目的加速
    else:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['test_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  
    return ds

4.自定义双层RNN网络模型

class Model(tf.keras.Model):
    def __init__(self, params):
        super().__init__()
    
        self.embedding = tf.Variable(np.load('./vocab/word.npy'),
                                     dtype=tf.float32,
                                     name='pretrained_embedding',
                                     trainable=False,)

        self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])

        self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))

        self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])
        self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)

        self.out_linear = tf.keras.layers.Dense(2)

  
    def call(self, inputs, training=False):
        if inputs.dtype != tf.int32:
            inputs = tf.cast(inputs, tf.int32)
    
        batch_sz = tf.shape(inputs)[0]
        rnn_units = 2*params['rnn_units']

        x = tf.nn.embedding_lookup(self.embedding, inputs)
        
        x = self.drop1(x, training=training)
        x = self.rnn1(x)

        x = self.drop2(x, training=training)
        x = self.rnn2(x)

        x = self.drop3(x, training=training)
        x = self.rnn3(x)

        x = self.drop_fc(x, training=training)
        x = self.fc(x)

        x = self.out_linear(x)

        return x

5.设置参数和训练策略

params = {
  'vocab_path': './vocab/word.txt',
  'train_path': './data/train.txt',
  'test_path': './data/test.txt',
  'num_samples': 25000,
  'num_labels': 2,
  'batch_size': 32,
  'max_len': 1000,
  'rnn_units': 200,
  'dropout_rate': 0.2,
  'clip_norm': 10.,
  'num_patience': 3,
  'lr': 3e-4,
}
def is_descending(history: list):
    history = history[-(params['num_patience']+1):]
    for i in range(1, len(history)):
        if history[i-1] <= history[i]:
            return False
    return True  
word2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:
    for i, line in enumerate(f):
        line = line.rstrip()
        word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1


model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到

decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0

history_acc = []
best_acc = .0

t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)

6.模型训练

while True:
  # 训练模型
    for texts, labels in dataset(is_training=True, params=params):
        with tf.GradientTape() as tape:#梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度
            logits = model(texts, training=True)
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
            loss = tf.reduce_mean(loss)
  
        optim.lr.assign(decay_lr(global_step))
        grads = tape.gradient(loss, model.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, params['clip_norm']) #将梯度限制一下,有的时候回更新太猛,防止过拟合
        optim.apply_gradients(zip(grads, model.trainable_variables))#更新梯度

        if global_step % 50 == 0:
            logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(
                global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))
            t0 = time.time()
        global_step += 1

    # 验证集效果
    m = tf.keras.metrics.Accuracy()

    for texts, labels in dataset(is_training=False, params=params):
        logits = model(texts, training=False)
        y_pred = tf.argmax(logits, axis=-1)
        m.update_state(y_true=labels, y_pred=y_pred)
    
    acc = m.result().numpy()
    logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))
    history_acc.append(acc)
  
    if acc > best_acc:
        best_acc = acc
    logger.info("Best Accuracy: {:.3f}".format(best_acc))
  
    if len(history_acc) > params['num_patience'] and is_descending(history_acc):
        logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))
        break

总结

通过RNN文本分类任务代码案例实战,学习Tensorflow2的各种api。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/563320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动化测试工具selenium的使用方法

一、前言 由于requests模块是一个不完全模拟浏览器行为的模块&#xff0c;只能爬取到网页的HTML文档信息&#xff0c;无法解析和执行CSS、JavaScript代码&#xff0c;因此需要我们做人为判断&#xff1b; selenium模块本质是通过驱动浏览器&#xff0c;完全模拟浏览器的操作&…

Python爬虫入门案例6:scrapy的基本语法+使用scrapy进行网站数据爬取

几天前在本地终端使用pip下载scrapy遇到了很多麻烦&#xff0c;总是报错&#xff0c;花了很长时间都没有解决&#xff0c;最后发现pycharm里面自带终端&#xff01;&#xff08;狂喜&#xff09;&#xff0c;于是直接在pycharm终端里面写scrapy了 这样的好处就是每次不用切换路…

项目风险应对策略:项目经理应对不确定性的指南

风险应对是项目经理管理项目未来的工具箱。它可以帮助管理人员弄清楚可能会出现什么问题&#xff0c;并让他们有机会为这些问题做好准备。 对抗负面风险的5种策略 如果没有风险管理计划&#xff0c;项目可能会因意外问题或不良风险而迅速脱轨。什么策略可以用来对抗负面风险&…

Salesforce认证|新鲜出炉销售代表认证!

Salesforce一直致力于为专业人士提供测试知识与技能的方法&#xff0c;现在终于轮到销售人员了&#xff01; 前不久&#xff0c;Salesforce宣布推出销售代表认证&#xff0c;这不仅是首个面向销售人员的认证&#xff0c;也是为数不多的非技术类、非顾问类认证&#xff0c;这为…

记录 aaPanel 安装环境失败的经历及解决方案

最近我在一台Debian 11的国外服务器上安装aaPanel&#xff08;即宝塔面板的国际版&#xff09;。在安装完面板后&#xff0c;我继续安装LNMP环境。几分钟后&#xff0c;aaPanel提示LNMP环境已经安装成功。然而&#xff0c;在创建站点时&#xff0c;却提示环境没有安装。 问题排…

财务共享中心成功建立!用友帮助河南水投集团打造财务效率新高地

河南水投集团作为省级水务集团&#xff0c;自成立以来一直坚持以资产筹集资金&#xff0c;以资金建设项目&#xff0c;以运营扩张资本。即使在面对经济下行压力及疫情影响双重挑战下&#xff0c;仍坚持结果导向&#xff0c;通过项目建设推动发展&#xff0c;保持了较好的发展态…

MyBatisPlus更新字段为null的正确姿势以及lambda方式的条件字段解析之源码解析

文章目录 [toc] 1.问题2.原因3.解决方法3.1错误方法方式一&#xff1a;配置全局字段策略方式二&#xff1a;在实体上添加字段策略注解 3.2正确姿势方式一&#xff1a;使用LambdaUpdateWrapper &#xff08;推荐&#xff09;方式二&#xff1a;使用UpdateWrapper方式三 总结 1.问…

沉降仪工作原理

输电线路杆塔倾斜北斗在线监测装置 一、产品概述 杆塔、铁塔在时间、自然因素的影响下&#xff0c;发生的倾斜、偏离等现象&#xff0c;而在人工巡检电力设施时是不容易通过人眼判别的&#xff0c;在日积月累的变化中&#xff0c;铁塔、杆塔会因倾斜幅度过大进一步引发严重的坍…

基于 Bert 论文构建 Question-Answering 模型

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 摘要 本文拜读了提出 Bert 模型的论文&#xff0c;考虑了在 Bert 中算法模型的实现.比较了 Bert 与其他如 Transformer、GPT 等热门 NLP 模型.BERT 在概念上很简单&#xff0c;在经验上也很强大。它推动了 11 项自然语言处理任…

“她经济”崛起,茉莉智慧如何以科技赋能月子中心迭代升级?

近年来&#xff0c;利好生育政策频出&#xff0c;女性消费能力不断提升&#xff0c;以月子中心为核心的产后护理赛道发展势头良好。据iiMedia Research数据&#xff0c;2022年中国月子中心市场规模突破223.0亿元。iiMedia Research市场调查显示&#xff0c;93.5%的受访者认为产…

ubuntu命令记录

centos 下载地址&#xff1a; 网易镜像&#xff1a;http://mirrors.163.com/centos/6/isos/ 搜狐镜像&#xff1a;http://mirrors.sohu.com/centos/6/isos/ VM与LINUX的安装&#xff08;虚拟机的安装&#xff09; 注意&#xff1a;a.必须开启虚拟化&#xff08;一般电脑都默认…

BFT 最前线 | 王小川:2033机器智慧将超人类;扎克伯格财富暴涨;哈工大:能跳跃的昆虫机器人;北京支持“1+4”机器人领域

原创 | 文 BFT机器人 名人动态 CELEBRITY NEWS 01 王小川&#xff1a;10年后机器智慧将超过人类 年底将推出对标GPT-3.5的模型 科技预言大师雷库兹韦尔说人工智能的奇点&#xff0c;机器智慧超过人类会发生在2045年&#xff0c;王小川的判断比这更激进&#xff0c;他认为这一…

复杂的C++继承

文章目录 什么是继承继承方式赋值规则继承中的作用域&#xff08;隐藏&#xff09;子类中的默认成员函数需要自己写默认成员函数的情况 继承与友元及静态成员多继承菱形继承菱形继承的问题菱形虚拟继承 继承和组合 面向对象三大特性&#xff1a;封装继承和多态。封装在类和对象…

2172. 最大公约数

Powered by:NEFU AB-IN Link 文章目录 2172. 最大公约数题意思路代码 2022年第十三届决赛真题 2172. 最大公约数 题意 给定一个数组, 每次操作可以选择数组中任意两个相邻的元素 x , y x, yx,y 并将其 中的一个元素替换为 gcd ⁡ ( x , y ) \operatorname{gcd}(x, y)gcd(x,y),…

从月薪5000到月薪20000,自动化测试应该这样学...

绝大多数测试工程师都是从功能测试做起的&#xff0c;工作忙忙碌碌&#xff0c;每天在各种业务需求学习和点点中度过&#xff0c;过了好多年发现自己还只是一个功能测试工程师。 随着移动互联网的发展&#xff0c;从业人员能力的整体进步&#xff0c;软件测试需要具备的能力要…

征稿丨IJCAI‘23大模型论坛,优秀投稿推荐AI Open和JCST发表

第一届LLMIJCAI’23 Symposium征稿中&#xff0c;优秀投稿论文推荐《AI Open》和 《JCST》发表。 大规模语言模型&#xff08;LLMs&#xff09;&#xff0c;如ChatGPT和GPT-4&#xff0c;以其在自然语言理解和生成方面的卓越能力&#xff0c;彻底改变了人工智能领域。 LLMs广泛…

Go语言文件I/O操作

go语言中的io操作主要学习目标 掌握文件的常规操作掌握ioutil包的使用掌握bufio包的使用 在go中使用 FileInfo接口 定义了IO的一些函数 FileInfo接口 源码追溯 //type.go // A FileInfo describes a file and is returned by Stat and Lstat. type FileInfo fs.FileInfo/…

ChatGPT:你真的了解网络安全吗?浅谈攻击防御进行时之传统的网络安全

ChatGPT&#xff1a;你真的了解网络安全吗&#xff1f;浅谈网络安全攻击防御进行时 传统的网络安全总结 ChatGPT&#xff08;全名&#xff1a;Chat Generative Pre-trained Transformer&#xff09;&#xff0c;美国OpenAI 研发的聊天机器人程序&#xff0c;是人工智能技术驱动…

什么是网络安全?如何让小白简单的学习网络安全

一、什么是网络安全 网络安全是一个庞大的学科&#xff0c;如果只是普及网络安全技能是非常枯燥的&#xff0c;所以建议从大众容易接受的网络安全诈骗入手&#xff0c;可以先介绍一下近年来频发的网络安全诈骗案例&#xff0c;钓鱼邮件、中奖短信、冒充公检法等多种诈骗手段&am…

Koala:加州大学BAIR团队使用ChatGPT蒸馏数据和公开数据集微调LLaMA模型得到

自从Meta发布LLaMA以来&#xff0c;围绕它开发的模型与日俱增&#xff0c;比如Alpaca、llama.cpp、ChatLLaMA以及Vicuna等等&#xff0c;相关的博客可以参考如下&#xff1a; 【Alpaca】斯坦福发布了一个由LLaMA 7B微调的模型Alpaca&#xff08;羊驼&#xff09;&#xff0c;训…