笔记:Few-Shot Learning小样本分类问题 + 孪生网络 + 预训练与微调

news2024/9/22 8:35:15

内容摘自王老师的B站视频,大家还是尽量去看视频,老师讲的特别好,不到一小时的时间就缕清了小样本学习的基础知识点~Few-Shot Learning (1/3): 基本概念_哔哩哔哩_bilibili

Few-Shot Learning(小样本分类)

假设现在每类只有一两个样本,计算机能否做到像人一样的正确分类?

  • 这个例子Support Set有两类,每类只有一两个样本,靠这些样本,难以训练出一个深度神经网络,这个集合只能提供一些参考信息。对于小样本问题,不能用传统的分类方法。

小样本分类与传统的监督学习有所不同,小样本学习的目标不是让机器通过学习训练集中图片,知道哪类是什么样子;当我拿一个很大的训练集来训练神经网络后进行小样本分类,预训练模型的目的是让机器自己学会学习-----也就是学习事物的异同,学会区分不同的事物。

现在训练集有五类,其中并没有松鼠这个类别

训练完成之后,可以问模型这两张图片是否是相同的东西呢?这时候模型已经学会分辨了事物的异同,比如给出两张松鼠图片,模型知道这两个动物之间长得很像,模型能够告诉你两张图片很可能是相同的东西。

支持集

给出一张图片,神经网络不知道这是什么。

这时候就需要支持集(Support Set),每类给出少样本(1~2)张,神经网络将Query图片和支持集中的每个类别依次对比,找出最相似的。

训练集和支持集的区别
  • 训练集规模很大,每类有很多张图片,可以训练一个深度神经网络

  • 支持集每类只有一张或几张图片,不足以训练一个大的神经网络,只能在做预测时候提供一些额外信息。

  • 用足够大的训练集训练的目的不是让模型识别训练集中的大象、老虎,而是知道事物的异同。对于训练的模型,只要提供含有该类别的小样本信息,模型就能区分类别,尽管训练集中没有这个类别。

小样本分类:Learn To Learn

带小朋友去动物园,小朋友不知道这个动物是什么,但是小朋友只需要翻一遍卡片(将目标与卡片上动物对应),就知道看到的动物是什么,这个卡片就是支持集,前提是小朋友有读卡片的能力,也就是得先经过训练学习。

如果卡片中每类只有一张,那就是One-Shot Learning(单样本学习)

传统监督学习 和 小样本学习 步骤的区别

  • 传统监督学习:测试图片虽然不是训练集中图片,但包含在训练集类别,模型已经见过上千张该类别图片,能够判断出是哪类。

  • 小样本学习:测试图片不但不包含在训练集中,也不是训练集中的类别。所以小样本学习比传统监督学习更难。因为不是训练集中的类别,所以要提供支持集,提供更多信息(给模型看小卡片,每张卡片有一个图片和一个标签,模型发现测试图片和某张卡片相似度高,就知道测试图片属于哪个标签)

小样本学习两个术语
  • k-way :支持集含有的种类数

  • n-shot : 支持集中每个种类有多少张图片

小样本学习预测准确率

  • 横轴是支持集类别数量。随着类别数量增加,分类准确率会降低。

  • 比如从三选一变成六选一

  • 每类样本越多,做预测越容易

相似度函数

sim(x, x'), x,x'为两个input

理想情况:sim(x1,x2) = 1 , sim(x1,x3) = 0, sim(x2,x3) = 0

从一个很大的训练集上学习一个相似度函数,它可以判断两张图片的相似度有多高。

孪生神经网络就可以作为相似度函数,可以拿大规模数据集做训练,训练结束之后,可以拿得到的相似度函数做预测。给一个测试图片,可以拿他跟支持集中的图片逐一对比,计算相似度,找到相似度最高作为预测结果。

  • Omniglot 特点:小样本(20个,105*105)

孪生网络(Siamese Network)

孪生网络要解决的问题
  • 第一类,分类数量较少,每一类的数据量较多,比如ImageNet、VOC等。这种分类问题可以使用神经网络或者SVM解决,只要事先知道了所有的类。

  • 第二类,分类数量较多(或者说无法确认具体数量),每一类的数据量较少,比如人脸识别、人脸验证任务。(少样本问题)

孪生网络的优点
  • 这个网络主要的优点是淡化了标签,使得网络具有很好的扩展性,可以对那些没有训练过的小样本类别进行分类,这点是优于很多算法的。

第一种训练孪生网络方法:每次取2样本,比较相似度 。
  • 训练这个神经网络要用一个大的数据集,每类有标注,每类下面都有很多个样本。

  • 我们需要用训练集来构造正样本和负样本

    • 正样本告诉神经网络什么东西是同一类。

    • 负样本告诉神经网络事物之间的区别。

  • 正样本获取

    • 每次从训练集中抽取一张图片(老虎),然后从同一类中随机抽取另一张图片(老虎),标签设置为1 (tiger, tiger, 1),意思是相似度满分。

  • 负样本获取

    • 每次从训练集中抽取一张图片(汽车),排除汽车这个类别,再从数据集中随机抽样(大象),标签设计为(car, elephant, 0),意思是相似度为0

  • 搭建一个卷积神经网络CNN用来提取特征,这个神经网络有很多卷积层,Pooling层,以及一个flatten层。输入是一张图片x,输出是提取的特征向量 f(x)

  • 现在开始训练神经网络,输入为(x1, x2 , 0或1),把这两张图片输入神经网络,把刚才搭建的卷积神经网络记作函数f。

  • 对于提取的特征向量,第一张图片特征向量记作h1 = f(x1),第二张图片特征向量记作h2 = f(x2),如果都是用CNN,这两个f需要是相同的卷积神经网络,共享相同的权值W(之所以叫孪生,就是因为共享特征提取的部分)。也可以不同权值,则不同场景,允许不同神经网络。

  • 然后拿h1 - h2 得到一个向量,再对这个向量所有元素求绝对值,记作z = ||h1 - h2||,表示两个特征向量之间的区别,再用一些全连接层来处理z向量,输出一些标量。

  • 最后用Sigmoid激活函数,得到输出是一个介于0~1之间的实数,可以衡量两个图片之间的相似度。如果两张图片是同一个类别,输出应该接近1,如果两张图片不同类别,输出应该接近0(希望神经网络的训练输出接近1),把标签与预测之间的差别作为损失函数

  • 损失函数可以是标签与预测的交叉熵损失函数cross-entropy loss function,可以衡量标签与预测的差别

  • 有了损失函数可以用反向传播计算梯度,用梯度下降来更新模型参数。

  • 模型主要有两部分,一个是卷积神经网络f用来从图片提取特征,一个是全连接层预测相似度,训练部分就是更新这两个的参数

  • 做反向传播,梯度从损失函数传回到向量z以及全连接层的参数,有了损失函数关于全连接层的梯度,就可以更新全连接层的参数了。

  • 然后梯度进一步从向量z传回到卷积神经网络,更新卷积神经网络参数,这样就完成了一轮训练

  • 做训练时候,我们要准备同样数量正样本和负样本。负样本标签设置为0,希望神经网络预测接近0,意思是这两张图片不同。还是用同样方法做反向传播,更新参数。

训练好模型之后,可以做One-Shot Prediction

  • 六个类别,每个类别一张图片,这六个类别可以都不在训练集中

  • 将Query与Support Set支持集中图片作对比:

    • 将Query图片与支持集中某一类一张图片作为input1 和 input2 ,输入到孪生网络中,孪生网络会输出一个0~1之间的值。用同样方法算出Query与所有图片相似度,查找相似度最高的。

孪生网络第二种训练方法:Triplet Loss
准备数据
  • 有这样一个训练集,每次选出三张图片

  • 首先从训练集随机选一张图片,作为anchor(锚点),记录这个锚点,然后从同类中随机抽取一张图片作为正样本Positive;排除该类别,从数据集中作随机抽样,得到不同类别的负样本Negative。

  • 现在有锚点x^a,正样本x+,负样本x-,把三张图片分别输入卷积神经网络f来提取特征(f指的是同一个卷积神经网络),得到三个特征向量

  • 计算正样本和锚点再特征空间上的距离,将特征向量 f(x+)与f(xa)求差,然后算二范数的平方,得到距离d+

  • 类似操作得到d-

  • 我们希望得到的神经网络有这样性质,像同类别特征向量聚在一起,不同类别的特征向量能够被分开,所以d+应该很小,d-应该很大

  • 这个坐标系是特征空间,卷积神经网络可以把图片映射到这个特征空间

  • d-应该比d+大很多,否则模型分辨不了同类和不同类

  • 所以鼓励正样本在特征空间接近锚点(d+尽量小),鼓励负样本在特征空间远离锚点(d-尽量大)

  • 指定一个margin :α,α>0。如果d- >= d+ + α,我们就认为没有损失loss=0,分类正确。假如条件不满足,则会有loss = d+ + α - d- , 我们希望loss越小越好

  • 有了损失函数,就可以求损失函数关于神经网络的梯度,作梯度下降来更新模型参数

测试模型
  • 给一个query,一个支持集,用神经网络提取特征,把所有这些图片变为特征向量,比较特征向量之间的距离。找出距离最小的。

总结

我们使用了Siamese Network解决了少样本学习

基本思路:

  • 用一个比较大的训练集来训练孪生网络,让孪生网络知道事物之间的异同

  • 训练结束之后拿孪生网络作预测,解决少样本问题。少样本的问题是少样本的类别不在训练集中。比如query是松鼠,但训练集中没有松鼠这个类别,需要额外的信息来识别query的图片,这个额外的信息就是少样本支持集。

  • 支持集称为k-way, n-shot,k个类别,类别越多,预测越困难,n个样本,样本越少,预测越困难,one-shot learning单样本预测最困难。

  • 有了训练好的孪生网络,我们就可以将query与support set中的样本逐一对比,选出距离最小或相似度最高作为分类结果。

  • 两种训练孪生网络方法:1.两个input,标签0或1,输出0~1之间数值,与标签差值作为loss,目标是让预测尽量接近标签。 2.另一种是Triplet Loss,xa,x+,x-,用CNN提取得到三个特征向量,输出d+,d-,目标是让d+尽量小,d-尽量大。有了这样一个神经网络就可以用它提取特征,比较两张图片在特征空间距离,作出few-shot分类

Fine Tuning

基本思路

在大规模数据上预训练模型,然后再小规模的support set上做fine-tuning。方法简单,准确率高。

  • 看个例子,余弦相似度consine similarity,衡量两个向量之间相似度,现在两个向量长度都是1,即他们的二范数都为1。

  • 把向量x和w的夹角记作θ,由于向量x和w长度都是1,cosθ就是x和w的内积,表示两个向量的相似度

  • 可以理解,把向量x投影到w方向上,投影长度就是-1到+1之间

  • 如果向量x和w的长度不是1,则需要做归一化把他们程度变为1,然后求得的内积才是余弦相似度

微调主要用到Softmax Function
  • 它是一个常用的激活函数,可以把一个k维向量映射成一个概率分布

  • 输入为Φ,它是任意的k维向量。把Φ的每一个元素做指数变换,得到k个大于0的数;然后对其作归一化,让得到的k个数相加等于1,把得到的k个数记为向量p

  • 向量p就是softmax函数的输出

  • 性质

    • 输入Φ和输出p都是k维向量

    • 向量p的元素都是正数,而且相加等于1

    • 所以p是个概率分布

  • softmax通常用于分类器的输出层,如果有k个类别,那么softmax的输出就是k个概率值,每个概率值表示对一个类别的confidence

  • softmax会让最大的值变大,其余的值变小。softmax比max函数要温柔一些

Softmax分类器
  • 是一个全连接层加一个Softmax函数

  • 分类器的输入是特征向量x,表示输入的测试图片的特征向量,把x乘到参数矩阵w上,再加上向量b,得到一个向量

  • 对得到的向量做softmax变换,得到输出向量p

  • 假如类别数量为k,那么向量p就是k维的

  • 矩阵W和b是这一层的参数,可以从训练数据中学习。W有K行,k是类别数量,所以W每一行对应一个类别,d是每个类别的特征数量

使用预训练好的神经网络,在query和support set上做fine-tuning的过程
  • 把query和support set中的图片都映射成特征向量,这样可以比较query和support set在特征空间上的相似度,比如可以计算两两之间的cosine similarity。最后选择相似度最高的作为query的分类结果

  • 预训练

    • 搭一个卷积神经网络用来提取特征,有很多卷积层、Pooling层以及一个Flatten层,也可以有全连接层

    • 神经网络输入是一张图片x,输出一个特征向量f(x)

    • 可以用传统的监督学习,预训练好后把全连接层都去掉;也可以用孪生网络训练

  • Few-Shot分类方法

    • 3-way 2-shot,三类别,每类别两样本

    • 拿预训练的神经网络提取特征,每张图片变成一个特征向量,每个类别两个特征向量

    • 平均每个类别特征向量作平均,得到一个同样大小的向量,也就是均值向量

    • 有三个类别,一共得到三个均值向量

    • 均值向量归一化,得到三个向量μ1,μ2,μ3,它们的二范数都等于一,μ1,μ2,μ3就是对三个类别的表征

    • 做分类的时候,要拿query的特征向量对μ1,μ2,μ3作对比

  • 对query作分类

    • 给一张query图片,需要判断是三个类别中的哪一个

    • 拿预训练的神经网络f来提取特征,得到一个特征向量

    • 对特征向量作归一化,得到向量q,它的二范数等于1

    • 与刚才从support set中提取的三个向量μ1,μ2,μ3,它们的二范数也是1,每个μ向量表征一个类别

    • 可以把三个μ向量堆叠起来,作为矩阵M的三个行向量

  • 做few-shot预测

    • query的特征向量q乘到矩阵M上,再做Softmax变换,得到p = Softmax(Mq),p是个概率分布,这个例子里,p是三维向量,表示对三个类别的confidence

    • 三个元素分别是q与μ1,μ2,μ3的内积

    • 很显然,在向量p中,第一个元素最大,分类结果是第一类

Fine-tuning可以大幅提高预测准确率
  • 基本都是先做预训练,后做Fine-Tuning

  • 刚才我们用了固定的W和b,没有学习这两个参数

  • 可以在Support Set上学习W和b,这叫做fine tuning

    • Cross Entropy来衡量yj与pj的差别有多大,yj是真实标签,pj是分类器做出的预测,损失函数就是Cross Entropy Loss

    • Support set中有几个或者几十个有标注的样本,每个样本都对应一个Cross Entropy Loss,把这些Cross entropy loss加起来,作为损失函数

    • 也就是说我们用support set中所有的图片和标签来学习这个分类器

    • CrossEntropyLoss做最小化Minimization,让预测pj尽量接近真实标签yj

    • Minimization是对分类器参数W和b求的,希望学习W和b;当然也可以让梯度传播到卷积神经网络,更新神经网络参数,让提取的特征向量更有效

    • support通常很小几十个到几百个样本,最好加个regularization来防止过拟合。有一篇文章建议用Entropy Regularization

  • 有一篇ICLR2020的论文说 对于5-way 1-shot,做fine tuning可以提到2%~7%的准确率;对5-way 5-shot,提高1.5%~4%准确率

  • 尽管support set很小,但用support set来训练分类器有助于提高准确率,预训练+fine tuning比只用预训练好很多

  • W,b默认值

  • Entropy Regularization防止过拟合

    • 希望Entropy Regularization越小越好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux】基础I/O——动静态库的制作

我想把我写的头文件和源文件给别人用 1.把源代码直接给他2.把我们的源代码想办法打包为库 1.制作静态库 1.1.制作静态库的过程 我们先看看怎么制作静态库的! makefile 所谓制作静态库 需要将所有的.c源文件都编译为(.o)目标文件。使用ar指令将所有目标文件打包…

Linux应用——网络基础

一、网络结构模型 1.1C/S结构 C/S结构——服务器与客户机; CS结构通常采用两层结构,服务器负责数据的管理,客户机负责完成与用户的交互任务。客户机是因特网上访问别人信息的机器,服务器则是提供信息供人访问的计算机。 例如&…

[2019红帽杯]Snake

[2019红帽杯]Snake-CSDN博客 unity的题 下载下来看看是什么类型就是 这道题就是贪吃蛇 unity无脑找Assembly 用dnspy打开 一般就在这里慢慢找 但是你可以发现没有任何的信息 这里外接库 只能从这里下手试试 64位链接库的意思 游戏题,win!很关键 进入了Gameobject 看a1,小…

复现Android中GridView的bug并解决

几年前的一个bug,GridView的item高度不一致。如下图: 复现bug的代码: import android.os.Bundle; import android.widget.BaseAdapter; import android.widget.GridView; import androidx.appcompat.app.AppCompatActivity; import java.uti…

【Day12】登录认证、异常处理

1 登录 先创建一个新的 controller 层:LoginController RestController public class LoginController {Autowiredprivate EmpService empService;// 注入PostMapping("/login")public Result login(RequestBody Emp emp) { // 包装对象Emp e empServic…

html 单页面引用vue3和element-plus

引入方式: element-plus基于vue3.0,所以必须导入vue3.0的js文件,然后再导入element-plus自身所需的js以及css文件,导入文件有两种方法:外部引用、下载本地使用 通过外部引用ElementPlus的css和js文件 以及Vue3.0文件 …

Golang | Leetcode Golang题解之第260题只出现一次的数字III

题目: 题解: func singleNumber(nums []int) []int {xorSum : 0for _, num : range nums {xorSum ^ num}lsb : xorSum & -xorSumtype1, type2 : 0, 0for _, num : range nums {if num&lsb > 0 {type1 ^ num} else {type2 ^ num}}return []in…

【数据结构】二叉树OJ题_对称二叉树_另一棵的子树

对称二叉树 题目 101. 对称二叉树 - 力扣(LeetCode) 给你一个二叉树的根节点 root , 检查它是否轴对称。 示例 1: 输入:root [1,2,2,3,4,4,3] 输出:true示例 2: 输入:root [1,2…

不同类型的指针变量进行++操作的效果

可以看到 不同变量的指针进行操作的时候,他的地址移动的大小是不一样的 运行了打印了一些东西 , 没想到可以用sizeof来打印出 names[0][]这个字符串的长度方法 , 只能用这个 strlen1来判断这个字符串的长度。

使用minio cllient(mc)完成不同服务器的minio的数据迁移和mc基本操作

minio client 前言使用1.拉取minio client 镜像2.部署mc容器3.添加云存储服务器4.迁移数据1.全量迁移2.只迁移某个桶3.覆盖重名文件 5.其他操作1.列出所有alias、列出列出桶中的文件和目录1.1.列出所有alias1.2.列出桶中的文件和目录 2.创建桶、删除桶2.1.创建桶2.2.删除桶 3.删…

DX-10A信号继电器 柜内安装,板前接线 约瑟JOSEF

DX-10型闪光信号继电器型号: DX-10A闪光信号继电器; DX-10B闪光信号继电器; DX-10C闪光信号继电器; 用途 DX-10 闪光继电器用于电力系统断路器的位置信号灯不对应闪光,该继电器是为了适应当前推广使用发光二极管节能指示灯而…

“狂飙”过后,大模型未来在何方?

2024年6月14日,第六届“北京智源大会”在中关村展示中心开幕。 开幕现场,智源研究院、OpenAI、百度、零一万物、百川智能、智谱AI、面壁智能等国内主流大模型公司CEO与CTO,人工智能顶尖学者和产业专家,在围绕人工智能关键技术路径…

rockchip的yolov5 rknn python推理分析

rockchip的yolov5 rknn推理分析 对于rockchip给出的这个yolov5后处理代码的分析,本人能力十分有限,可能有的地方描述的很不好,欢迎大家和我一起讨论,指出我的错误!!! RKNN模型输出 将官方的Y…

GD 32 环形队列

1.0 为什么要使用环形队列 在代码中使用环形队列进行程序的编写,由于在实际开发过程中,会出现接收数据频率太快快于主流程读取数据的频率,这个时候后面来的数据会覆盖前面一包数据,这个时候可以使用环形队列的方式解决这个问题。 …

离散数学,格与子格,格的性质,格的代数系统定义,格的同态与同构,特殊格

目录 1.格与子格 相互对偶 2.格的性质 对偶式 格的保序性 3.格的代数系统定义 格对应的偏序关系就是s的子集之间的包含关系 该格对应的偏序关系就是整除关系 子格必然是格 4.格的同态与同构 格同态,序同态 同态是保序的 例子 5.特殊格 全下…

明星应援系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,线上应援管理,线下应援管理,应援物品管理,购买订单管理,集资应援管理,集资订单管理,市集订单管理&#xff0…

CentOS部署MySQL

1.配置yum仓库 #更新秘钥 rpm --import https://repo.mysql.com/RPM-GPG-KEY-mysql-2023 #安装MySQL rpm -Uvh http://dev.mysql.com/get/mysql80-community-release-el7-2.noarch.rpm 2.使用yum安装MySQL yum -y install mysql-community-server 3.启动MySQL并配置开机自启…

PCB系统学习(1)--PCB印制电路板

PCB印制电路板 1.1PCB的定义1.2PCB的层叠结构1.2.1PCB单层板1.2.2PCB双层板1.2.3PCB四层板 1.3PCB的通孔,盲孔,埋孔1.4元器件的符号与封装1.5PCB的生产过程 1.1PCB的定义 PCB(PrintedCircuitBoard),中文即印制电路板,或印刷线路板…

C语言八皇后问题可视化界面

插件使用easyx 以下是部分代码。需要源码的私信 #include<stdio.h> #include<easyx.h> #define width 1100//设置窗口的宽度和高度 #define height 900 int place[8] { 0 };//皇后位置 int flag[8] { 1,1,1,1,1,1,1,1 };//定义列 int d1[15] { 1,1,1,1,1,1,1,…

【Node.js基础03】利用http模块创建Web服务

一&#xff1a;使用步骤 1 加载http模块&#xff0c;并创建Web服务程序 2 利用Web服务程序监听request事件&#xff0c;设置响应头和响应体 3 配置端口号并启动Web服务 4 浏览器请求设置的端口号&#xff0c;进行Web服务程序测试 二&#xff1a;简单应用 const http requir…