【NLP】从预训练模型中获取Embedding

news2024/10/6 16:23:53

从预训练模型中获取Embedding

  • 背景说明
  • 下载IMDB数据集
  • 进行分词
  • 下载并预处理GloVe词嵌入数据
  • 构建模型
  • 训练模型并可视化结果
  • 结果对比
  • 其他代码

在NLP领域中,构建大规模的标注数据集非常困难,以至于仅用当前语料无法有效完成特定任务。可以采用迁移学习的方法,即将预训练好的词嵌入作为模型的权重,然后在此基础上的微调。

背景说明

本文基于数据集IMDB进行情感分析,做预测之前需要先对数据集进行预处理,把词转换为词嵌入。这里采用迁移学习的方法来提升模型性能,并使用2014年英文维基百科的预计算词嵌入数据集。**该数据集文件名为glove.6B.zip,大小为822MB,里面包含400000个单词的100维嵌入向量。**把预训练好的词嵌入导入模型的第一层,并冻结该层,然后增加分类层进行分类预测。GloVe词嵌入数据集的下载地址为https://nlp.stanford.edu/projects/glove/
glove6B

下载IMDB数据集

下载IMDB数据集并保存在本地,这里下载glove.6B.zip完成之后保存到“./aclImdb”。
IMDB数据库
读取电影评论并根据其评论的正面或负面将其分类为1或0:

def preprocessing(data_dir):
    # 提取数据集中所有的评论文本以及对应的标签,neg表示负面标签,用0表示,pos表示正面标签,用1表示
    texts, labels = [], []
    train_dir = os.path.join(data_dir, 'train')
    for label_type in ['neg', 'pos']:
        dir_name = os.path.join(train_dir, label_type)
        for fname in os.listdir(dir_name):
            if fname[-4:] == '.txt':
                f = open(os.path.join(dir_name, fname), encoding='utf-8')
                texts.append(f.read())
                f.close()
                if label_type == 'neg':
                    labels.append(0)
                else:
                    labels.append(1)
    return texts, labels

进行分词

利用TensorFlow 2.0提供的Tokenizer函数进行分词操作,如下所示:

def word_cutting(args, texts, labels):
    # 利用Tokenizer函数进行分词操作
    tokenizer = Tokenizer(num_words=args.max_words)
    tokenizer.fit_on_texts(texts)
    sequences = tokenizer.texts_to_sequences(texts)

    word_index = tokenizer.word_index
    print(f'Found {len(word_index)} unique tokens.')

    data = pad_sequences(sequences, maxlen=args.maxlen)
    labels = np.asarray(labels)
    print(f'Shape of data tensor: {data.shape}')
    print(f'Shape of label tensor: {labels.shape}')
    return word_index, data, labels

将分词后的数据划分为训练集和测试集:

def data_spliting(data, labels, training_samples, validation_samples):
    # 将数据划分为训练集和验证集, 首席按打乱数据,因为一开始数据集是排好序的,负面评论在前,正面评论灾后
    indices = np.arange(data.shape[0])
    np.random.shuffle(indices)
    data = data[indices]
    labels = labels[indices]

    x_train = data[:training_samples]
    y_train = labels[:training_samples]
    x_val = data[training_samples:training_samples + validation_samples]
    y_val = labels[training_samples:training_samples + validation_samples]
    return x_train, y_train, x_val, y_val

下载并预处理GloVe词嵌入数据

对glove.6B.100d.txt文件进行解析,构建一个由单词映射为其向量表示的索引,这里根据前面的描述,已经下载好并放到本地目录中。

def glove_embedding(glove_dir):
    # 对glove.6B文件进行解析,构建一个由单词映射为其向量表示的索引
    embeddings_index = {}
    f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'), encoding='utf-8')
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = coefs
    f.close()
    print(f'Found {len(embeddings_index)} word vectors.')
    # 查看字典的样例
    for key, value in embeddings_index.items():
        print(key, value)
        print(value.shape)
        break
    return embeddings_index

构建模型

采用Keras的序列(Sequential)构建模型

def model_building(max_words, embedding_dim, maxlen):
    model = Sequential()
    model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
    model.add(Flatten())
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()
    return model

训练模型并可视化结果

训练的时候区分有嵌入和无嵌入两种情况:

def training(args):
    texts, labels = preprocessing(args.data_dir)
    word_index, data, labels = word_cutting(args, texts, labels)
    x_train, y_train, x_val, y_val = data_spliting(data, labels, args.training_samples, args.validation_samples)
    embeddings_index = glove_embedding(args.embed_dir)
    # 在embeddings_index字典的基础上构建一个矩阵,对单词索引中索引为i的单词来说,该矩阵的元素i就是这个单词对应的词向量
    embedding_matrix = np.zeros((args.max_words, args.embedding_dim))
    for word, i in word_index.items():
        embedding_vector = embeddings_index.get(word)
        if i < args.max_words:
            if embedding_vector is not None:
                # 在嵌入索引(embedding_index)找不到的词,其嵌入向量都设为0
                embedding_matrix[i] = embedding_vector

    model = model_building(args.max_words, args.embedding_dim, args.maxlen)
    # 在模型中加载GloVe词嵌入,并冻结Embedding Layer
    model.layers[0].set_weights([embedding_matrix])
    model.layers[0].trainable = False
    # 训练模型
    model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
    history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val))
    model.save_weights('./pre_trained_glove_model.h5')
    visualizating(history)

    # 不使用预训练词嵌入的情况
    model = model_building(args.max_words, args.embedding_dim, args.maxlen)
    model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
    history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val))
    visualizating(history)

可视化代码如下:

def visualizating(history):
    # 可视化,并查看验证集的准确率acc
    acc, val_acc = history.history['acc'], history.history['val_acc']
    loss, val_loss = history.history['loss'], history.history['val_loss']

    epochs = range(1, len(acc) + 1)
    plt.figure()
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()
    plt.show()

结果对比

1)在模型中加载GloVe词嵌入的情况,并冻结Embedding Layer:
有嵌入情况
可以看出,训练准确率和验证准确率相差较大,但是这里使用的较少的训练数据。验证准确率接近60%。
2)不使用预训练词嵌入的情况
无嵌入
显然,不使用预训练模型的验证准确率只有55%左右。因此可以初步得出结论,使用预训练词嵌入的模型性能优于未使用的模型性能。

其他代码

import argparse
import os
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Dense


def main():
    parser = argparse.ArgumentParser(description='IMDB Emotional Analysis')
    parser.add_argument('--data_dir', type=str, default='./aclImdb', help='the dir path of IMDB dataset')
    parser.add_argument('--embed_dir', type=str, default='./glove.6B/', help='the dir path of embeddings')

    parser.add_argument('--maxlen', type=int, default=100, help='Keep comments for only the first 100 words')
    parser.add_argument('--training_samples', type=int, default=500, help='Train on 200 samples')
    parser.add_argument('--validation_samples', type=int, default=10000, help='Verify 10,000 samples')
    parser.add_argument('--max_words', type=int, default=10000, help='Consider the most common 10000 words')
    parser.add_argument('--embedding_dim', type=int, default=100, help='the dim of word vector')
    args = parser.parse_args()
    training(args)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/768690.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

hbuilder创建基于vue2的uniapp小程序项目

参考vant官网&#xff1a;https://vant-contrib.gitee.io/vant/v3/#/zh-CN/quickstart#an-zhuang官网 参考别人博客&#xff1a;https://www.yii666.com/blog/465379.html 1.创建项目 1.1 hbuilder进去右上角点击文件–新建–项目 1.2 vue2项目如下图 2.安装依赖 2.1 2.2…

Linux搭建SVN环境(最新版)

最新版本号(svn-1.14) https://opensource.wandisco.com/centos/7 更新版本库 sudo tee /etc/yum.repos.d/wandisco-svn.repo <<-EOF [WandiscoSVN] nameWandisco SVN Repo baseurlhttp://opensource.wandisco.com/centos/$releasever/svn-1.14/RPMS/$basearch/ enabled…

TypeScript 学习笔记(七):条件类型

条件类型 TS中的条件类型就是在类型中添加条件分支&#xff0c;以支持更加灵活的泛型&#xff0c;满足更多的使用场景。内置条件类型是TS内部封装好的一些类型处理&#xff0c;使用起来更加便利。 一、基本用法 当T类型可以赋值给U类型时&#xff0c;则返回X类型&#xff0c…

一探究竟:人工智能、机器学习、深度学习

一、人工智能 1.1 人工智能是什么&#xff1f; 1956年在美国Dartmounth 大学举办的一场研讨会中提出了人工智能这一概念。人工智能&#xff08;Artificial Intelligence&#xff09;&#xff0c;简称AI&#xff0c;是计算机科学的一个分支&#xff0c;它企图了解智能的实质&am…

拦截器是什么

拦截器 package com.qf.config;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView;import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse;public class MyIntercep…

VSCode下载安装(保姆级--一步到胃)

前言 Visual Studio Code&#xff08;简称“VSCode” &#xff09;是Microsoft在2015年4月30日Build开发者大会上正式宣布一个运行于 Mac OS X、Windows和 Linux 之上的&#xff0c;针对于编写现代Web和云应用的跨平台源代码编辑器&#xff0c;可在桌面上运行&#xff0c;并且…

零售行业门店综合管理系统怎么做?店务系统有什么功能?

线下门店则变成了零售行业的重要战场。今时不同往日&#xff0c;现在线下门通常得需要兼多种角色&#xff0c;无论是对于门店员工还是管理者来说经营难度和工作强度都在显著增加。像传统落后的门店管理存在着库存失衡&#xff0c;服务效率低&#xff0c;信息滞后且准确度低等问…

使用IDEA社区版创建SpringBoot项目

文章目录 1.关于IDEA社区版的版本2.下载Spring Boot Helper3.创建项目4.配置Maven国内源4.1找不到settings.xml的情况4.2找得到settings.xml的情况 4.3删除repository目录下的所有文件和目录5.加载项目6.解决org.springframework.boot:spring-boot-starter-parent:pom:2.7.13.R…

学员管理系统——面向对象

文章目录 前言基本思路Student.pymain.pyStudentManage.py菜单 menu()根据菜单实现程序的大概逻辑add_student() 添加学员信息delete_student() 删除学员信息modify_studnet() 修改学员信息search_student() 查找学员信息print_student() 显示所有学员信息save_student() 保存学…

使用qt的webengine让客户端嵌入网页

前提 在windows下&#xff0c;qt下 界面 用qt的界面设计拉上一些东西&#xff0c;一个跑按钮&#xff0c;一个刷新按钮&#xff0c;一个弹出框按钮&#xff0c;地址栏是为了填入新的https地址&#xff0c;一个verticalLayout是为了限定webengine的显示&#xff0c;需要包含 …

UI界面中的图标设计趋势与最佳实践

作为UI设计师&#xff0c;在日常的工作中&#xff0c;避免不了做图标规范。今天跟大家聊一聊&#xff0c;UI设计中的图标设计。 规范的重要性不用多说了&#xff0c;没有规范多个设计师绘制的图标会有很多差异&#xff0c;描边粗细、角度、圆角度等等。今天的文章和大家聊一下…

opencv-14 图像加密和解密

在OpenCV中&#xff0c;图像加密和解密是通过对图像像素进行一系列的变换和操作来实现的 通过按位异或运算可以实现图像的加密和解密。 通过对原始图像与密钥图像进行按位异或&#xff0c;可以实现加密&#xff1b;将加密后的图像与密钥图像再次进行按位异或&#xff0c;可以实…

MFC第十八天 非模式对话框、对话框颜色管理、记事本项目(查找替换、文字和背景色、Goto(转到)功能的开发)

文章目录 非模式对话框非模式对话框的特点非模式对话框与QQ聊天窗口开发非模态对话框&#xff08;Modeless Dialog&#xff09;和模态对话框&#xff08;Modal Dialog&#xff09;区别 记事本开发CFindReplaceDialog类的成员查找替换(算法分析)使用RichEdit控件 开发Goto(转到)…

Django实现接口自动化平台(十三)接口模块Interfaces序列化器及视图【持续更新中】

相关文章&#xff1a; Django实现接口自动化平台&#xff08;十二&#xff09;自定义函数模块DebugTalks 序列化器及视图【持续更新中】_做测试的喵酱的博客-CSDN博客 本章是项目的一个分解&#xff0c;查看本章内容时&#xff0c;要结合整体项目代码来看&#xff1a; pytho…

3.13 Bootstrap 页面标题(Page Header)

文章目录 Bootstrap 页面标题&#xff08;Page Header&#xff09; Bootstrap 页面标题&#xff08;Page Header&#xff09; 页面标题&#xff08;Page Header&#xff09;是个不错的功能&#xff0c;它会在网页标题四周添加适当的间距。当一个网页中有多个标题且每个标题之间…

MotionBert论文解读及详细复现教程

MotionBert&#xff1a;统一视角学习人体运动表示 通过学习人体运动表征&#xff0c;论文原作者提出了处理以人为中心的视频任务的统一方法。使用双流时空transformer&#xff08;DSTformer&#xff09;网络实现运动编码器&#xff0c;能够全面、自适应地捕获骨骼关节之间的远…

数据结构——六大排序 (插入,选择,希尔,冒泡,堆,快速排序)

1. 插入排序 1.1基本思路 把待排序的记录按其关键码值的大小逐个插入到一个已经排好序的有序序列中&#xff0c;直到所有的记录插入完为止&#xff0c;得到一个新的有序序列 我们熟知的斗地主就是一个插入排序 1.2 代码实现 我们这里将一个无序数组变成有序数组 插入排序时…

CVE-2017-15715

CVE-2017-15715 一、环境搭建二、漏洞原理三、漏洞复现 一、环境搭建 如下介绍kali搭建的教程 cd ~/vulhub/httpd/CVE-2017-15715 // 进入指定环境 docker-compose build // 进行环境编译 docker-compose up -d // 启动环境docker-compose ps使用这条命令查看当前正在…

放射显影多肽1778691-88-5,DOTA Methyltetrazine ,四甲基四嗪修饰大环配体

资料编辑|陕西新研博美生物科技有限公司小编MISSwu​ 中文名称&#xff1a;四甲基四嗪修饰大环配体 英文名称&#xff1a;FOLATE-NOTA&#xff0c; Methyltetrazine-DOTA 规格标准&#xff1a;1g、5g、10g CAS&#xff1a;1778691-88-5 分子式&#xff1a;C37H52N12O12 分子量…