PyTorch深度学习实战(40)——零样本学习(Zero-Shot Learning)

news2024/11/28 2:45:04

PyTorch深度学习实战(40)——零样本学习

    • 0. 前言
    • 1. 零样本学习
    • 2. 实现零样本学习模型
      • 2.1 模型分析
      • 2.2 构建零样本学习模型
    • 小结
    • 系列链接

0. 前言

零样本学习 (Zero-Shot Learning) 是一种机器学习方法,旨在解决传统监督学习中,当训练数据中不存在某个类别的样本时,如何对该类别进行分类的问题。在传统监督学习中,分类模型需要通过训练数据学习到每个类别的特征和模式,并在测试阶段根据这些学习到的知识对新样本进行分类。然而,在现实世界中,我们无法获得所有可能类别的训练样本,因此零样本学习成为了一种重要的解决方案。在本节中,我们将学习零样本学习的基本概念,并使用 PyTorch 实现零样本学习模型。

1. 零样本学习

零样本学习 (Zero-Shot Learning) 是一种用于在没有相关样本数据的情况下分类或识别新的物体类别的机器学习技术。与传统监督学习不同,它并不需要训练数据集中包含所有可能的类别,而是通过学习如何从类别的语义描述(例如属性或关系)中推断出新类别的特征,并将这些特征用于分类或识别。在实际应用中,零样本学习可以帮助我们克服由于数据收集和标注成本高昂而产生的数据样本不足问题。
在零样本学习中,模型必须利用词向量自动生成属性(没有为训练提供属性),词向量包含单词之间的语义相似性。例如,所有动物都会有相似的词向量,而汽车则与动物之间的词向量表示有较大差距,本节中,我们将使用预训练的词向量。具有相似上下文的单词具有相似的词向量,词向量的 t-SNE 表示示例如下所示:

聚类结果

在上图中,可以看到相同类别的样本在二维空间中相互聚集,相似的类别也有相似的词向量。因此,单词就像图像一样,也有矢量嵌入,可以用于获取它们之间的相似性。
在下一小节中,实现零样本学习模型,利用以上原理识别模型在训练期间没有见到的类别。本质上,我们将直接学习如何将图像特征映射到单词特征。

2. 实现零样本学习模型

2.1 模型分析

在本节中,我们将使用 PyTorch 实现零样本学习模型,模型构建策略如下:

  • 导入训练数据集
  • 从预训练的词向量模型中获取每个类别对应的词向量
  • 将图像输入预训练的图像分类模型,如 VGG16
  • 网络预测输出图像中物体对应的词向量
  • 训练模型后,在新的测试图像上预测词向量
  • 最接近预测词向量的词向量类别作为测试图像的类别

2.2 构建零样本学习模型

接下来,使用 PyTorch 实现以上策略。

(1) 访问 GitHub 下载相关数据,并解压。

(2) 导入相关库:

import gzip
import _pickle as cPickle

import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder, normalize
device = 'cuda' if torch.cuda.is_available() else 'cpu'

(3) 定义特征数据的路径 (DATAPATH) 以及 word2vec 嵌入 (WORD2VECPATH):

WORD2VECPATH = "zero-shot-learning-master/data/class_vectors.npy"
DATAPATH = "zero-shot-learning-master/data/zeroshot_data.pkl"

(4) 提取可用类别列表:

with open('zero-shot-learning-master/src/train_classes.txt', 'r') as infile:
    train_classes = [str.strip(line) for line in infile]

(5) 加载特征向量数据:

with gzip.GzipFile(DATAPATH, 'rb') as infile:
    data = cPickle.load(infile)

(6) 定义训练数据和属于零样本类别的数据(训练期间不存在的类别)。在训练期间,只显示训练类别,并隐藏零样本模型类别:

training_data = [instance for instance in data if instance[0] in train_classes]
zero_shot_data = [instance for instance in data if instance[0] not in train_classes]
np.random.shuffle(training_data)

(7) 每个类别获取 300 张图像样本进行训练,其余图像用于进行验证:

train_size = 300 # per class
train_data, valid_data = [], []
for class_label in train_classes:
    ctr = 0
    for instance in training_data:
        if instance[0] == class_label:
            if ctr < train_size:
                train_data.append(instance)
                ctr+=1
            else:
                valid_data.append(instance)

(8) 打乱训练和验证数据,并将与类别对应的向量提取到字典 vectors 中:

np.random.shuffle(train_data)
np.random.shuffle(valid_data)
vectors = {i:j for i,j in np.load(WORD2VECPATH, allow_pickle=True)}

(9) 获取训练和验证数据的图像和词嵌入特征:

train_data = [(feat, vectors[clss]) for clss,feat in train_data]
valid_data = [(feat, vectors[clss]) for clss,feat in valid_data]

(10) 获取训练、验证和零样本类别:

train_clss = [clss for clss,feat in train_data]
valid_clss = [clss for clss,feat in valid_data]
zero_shot_clss = [clss for clss,feat in zero_shot_data]

(11) 定义训练数据、验证数据和零样本数据的输入和输出数组:

x_train, y_train = zip(*train_data)
x_train, y_train = np.squeeze(np.asarray(x_train)), np.squeeze(np.asarray(y_train))
x_train = normalize(x_train, norm='l2')

x_valid, y_valid = zip(*valid_data)
x_valid, y_valid = np.squeeze(np.asarray(x_valid)), np.squeeze(np.asarray(y_valid))
x_valid = normalize(x_valid, norm='l2')

y_zsl, x_zsl = zip(*zero_shot_data)
x_zsl, y_zsl = np.squeeze(np.asarray(x_zsl)), np.squeeze(np.asarray(y_zsl))
x_zsl = normalize(x_zsl, norm='l2')

(12) 定义训练、验证数据集和数据加载器:

from torch.utils.data import TensorDataset

trn_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_train, y_train]])
val_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_valid, y_valid]])

trn_dl = DataLoader(trn_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)

(13) 构建模型,以 4096 维特征作为为输入,预测 300 维向量为输出:

def build_model(): 
    return nn.Sequential(
        nn.Linear(4096, 1024), nn.ReLU(inplace=True),
        nn.BatchNorm1d(1024), nn.Dropout(0.8),
        nn.Linear(1024, 512), nn.ReLU(inplace=True),
        nn.BatchNorm1d(512), nn.Dropout(0.8),
        nn.Linear(512, 256), nn.ReLU(inplace=True),
        nn.BatchNorm1d(256), nn.Dropout(0.8),
        nn.Linear(256, 300)
    )

(14) 定义函数在批数据上训练和验证模型:

def train_batch(model, data, optimizer, criterion):
    ims, labels = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    ims, labels = data
    _preds = model(ims)
    loss = criterion(_preds, labels)
    return loss.item()

(15) 训练模型:

model = build_model().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 60
trn_loss_epochs = []
val_loss_epochs = []

for ex in range(n_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for bx, data in enumerate(trn_dl):
        loss = train_batch(model, data, optimizer, criterion)
        pos = (ex + (bx+1)/N)
        trn_loss.append(loss)
    trn_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss = validate_batch(model, data, criterion)
        pos = (ex + (bx+1)/N)
        val_loss.append(loss)
    val_loss_epochs.append(np.average(val_loss))
        
    if ex == 10:
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
    if ex == 40:
        optimizer = optim.Adam(model.parameters(), lr=1e-5)

epochs = np.arange(n_epochs)+1
plt.plot(epochs, trn_loss_epochs, 'bo', label='Training loss A')
plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss B')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型性能检测

(16) 预测属于零样本类别(模型未见过的类别)的图像 (x_zsl),并获取与所有可用类别相对应的实际特征 vectors 和类别名 classnames

pred_zsl = model(torch.Tensor(x_zsl).to(device)).cpu().detach().numpy()

class_vectors = sorted(np.load(WORD2VECPATH, allow_pickle=True), key=lambda x: x[0])
classnames, vectors = zip(*class_vectors)
classnames = list(classnames)

vectors = np.array(vectors)

(17) 计算每个预测向量与可用类别对应的向量之间的距离,并测量前五个预测中出现的零样本类别的数量:

dists = (pred_zsl[None] - vectors[:,None])
dists = (dists**2).sum(-1).T

best_classes = []
for item in dists:
    best_classes.append([classnames[j] for j in np.argsort(item)[:5]])

print(np.mean([i in J for i,J in zip(zero_shot_clss, best_classes)]))
# 0.7328664332166083

从以上结果可以看出,在模型的前 5 个预测中,可以正确预测约 73% 的图像,其中包含训练期间不存在类别的对象。其中,前 123 个预测的正确分类图像的百分比分别为 6%14%40%

小结

在零样本学习中,每个类别通常都与一些语义属性或描述相关联,这些属性可以包括文本描述、语义嵌入或语义关系等。模型将这些语义信息与特征空间进行联系,从而能够根据语义相似度将新的未见类别样本归类到正确的类别中。具体来说,零样本学习的过程可以分为两个主要步骤:建模和推理。在建模阶段,模型需要学习到每个类别的语义表示,通常是将语义属性映射到一个低维的嵌入空间中。在推理阶段,当遇到一个未见类别的样本时,模型会将其与已知类别的语义表示进行比较,并基于相似度进行分类。在本节中,我们学习了零样本分类模型,用以在训练中不存在某个类别的图像时进行预测。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现
PyTorch深度学习实战(37)——CycleGAN详解与实现
PyTorch深度学习实战(38)——StyleGAN详解与实现
PyTorch深度学习实战(39)——小样本学习(Few-shot Learning)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1560633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pulsar存在大量消费未ack的原因

问题起源&#xff1a; 某产品灰度上线后&#xff0c;从pulsar服务端监控发现存在一种现象&#xff1a;消费但未ack的信息不断增加&#xff0c;直到3000左右就稳定下来了且消费速度为0&#xff0c;但不清楚这3000是怎么来的&#xff0c;因为代码是消费到立马ack的&#xff1b; …

格雷希尔G10系列L150A和L200A气动快速连接器,在新能源汽车线束线缆剥线后的气密性测试密封方案

线束线缆在很多用电环境都有使用&#xff0c;比如说新能源汽车&#xff0c;从电池包放电开始&#xff0c;高低压、通讯都开始进行工作&#xff0c;线束在连接的地方需要具有较高的气密性和稳定性&#xff0c;才能保证车辆在不同环境下能够正常的运行。 线束在组装铜鼻子前需要剥…

【Oracle篇】expdp/impdp高效完成全部生产用户的全库迁移(第四篇,总共四篇)

☘️博主介绍☘️&#xff1a; ✨又是一天没白过&#xff0c;我是奈斯&#xff0c;DBA一名✨ ✌✌️擅长Oracle、MySQL、SQLserver、Linux&#xff0c;也在扩展大数据方向的知识面✌✌️ ❣️❣️❣️大佬们都喜欢静静的看文章&#xff0c;并且也会默默的点赞收藏加关注❣️❣️…

基于SpringBoot的游戏商城系统的设计与实现(论文+源码)_kaic

目录 1前言 1.1研究的背景及意义 1.2国内外的研究状况和发展趋势 2需求分析 2.1系统需求分析 2.1.1技术可行性 2.1.2经济可行性 2.1.3操作可行性 2.2系统的开发环境 2.2.1 Springboot框架 2.2.2 数据库Mysql 2.2.3 IntelliJ IDEA平台 2.2.4 Mybatis和MyBatis-plus 2.2.5 前端框…

火鸟门户同城模块

同城活动 同城活动是指在同一城市举办的活动&#xff0c;可以是多种类型&#xff0c;例如&#xff1a; 聚会&#xff1a;朋友聚会、同学聚会、兴趣爱好聚会等。展览&#xff1a;艺术​​展览、科技展览、文化展览等。演出节目&#xff1a;演唱会、音乐会、戏剧表演等。比赛项…

JumpServer 堡垒主机

JumpServer 堡垒机帮助企业以更安全的方式管控和登陆各种类型的资产 SSH&#xff1a;Linux/Unix/网络设备等Windows&#xff1a;Web方式连接/原生RDP连接数据库&#xff1a;MySQL、Oracle、SQLServer、PostgreSQL等Kubernetes&#xff1a;连接到K8s集群中的PodsWeb站点&#x…

Backend - gitea 首次建库(远端本地)

目录 一、建立远端储存库 1. 进入新增画面 2. 填写储存库名称&#xff08;如book&#xff09;&#xff0c;点击“建立”即可 二、本地关联远端储存库 1. 本地初始化储存库代码 &#xff08;1&#xff09;新建文件夹 &#xff08;2&#xff09;获取远端储存库 2. 本地编写…

前端学习<二>CSS基础——14-CSS3属性详解:Web字体

前言 开发人员可以为自已的网页指定特殊的字体&#xff08;将指定字体提前下载到站点中&#xff09;&#xff0c;无需考虑用户电脑上是否安装了此特殊字体。从此&#xff0c;把特殊字体处理成图片的方式便成为了过去。 支持程度比较好&#xff0c;甚至 IE 低版本的浏览器也能…

HTML期末作业-香水网站,逐步讲解

知名品牌 CHANEL I wear nothing but a few drops of Chanel No.5. 了解更多 GIVENCHY 纪梵希香水几乎就是赫本本人的化身——经典、优雅、高贵、简洁、女性化 了解更多 DIOR Dior Addict the now fragrance from Dior. 了解更多 BURBUEEY The good things in life neve…

大数据技术之 Apache Doris(一)

第 1 章 Doris 简介 1.1 Doris 概述 Apache Doris 由百度大数据部研发&#xff08;之前叫百度 Palo&#xff0c;2018 年贡献到 Apache 社区后&#xff0c;更名为 Doris &#xff09;&#xff0c;在百度内部&#xff0c;有超过 200 个产品线在使用&#xff0c;部署机器超过 10…

抽象类和接口(2)(接口部分)

❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&a…

微信公众号账号迁移主体怎么变更?

公众号迁移有什么作用&#xff1f;只能变更主体吗&#xff1f;大家都知道&#xff0c;公众号是不支持直接变更主体的&#xff1b;但是很多情况下&#xff0c;我们又不得不进行账号主体的更换&#xff1b;这时候&#xff0c;我么就可以通过账号迁移功能&#xff0c;将A公众号的粉…

MySQL使用技巧,高级Java开发必看

insert into tab(col1,col2…) select … 5、活用正则表达式 regexp ^ $ . * | 6、关联查询比子查询效率快&#xff0c;优先使用join关联查询 7、if(exp,v1,v2) if()函数的使用 exp:表达式 v1:exp为真时返回的值 v2:exp为假时返回的值 8、case when… then… else… en…

广和通发布基于高通高算力芯片的具身智能机器人开发平台Fibot

3月29日&#xff0c;为助力机器人厂商客户快速复现及验证斯坦福Mobile ALOHA机器人的相关算法&#xff0c;广和通发布具身智能机器人开发平台Fibot。作为首款国产Mobile ALOHA机器人的升级配置版本&#xff0c;开发平台采用全向轮底盘设计、可拆卸式训练臂结构&#xff0c;赋予…

AI如何影响装饰器模式与组合模式的选择与应用

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》《MYSQL应用》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 &#x1f680; 转载自热榜文章&#xff1a;设计模式深度解析&#xff1a;AI如何影响…

[C/C++] -- 二叉树

1.简介 二叉树是一种每个节点最多有两个子节点的树结构&#xff0c;通常包括&#xff1a;根节点、左子树、右子树。 满二叉树&#xff1a; 如果一棵二叉树只有度为0的结点和度为2的结点&#xff0c;并且度为0的结点在同一层上&#xff0c;则这棵二叉树为满二叉树。深度为k&a…

10、电科院FTU检测标准学习笔记-双遥信及变位优先验证

———————————————————————————————————— 作者简介&#xff1a; 本人从事电力系统多年&#xff0c;岗位包含研发&#xff0c;测试&#xff0c;工程等&#xff0c;具有丰富的经验 在配电自动化验收测试以及电科院测试中&#xff0c;本人全程参…

电脑win10系统更新后开机很慢,更新win10后电脑开机怎么变慢了

很多用户反映,更新win10后电脑开机怎么变慢了呢?现在动不动就要30几秒,以前都是秒开机的,要怎么设置才能提高开机速度?小伙伴们别着急,主要原因可能是关机设置中没有勾选启用快速启动,或者是开机启动设置的问题,针对开机变慢的情况,小编整理了2个处理方法,接下来,我…

U盘PE引导-系统安装操作

U盘PE引导-系统安装操作 1. U盘接入电脑&#xff0c;开机按F12&#xff08;DELL&#xff09;选择U盘引导&#xff0c;按回车&#xff0c;如图2.选择进入PE 系统3.进入PE 系统后&#xff0c;运行 WinNTSetup 安装器&#xff0c; 具体 设置如下图 1. U盘接入电脑&#xff0c;开机…

C语言键盘输入与屏幕输出——数据的格式化键盘输入

目录 数据的格式化键盘输入 输入数据的格式控制 scanf&#xff08;&#xff09;的格式字符 scanf()的格式修饰符 数据的格式化键盘输入 格式 scanf&#xff08;格式控制字符串&#xff0c;输入地址表&#xff09;&#xff1b; 输入数据的格式控制 格式 scanf&#xff08;…