Embedding例子:简单NN网络、迁移学习例子

news2024/11/26 5:52:27

一、简单例子:构造简单NN网络生成Embedding

1、pytorch例子

2、tensorflow例子

# 1导入模块
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np

# 2构建语料库
corpus=[
  ["The", "weather", "will", "be", "nice", "tomorrow"],
  ["How", "are", "you", "doing", "today"],
  ["Hello", "world", "!"]
]

# 3生成字典
#获取语料不同单词,并过滤掉一些字符如"!"
word_set=set([i for item in corpus for i in item if i!='!']) 
word_dicts={}

#索引从1开始,0用来填充
j=1
for i in word_set:
    word_dicts[i]=j 
    j=j+1
   
# 4用索引表示语料
raw_inputs=[]
for i in range(len(corpus)):
    raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])
    padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,padding='post')
    print(padded_inputs)

# 5构建网络
model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape
# 6 查看结果
output_array[1]

输出结果:

二、迁移学习: 使用预训练模型生成Embedding

1、使用Glove预训练数据集迁移学习

import os

imdb_dir = './aclImdb' # 电影评论数据集
train_dir = os.path.join(imdb_dir, 'train')

labels = []
texts = []

for label_type in ['neg', 'pos']:
    dir_name = os.path.join(train_dir, label_type)
    for fname in os.listdir(dir_name):
        if fname[-4:] == '.txt':
            f = open(os.path.join(dir_name, fname))
            texts.append(f.read())
            f.close()
            if label_type == 'neg':
                labels.append(0)
            else:
                labels.append(1)

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np

maxlen = 100  # 只保留前100单词的评论
training_samples = 200  # 在200个样本上训练
validation_samples = 10000  # W对10000个样品进行验证
max_words = 10000  # 只考虑数据集中最常见的10000 个单词

tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))

data = pad_sequences(sequences, maxlen=maxlen)

labels = np.asarray(labels)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)

# 将数据划分为训练集和验证集
# 首先打乱数据, 因一开始数据集是排序好的
# 负面评论在前, 正面评论在后
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]

x_train = data[:training_samples]
y_train = labels[:training_samples]
x_val = data[training_samples: training_samples + validation_samples]
y_val = labels[training_samples: training_samples + validation_samples]

glove_dir = './glove.6B/'

embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'))
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

for key,value in embeddings_index.items():
    print(key,value)
    break

embedding_dim = 100

embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if i < max_words:
        if embedding_vector is not None:
            # 在嵌入索引(embedding index)找不到的词,其嵌入向量都设为0
            embedding_matrix[i] = embedding_vector
            
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Dense

model = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = False

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['acc'])
history = model.fit(x_train, y_train,
                    epochs=10,
                    batch_size=32, 
                    validation_data=(x_val, y_val))
model.save_weights('pre_trained_glove_model.h5')
import matplotlib.pyplot as plt
%matplotlib inline

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

输出结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1606977.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

配置静态IP【windows+ubuntu】

Windows配置静态IP 如下图所示&#xff0c;通过“网络和Internet进入设置界面”&#xff0c;依次操作“更改适配器选项”->选择要配置静态ip的网络“属性”->选择IPV4的属性->配置静态ip的地址、子网掩码、默认网关。默认网关应和路由器上的设置保持一致。 Ubuntu配…

2024红明谷杯——Misc 加密的流量

2024红明谷杯——Misc 加密的流量 写在前面&#xff1a; 这里是贝塔贝塔&#xff0c;照例来一段闲聊 打比赛但赛前一波三折&#xff0c;又是成功签到的一个比赛 说起来比赛全名叫红明谷卫星应用数据安全场景赛&#xff0c;但好像真的跟卫星的关系不大&#xff0c;没有bin方…

Redis中的订阅发布(三)

订阅发布 发送消息 当一个Redis客户端执行PUBLISH 命令将消息message发送给频道channel的时候&#xff0c;服务器需要执行以下 两个动作: 1.将消息message发送给channel频道的所有订阅者2.如果一个或多个模式pattern与频道channel相匹配&#xff0c;那么将消息message发送给…

基于SpringBoot+Vue的便利店管理系统 免费获取源码

项目源码获取方式放在文章末尾处 项目技术 数据库&#xff1a;Mysql5.7/8.0 数据表&#xff1a;11张 开发语言&#xff1a;Java(jdk1.8) 开发工具&#xff1a;idea 前端技术&#xff1a;vue 后端技术&#xff1a;SpringBoot 功能简介 (有文档) 项目获取关键字&#…

GUI02-在窗口上跟踪并输出鼠标位置(Win32版)

(1) 响应 WM_MOUSEMOVE 消息获得鼠标位置&#xff1b; (2) 响应 WM_PAINT 将鼠标位置输出到窗口中&#xff1b; (3) 学习二者之间的关键步骤&#xff1a;调用 InvalidateRect() 以通知窗口重绘。 零. 课堂视频 在窗口上跟踪输出鼠标位置-Win32版 一、关键知识点 1. BeginPaint…

Syncovery for Mac:高效文件备份和同步工具

Syncovery for Mac是一款专为Mac用户设计的文件备份和同步工具&#xff0c;凭借其高效、安全和易用的特点&#xff0c;深受用户好评。 Syncovery for Mac v10.14.2激活版下载 该软件具备强大的备份功能&#xff0c;支持多种备份方案和数据格式&#xff0c;用户可以根据需求轻松…

vscode自动生成返回值的快捷键

vscode中类似idea的altenter功能&#xff0c;可以添加返回值 idea中是Introduce local variable&#xff0c; vscode中按下command.(句号) 然后选extract to local variable或者 Assign statement to new local variable都行&#xff0c; 光标在分号前如图&#xff1a; 光标在…

维护SQLite的私有分支(二十六)

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLite、MySQL 和 PostgreSQL 数据库速度比较&#xff08;本文阐述时间很早比较&#xff0c;不具有最新参考性&#xff09;&#xff08;二十五&#xff09; 下一篇&#xff1a;SQLite数据库中JSON 函数和运算符 1…

前端css中transition的使用

前端css中transition的使用 一、前言二、transition的4个属性三、例子1.源码12.源码1运行效果 四、结语五、定位日期 一、前言 CSS中的transition&#xff08;过渡&#xff09;&#xff0c;根据字面意思就可以理解成一种变化状态的过程。当我们有一个方形&#xff0c;我们想让…

深度学习数据处理——对比标签文件与图像文件,把没有打标签的图像文件标记并删除

要对比目录下的jpg文件与json文件&#xff0c;并删除那些没有对应json文件的jpg文件&#xff0c;这个在深度学习或者机器学习时常会遇到。比如对一个数据集做处理时&#xff0c;往往会有些图像不用标注&#xff0c;那么这张图像是没有对应的标签文件的&#xff0c;这个时候又不…

MySQL 的事务概念

事务概念 MySQL事务是一个或者多个的数据库操作&#xff0c;要么全部执行成功&#xff0c;要么全部失败回滚。 事务是通过事务日志来实现的&#xff0c;事务日志包括&#xff1a;redo log和undo log。 事务状态 事务有以下五种状态&#xff1a; 活动的部分提交的失败的中止的…

模拟相机拍照——对文档进行数据增强

一. 背景 假如我们有一个标准文件&#xff0c;我们对其进行文字识别、版面分析或者其他下游任务就比较容易。然而&#xff0c;当图片是手机拍照获取的&#xff0c;图片中往往有阴影、摩尔纹、弯曲。 那么&#xff0c;如何通过标准的文档&#xff0c;获得类似相机拍照的图片呢&…

24华中杯C题10页论文+代码+思路

问题1&#xff1a;估算传感点的曲率 问题2&#xff1a;重构平面曲线 问题3&#xff1a;重构平面曲线并分析误差 详细资料如图所示 10页论文 需要的宝子们&#xff1a;2024华中杯A题思路数据可执行代码参考论文https://mbd.pub/o/bread/ZZ6am5dw 2024华中杯B题思路数据可执行…

C语言简单的数据结构:双向链表的实现

目录&#xff1a; 1.双向链表的结构和初始化1.1双向链表的结构1.2双向链表的初始化 2.双向链表的相关操作2.1双向链表的尾插、打印和头插2.11双向链表的尾插2.12双向链表的打印2.13双向链表的头插 2.2双向链表的尾删和头删2.21双向链表的尾删2.22双向链表的头删 2.3双向链表查找…

Linux 网络测速

1.开发背景 网络测速&#xff0c;为了测试开发板的网络速度是否达标的通用测试方法 2.开发需求 搭建 iperf3 &#xff0c;在 ubuntu 下安装服务端&#xff0c;在板卡上安装客户端&#xff0c;服务端和客户端互发 3.开发环境 ubuntu20.04 嵌入式开发板&#xff08;debian 千…

了解MySQL InnoDB多版本MVCC(Multi-Version Concurrency Control)

了解MySQL InnoDB多版本MVCC&#xff08;Multi-Version Concurrency Control&#xff09; 在数据库管理系统中&#xff0c;多版本并发控制&#xff08;MVCC&#xff09;是一种用于实现高并发和事务隔离的技术。MySQL的InnoDB存储引擎支持MVCC&#xff0c;这使得它可以在提供高…

22长安杯电子取证复现(检材一,二)

检材一 先用VC容器挂载&#xff0c;拿到完整的检材 从检材一入手&#xff0c;火眼创建案件&#xff0c;打开检材一 1.检材1的SHA256值为 计算SHA256值&#xff0c;直接用火眼计算哈希计算 9E48BB2CAE5C1D93BAF572E3646D2ECD26080B70413DC7DC4131F88289F49E34 2.分析检材1&am…

Spring (三) 之Aop及事务控制

文章目录 目标 一、AOP 思想和重要术语&#xff08;理解&#xff09;1、需求问题2、AOP3、AOP 术语 二、AOP 实现及 Pointcut 表达式&#xff08;了解&#xff09;1、AOP 规范及实现2、AspectJ3、AspectJ 切入点语法&#xff08;掌握&#xff09;3.1、切入点语法通配符3.2、切入…

Linux 网络基本命令

一、查看网络信息 ifconfig 二、关闭网络 ifdown ens33 (有的电脑不一定是ens33&#xff0c;具体看上图画线的地方) 三、开启网络 ifup ens33

【电路笔记】-数字逻辑门总结

数字逻辑门总结 文章目录 数字逻辑门总结1、概述2、逻辑门真值表3、总结 数字逻辑门有三种基本类型&#xff1a;与门、或门和非门。 1、概述 我们还看到&#xff0c;数字逻辑门具有与其相反或互补的形式&#xff0c;分别为“与非门”、“或非门”和“缓冲器”&#xff0c;并且…