第八章 模型篇:transfer learning for computer vision

news2025/1/11 8:05:49

参考教程:
transfer-learning
transfer-learning tutorial

文章目录

  • transfer learning
    • 对卷积网络进行finetune
    • 把卷积网络作为特征提取器
    • 何时、如何进行fine tune
  • 代码示例
    • 加载数据集
    • 构建模型
      • fine-tune 模型
      • 模型作为feature extractor
    • 定义train_loop和test_loop
    • 定义超参数,开始训练
    • 结果可视化

transfer learning

很少会有人从头开始训练一个卷积神经网络,因为并不是所有人都有机会接触到大量的数据。常用的选择是在一个非常大的模型上预训练一个模型,然后用这个模型为基础,或者固定它的参数用作特征提取,来完成特定的任务。

对卷积网络进行finetune

进行transfer-learning的一个方法是在基于大数据训练的模型上进行fine-tune。可以选择对模型的每一个层都进行fine-tune,也可以选择freeze特定的层(一般是比较浅的层)而只对模型的较深的层进行fine-tune。理论支持是,模型的浅层通常是一些通用的特征,比如edge或者colo blob,这些特征可以应用于多种类型的任务,而高层的特征则会更倾向于用于训练的原始数据集中的数据特点,因为不太能泛化到新数据上去。

把卷积网络作为特征提取器

将ConvNet作为一个特征提取器,通常是去掉它最后一个用于分类的全连接层,把剩余的层用来提取新数据的特征。你可以在该特征提取器后加上你自己的head,比如分类head或者回归head,用于完成你自己的任务。

何时、如何进行fine tune

使用哪种方法有多种因素决定,最主要的因素是你的新数据集的大小和它与原始数据集的相似度。

  • 当你的新数据集很小,并和原始数据集比较相似时。
    因为你的数据集很小,所以从过拟合的角度出发,不推荐在卷积网络上进行fine-tune。又因为你的数据和原始数据比较相似,所以卷积网络提取的高层特征和你的数据也是相关的。因此你可以直接卷积网络当作特征提取器,在此基础上训练一个线性分类器。
  • 当你的新数据集很大,并和原始数据集比较相似时。
    新数据集很大时,我们可以对整个网络进行fine-tune,因为我们不太会有过拟合的风险。
  • 当你的新数据集很小,并和原始数据集不太相似时。
    因为你的数据集很小,我们还是推荐只训练一个线性的分类器。但是新数据和原始数据又不相似,所以不建议在网络顶端接上新的分类器,因为网络顶端包含很多的dataset-specific的特征,所以更推荐的是从浅层网络的一个位置出发接上一个分类器。
  • 当你的新数据集很大,并和原始数据集不太相似时。
    因为你的数据集很大,我们仍然选择对整个网络进行fine-tune。因为通常情况下以一个pretrained-model对模型进行初始化的效果比随机初始化要好。

代码示例

我们使用与第四章 模型篇:模型训练与示例一样的流程进行模型训练。

加载数据集

首先是加载数据集,方便起见我们直接使用torchvision中的cifar10数据进行训练。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)


test_data = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

train_dataloader = DataLoader(training_data, batch_size = 64)
test_dataloader = DataLoader(test_data, batch_size = 64)

使用官方提供的代码对我们的dataset进行可视化,注意训练时使用的batchsize为64,这里可视化时为了方便暂时使用了batchsize=4。
在这里插入图片描述

构建模型

在第四章中我们用了自定义的model。在这里我们使用预训练好的模型,并对模型结构进行修改。

transfer-learning对模型的处理有两种,一种是fine-tune整个模型,一种是将模型作为feature-extractor。第二种和第一种的区别是,模型中的部分层被freeze,不在训练过程中更新。

fine-tune 模型

model_ft = models.resnet18(weights = 'IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10) # 因为cifar10是十分类,所以输出这里为10

模型作为feature extractor

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False  # requires_grad 设为False,不随训练更新

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 10)

定义train_loop和test_loop

这两个部分直接参考第四章的代码就可以,复制过来直接使用。

# 训练过程的每个epoch的操作,代码来自pytorch_tutorial
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        optimizer.zero_grad() # 重置梯度计算
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward() # 反向传播计算梯度
        optimizer.step() # 调整模型参数
        

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

定义超参数,开始训练

全都准备好以后,我们定义一下要使用的优化器和loss,和一些别的超参数,就可以开始训练了。

learning_rate = 1e-3
momentum=0.9
epochs = 20

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=momentum)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model_ft, loss_fn, optimizer)
    test_loop(test_dataloader, model_ft, loss_fn)
print("Done!")

因为是在个人pc跑的,所以就随便放一个效果。。。。。
在这里插入图片描述

结果可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/663017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【K8S系列】如何高效查看 k8s日志

序言 你只管努力,其他交给时间,时间会证明一切。 文章标记颜色说明: 黄色:重要标题红色:用来标记结论绿色:用来标记一级论点蓝色:用来标记二级论点 Kubernetes (k8s) 是一个容器编排平台&#x…

【C#每日一记】多线程实现的贪吃蛇原理—不允许你还不知道

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

软件测试技能,JMeter压力测试教程(一)

目录 前言 一、安装Java环境 二、安装JMeter环境 三、启动JMeter脚本测试 四、查看报告文件 前言 使用jmeter做压测的时候,在windows上不太稳定,所有一直在 Linux 服务器上使用 jmeter 做压力测试 本篇记录下 Linux上搭建 jmeter 环境&#xff0c…

分布式学习第二天 redis学习

目录 1. 数据库类型 1.1 基本概念 1.2 关系/非关系型数据库搭配使用 2. Redis 2.1 基本知识点 2.2 redis常用命令 2.4 redis数据持久化 3 hiredis的使用 4. 复习 1. 数据库类型 1.1 基本概念 关系型数据库 - sql 操作数据必须要使用sql语句 数据存储在磁盘 存储的…

如何使用CDN给OSS做加速详解

意义 用户直接访问OSS资源,速度会受到OSS下行带宽以及Bucket地域的限制,若通过CDNOSS的方式进行访问,带宽上限更高,并且可以将OSS的资源缓存至就近的CDN节点,通过CDN节点进行分发,可以缩短网络传输距离&am…

Linux学习之CentOS(八)--Linux系统的分区概念

不知不觉已经记录了8篇Linux学习随笔了,虽然还是漂浮在Linux系统的表面,还有很多很多没有学,但是坚持学下去、坚持写下去就是成功的!!!! 在讲Linux系统分区之前,首先得介绍一下硬盘…

【SpringCloud】2.微服务的熔断和降级

目 录 1. 熔 断1.1 发生场景1.2 熔断实现1.3 熔断测试 2. 降 级2.1 发生场景2.2 降级处理2.3 降级测试 在 上篇博客,我们完成了项目的基本搭建工作,那这篇博客就来实现一下微服务的熔断和降级。 1. 熔 断 1.1 发生场景 在前面,我们用 spri…

【Java高级语法】(八)反射机制:有朋友问反射到底是怎样玩的?看完这篇文章你就清楚了~

Java高级语法详解之反射机制 :one: 概念:two: 优势和缺点:three: 使用3.1 Class类3.2 获取类的结构信息- 构造函数3.3 获取类的结构信息- 方法3.4 获取类的结构信息- 字段3.5 动态创建对象、调用方法和设置属性3.6 动态代理 :four: 底层原理:five: 应用场景:ear_of_rice: 总结:…

SedonaSQL 聚合函数使用说明

ST_Envelope_Aggr 函数说明: 返回几何的外边界 语法: ST_Envelope_Aggr (A:geometryColumn) 支持版本: v1.0.0 Spark SQL 举例说明: SELECT ST_Envelope_Aggr(pointdf.arealandmark) FROM pointdf运行示例(AggregateFunctionTest.java): ST_Intersection_Aggr 函数说明: …

大文件如何传输到电脑?亲测又快又简单!

我们平时可以因为各种原因,如更换新电脑、高清视频分享等,需要将大文件传输到另一台电脑上。大文件如何传输到电脑?相信这是很多朋友都想知道如何实现吧。我们为您提供了2种轻松将大文件从PC传输到PC的方法。话不多说,上技巧! 方…

腾讯云服务器地域有什么区别?怎么选比较好

腾讯云服务器地域有什么区别?云服务器地域怎么选择?地域是指云服务器所在机房的地理位置,用户距离地域越近网络延迟越低,速度越快,所以地域就近选择即可。广州上海北京等地域网站域名需要备案,中国香港或其…

SpringBoot使用MockMVC单元测试Controller

前言: 在SpringBoot应用程序中,Controller是接受客户端请求并返回响应数据的核心组件。为了保证Controller的正确性和稳定性,我们可以使用MockMVC框架进行单元测试。MockMVC是Spring框架提供的一个HTTP客户端,用于模拟HTTP请求和响…

2023年智能优化算法之——增长优化器Growth Optimizer(GO),附MATLAB代码

增长优化器的主要设计灵感来源于个人在社会成长过程中的学习和反思机制。学习是个体通过从外部世界获得知识而成长的过程。反思是检查个人自身不足并调整个人学习策略以帮助个人成长的过程。参考文献如下: Zhang, Qingke, et al. “Growth Optimizer: A Powerful M…

二维码在固定资产实物盘点中的应用

很多企业为了掌握固定资产的后续使用情况和状态,会定期对投入使用的固定资产进行盘点,然而固定资产常会出现分散情况,在这种情况下让财务人员到达每个固定资产的所在处进行实地盘点显得极为不现实。 也有不少企业会在盘点过程中使用到固定资…

聊天室(一)___常见的基本功能实现

最近搞聊天室的人还挺多,正好自己也弄就总结自己遇到必不可少的一些功能,本篇文章主要为自己和看到我文章的人一种思路,希望大家不要把聊天室想的太复杂。 上图是我自己做的一个聊天室,类似微信的单聊群聊收藏等功能,根…

python+requests接口自动化框架详解,没有比这个更详细的了

目录 为什么要做接口自动化框架 正常接口测试的流程是什么? 一、接口框架如下: 二、接口的数据规范设计---Case设计 2.1注册接口用例 2.2登录接口用例 三、创建utils包:用来存放公共的类 3.1 ParseExcel.py 操作封装excel的类&#xf…

【AI工具】-Stable Diffusion本地化部署教程

前言 今天我们要介绍的是时下最流行的AI绘图软件Stable Diffusion,虽然Diffusion.ai已经开放api,但是长时间的商业化调用我们需要购买很多的金币。所以我们需要找一个平替的AI绘图平台,现在主流市场中AI绘图软件主要就是OpenAI的DALLE、midj…

SSM会议管理系统

SSM会议管理系统 后端基于SSM、前端基于Freemarker写的会议管理系统、使用JDK8、mysql使用5.7版本 技术栈 Spring SpringMVC MyBatis Mysql Freemarker jqueryajax[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JGo0luHu-1687163482019)(img.png)] …

【Python 随练】打印水仙花数

题目: 打印出所有的"水仙花数",所谓"水仙花数"是指一个三位数,其各位数字立方和等于该数 简介: 在本篇博客中,我们将解决一个经典的数学问题:打印出所有的水仙花数。水仙花数是指一…

Unity核心5——Tilemap

Tilemap 一般称之为瓦片地图或者平铺地图,是 Unity2017 中新增的功能,主要用于快速编辑 2D 游戏中的场景,通过复用资源的形式提升地图多样性 ​ 工作原理就是用一张张的小图排列组合为一张大地图 ​ 它和 SpriteShape 的异同 共同点&#x…