2-Embedding例子:简单NN网络、迁移学习例子(glove语料预训练)

news2024/12/22 22:42:50

一、简单例子:构造简单NN网络生成Embedding

1、pytorch例子

2、tensorflow例子

# 1导入模块
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np

# 2构建语料库
corpus=[
  ["The", "weather", "will", "be", "nice", "tomorrow"],
  ["How", "are", "you", "doing", "today"],
  ["Hello", "world", "!"]
]

# 3生成字典
#获取语料不同单词,并过滤掉一些字符如"!"
word_set=set([i for item in corpus for i in item if i!='!']) 
word_dicts={}

#索引从1开始,0用来填充
j=1
for i in word_set:
    word_dicts[i]=j 
    j=j+1
   
# 4用索引表示语料
raw_inputs=[]
for i in range(len(corpus)):
    raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])
    padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,padding='post')
    print(padded_inputs)

# 5构建网络
model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape
# 6 查看结果
output_array[1]

输出结果:

二、迁移学习: 使用预训练模型生成Embedding

1、什么是迁移学习?不同任务场景下,如何使用预训练模型?

迁移学习是在一个任务上学习到的模型(结构、权重)作为初始点,应用到另一个新的任务上。

那该如何使用预训练模型呢?

场景1: 数据集小,数据相似度高

去掉输出层,然后将剩下的整个网络当作一个固定的特征提取机,应用到新的数据集中。
过程如图3-11所示,调整分类器中的几个参数,其他模块保持“冻结”即可。
这种微调方法,有时又称为特征抽取,因为预训练模型可以作为目标数据的特征提取器。

场景2: 数据集大, 数据相似度高

因为目标数据与预训练模型的训练数据之间高度相似,故采用预训练模型会非常有效。
另外,训练系统有一个较大的数据集,采用冻结预处理模型中少量较低层,修改分类器,然后在新数据集的基础上重新开始训练是一种较好的方式,具体处理过程如图3-12所示。

场景3:  数据集小,数据相似度不高

在这种情况下,可以冻结预训练模型中较少的网络高层,然后重新训练后面的网络,修改分类器。因为数据的相似度不高,重新训练的过程就变得非常关键。而新数据集大小的不足,则是通过冻结预训练模型中一些较低的网络层进行弥补,具体处理过程如图3-13所示。

场景4: 数据集大, 数据相似度不高

在这种情况下,因为有一个很大的数据集,所以神经网络的训练过程将会比较有效率。然而,因为目标数据与预训练模型的训练数据之间存在很大差异,采用预训练模型不是一种高效的方式。因此最好的方法还是将预处理模型中的权重全都初始化后再到新数据集的基础上重新开始训练,具体处理过程如图3-14所示。

2、使用Glove预训练数据集迁移学习例子

import os

imdb_dir = './aclImdb' # 电影评论数据集
train_dir = os.path.join(imdb_dir, 'train')

labels = []
texts = []

for label_type in ['neg', 'pos']:
    dir_name = os.path.join(train_dir, label_type)
    for fname in os.listdir(dir_name):
        if fname[-4:] == '.txt':
            f = open(os.path.join(dir_name, fname))
            texts.append(f.read())
            f.close()
            if label_type == 'neg':
                labels.append(0)
            else:
                labels.append(1)

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np

maxlen = 100  # 只保留前100单词的评论
training_samples = 200  # 在200个样本上训练
validation_samples = 10000  # W对10000个样品进行验证
max_words = 10000  # 只考虑数据集中最常见的10000 个单词

tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))

data = pad_sequences(sequences, maxlen=maxlen)

labels = np.asarray(labels)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)

# 将数据划分为训练集和验证集
# 首先打乱数据, 因一开始数据集是排序好的
# 负面评论在前, 正面评论在后
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]

x_train = data[:training_samples]
y_train = labels[:training_samples]
x_val = data[training_samples: training_samples + validation_samples]
y_val = labels[training_samples: training_samples + validation_samples]

glove_dir = './glove.6B/'

embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'))
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

for key,value in embeddings_index.items():
    print(key,value)
    break

embedding_dim = 100

embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if i < max_words:
        if embedding_vector is not None:
            # 在嵌入索引(embedding index)找不到的词,其嵌入向量都设为0
            embedding_matrix[i] = embedding_vector
            
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Dense

model = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = False

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['acc'])
history = model.fit(x_train, y_train,
                    epochs=10,
                    batch_size=32, 
                    validation_data=(x_val, y_val))
model.save_weights('pre_trained_glove_model.h5')
import matplotlib.pyplot as plt
%matplotlib inline

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

输出结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1615538.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【更新】cyのMemo(20240422~)

序言 胡哥首马在淮安325完赛&#xff0c;他的本硕都在淮安度过&#xff0c;七年的跑步生涯画上句号&#xff0c;真的是很圆满。七年&#xff0c;从180斤瘦到120斤&#xff0c;历经种种&#xff0c;胡哥理解的跑步&#xff0c;不是快&#xff0c;而是稳&#xff0c;他在比赛中从…

线性表的顺序存储如何设计实现?

如何存储 顺序及链式实现 计算机中的状态

【Java】变量零基础教程

目录 一、引言 二、基本组成单位 三、变量的基本原理 四、变量的基本使用步骤 五、变量快速入门 六、变量使用的注意事项 一、引言 为什么需要变量&#xff1f; ​​​​​​一个程序就是一个世界。 大家看下图&#xff0c;是我们现实中的一张生活照&#xff0c;图里有树…

汕头联想 ibm x3500 M5服务器上门维修记录

汕头联想服务器现场检修&#xff1b;汕尾IBM服务器故障维修&#xff1b;揭阳戴尔服务器维修&#xff1b;汕头ERP服务器维修&#xff1b;潮阳地区各种服务器故障维修&#xff1b;各类服务器主板齐全&#xff1b; 分享一例从东莞到汕头某染料厂维修ibm system x3500 M5服务器的真…

47.基于SpringBoot + Vue实现的前后端分离-校园外卖服务系统(项目 + 论文)

项目介绍 本站是一个B/S模式系统&#xff0c;采用SpringBoot Vue框架&#xff0c;MYSQL数据库设计开发&#xff0c;充分保证系统的稳定性。系统具有界面清晰、操作简单&#xff0c;功能齐全的特点&#xff0c;使得基于SpringBoot Vue技术的校园外卖服务系统设计与实现管理工作…

分布式技术在文本摘要生成中的应用

摘要 自然语言处理首先要应对的是如何表示文本以供机器处理&#xff0c;随着网络技术的发展和信息的公开&#xff0c;因特网上可供访问的数字文档成爆炸式的增长&#xff0c;文本摘要生成逐渐成为了自然语言处理领域的重要研究课题。本文主要介绍了分布式技术在文本摘要生成中…

Oracle21C 引入HR实例(linux)

1、下载资源 https://github.com/oracle-samples/db-sample-schemas点击code&#xff08;代码&#xff09;下载 2、上传Sql文件 解压之后将human_resources里的文件复制到demo\schema\目录&#xff08;具体目录前面的路径是你安装的路径&#xff09;下&#xff0c;如下图 3、…

argparse模块(详解)

文章目录 一、argparse模块&#xff08;1&#xff09;创建命令行解析对象&#xff1a;parser argparse.ArgumentParser()&#xff08;2&#xff09;添加命令行参数和选项&#xff1a;parser.add_argument()&#xff08;3&#xff09;解析命令行参数&#xff1a;args parser.p…

就业班 第三阶段(nginx) 2401--4.22 day1 nginx1 http+nginx初识+配置+虚拟主机

一、HTTP 介绍 HTTP协议是Hyper Text Transfer Protocol&#xff08;超文本传输协议&#xff09;的缩写,是用于从万维网&#xff08;WWW:World Wide Web &#xff09;服务器传输超文本到本地浏览器的传送协议。 HTTP是一个基于TCP/IP通信协议来传递数据&#xff08;HTML 文件…

Web3钱包开发获取测试币-Polygon Mumbai(一)

Web3钱包开发获取测试币-Polygon Mumbai(一) 由于主网区块链上的智能合约需要真正的代币&#xff0c;而部署和使用需要花费真金白银&#xff0c;因此测试网络为 Web3 开发人员提供了一个测试环境&#xff0c;用于部署和测试他们的智能合约&#xff0c;以识别和修复在将智能合约…

❤️新版Linux零基础快速入门到精通——第三部分❤️

❤️新版Linux零基础快速入门到精通——第三部分❤️ 非科班的我&#xff01;Ta&#xff01;还是来了~~~3. Linux权限管控3.1 认知root用户3.1.1 Switch User——su3.1.2 sudo命令3.1.3 为普通用户配置sudo认证 3.2 用户和用户组3.2.1 用户、用户组3.2.2 用户组管理3.2.3 用户管…

辽宁梵宁教育设计培训:赋能大学生,新技能学习再升级

辽宁梵宁教育设计培训&#xff1a;赋能大学生&#xff0c;新技能学习再升级 在当今这个日新月异、信息爆炸的时代&#xff0c;大学生们面临着前所未有的挑战与机遇。为了帮助他们更好地适应社会的快速变化&#xff0c;提升个人的综合素质和竞争力&#xff0c;辽宁梵宁教育设计…

基于python实现的医疗领域用户问答的意图识别算法研究(django)

基于python实现的医疗领域用户问答的意图识别算法研究(django) 开发语言&#xff1a;Python语言 数据库&#xff1a;MySQL&#xff0c;Neo4j知识&#xff1a;深度学习&#xff0c;知识图谱工具&#xff1a;pycharm、Navicat、Maven 系统的实现 系统登录界面 医疗领域用户问答…

华媒舍:百度竞价排名如何提升点击率

在网络推广中&#xff0c;提升点击率是十分重要的。运用百度搜索引擎广告是一种常用的提升点击率的形式。而百度竞价推广是搜索引擎所提供的一种付费流量方法&#xff0c;根据提高网站在搜索结果中的排名&#xff0c;可以有效提升点击率。下面我们就详细介绍如何运用百度竞价推…

钢管钢材地板踢脚线定购规格采购批发商城h5公众号开发

钢管钢材地板踢脚线定购规格采购批发商城h5公众号开发 商品管理&#xff0c;订单管理&#xff0c;用户管理&#xff0c;售后管理&#xff0c;商品评价&#xff0c;虚拟商品自动发货&#xff0c;优惠劵&#xff0c;购物送劵。 您可以在这个H5公众号商城上找到以下功能列表&…

flink Unsupported operand types: IF(boolean, NULL, String)

问题&#xff1a;业务方存储了NULL 字符串&#xff0c;需要处理为 null select if(anull&#xff0c;null&#xff0c;a); 结果遇到了 Unsupported operand types: IF(boolean, NULL, String)&#xff0c;根据报错反馈&#xff0c;很明显应该是没有对 null 自动转换&#xff…

Day39 网络编程(一):计算机网络,网络编程,网络模型,网络编程三要素

Day39 网络编程&#xff08;一&#xff09;&#xff1a;计算机网络&#xff0c;网络编程&#xff0c;网络模型&#xff0c;网络编程三要素 文章目录 Day39 网络编程&#xff08;一&#xff09;&#xff1a;计算机网络&#xff0c;网络编程&#xff0c;网络模型&#xff0c;网络…

JAVA学习笔记28(常用类)

1.常用类 1.1 包装类 1.包装类的分类 ​ 1.针对八中基本数据类型相应的引用类型–包装类 ​ 2.有了类的特点&#xff0c;就可以调用类中的方法 2.包装类和基本数据类型的转换 ​ *装箱&#xff1a;基本类型 --> 包装类型 //手动装箱 int n1 100; Integer integer ne…

Web3钱包开发获取测试币-Base Sepolia(二)

Web3钱包开发获取测试币-Base Sepolia(二) ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/b0c0ac86b04a496087471388532bc54a.png) 基于上篇 Web3钱包开发获取测试币-Polygon Mumbai(一) &#xff1a;https://suwu150.blog.csdn.net/article/details/137949473 我…

Centos7.9云计算CloudStack4.15 高级网络配置(3)

上两章的文章都是用的CloudStack的基本网络&#xff0c;这一篇我们来介绍CloudStack的高级网络&#xff0c;这里虚拟机用的是自己配置的内部网络&#xff0c;通过nat方式到物理网络。按照第一篇的文章&#xff0c;安装管理服务器和计算服务器。 并且在管理服务器配置好如下的全…