预训练--微调

news2024/11/24 13:01:19

预训练–微调

一个很简单的道理,如果我们的模型是再ImageNet下训练的,那么这个模型一定是会比较复杂的,意思就是这个模型可以识别到很多种类别的即泛化能力很强,但是如果要它精确的识别是否某种类别,它的表现可能就不佳了,因此,我们需要在原来的基础上再对特定的我们需要识别的类别进行重新训练,微调原来网络结构中的参数,此时模型还是可以抽取较通用的图像特征。
在这里插入图片描述
参考自https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter09_computer-vision/9.2_fine-tuning
当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

热狗识别

源数据集是ImageNet,超过1000万个图像和1000类物体,热狗数据集包含1400个正类图像和其他多种负类图像
最开始还是导入所需要的库以及设置cuda

import torch
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os
import d2lzh_pytorch as d2l
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下载数据集https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/hotdog.zip
我直接放在了我的默认路径下,读数据如下

train_imgs = ImageFolder("hotdog/train")
test_imgs = ImageFolder("hotdog/test")

然后我们观察一下数据集,可以看到大小,宽高比各不同

# 前八张正类图像和最后八张负类图像,可以看到宽高比、大小各不同
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [test_imgs[-1-i][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs,2, 8, scale=2)

在这里插入图片描述
接下来就是训练时,我们先从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入,测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高、宽均为224的中心区域作为输入,此外对RGB三通道作标准化,每个数值减去通道的平均值,再除以标准差需要注意的是,在使用预训练模型时,一定要和预训练时作同样的预处理。 如果你使用的是torchvision的models,
那就要求: All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
如果你使用的是pretrained-models.pytorch仓库,请务必阅读其README,其中说明了如何预处理。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([
    #transforms.Resize(size=256),  # 是将最小边调整到256
    #transforms.CenterCrop(size=224),
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

test_augs = transforms.Compose([
    transforms.Resize(size=256),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    normalize
])

需要注意的是,首先我有最开始有两点疑惑

  1. 为什么不能需要从图像中随机裁剪一块区域,然后将该区域缩放成224*224的图像进行输入。然后我测试了一下,如果不这样做的话,那么泛化能力会比较差
  2. 如果非要这么做,那么可不可以直接transforms.Resize(size=224)?不可以的,transforms.Resize(size=224)是把最短的边变为224,宽高比没变,那么这样就会导致图像的尺寸不一样,后面自然会报错,所以需要先transforms.Resize(size=256),然后transforms.CenterCrop(size=224)

之后我们使用在ImageNet上预训练的ResNet18,pretrained=True,自动下载预训练参数
不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home目录下.torch文件夹。
你可以通过修改环境变量$TORCH_MODEL_ZOO来更改下载目录

pretrained_net = models.resnet18(pretrained=True)

修改最后一层

pretrained_net.fc = nn.Linear(512, 2)

接下来设置训练的参数,由于除了最后一层,之前的参数都经过预训练,所以我们学习率调小一点,最后的fc层是初始化过的,于是我们学习率调大一点

output_params = list(map(id, pretrained_net.fc.parameters()))  # fc层
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())  # 除了fc层
lr = 0.01 # 用来更新特征层
# fc层是lr * 10
optimizer = optim.SGD([
    {"params":feature_params},
    {"params":pretrained_net.fc.parameters(), "lr":lr*10}
] ,lr = lr, weight_decay=0.001)

在之后就是训练了

def train_fine_tuning(net, optimizer, batch_size=64, num_epochs=5):
    train_iter = DataLoader(ImageFolder("hotdog/train", transform=train_augs), batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder("hotdog/test", transform=test_augs), batch_size, shuffle=False)
    loss = torch.nn.CrossEntropyLoss()
    d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
train_fine_tuning(pretrained_net, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1297181.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

html、css类名命名思路整理

开发页面时,老是遇到起名问题,越想越头疼,严重影响开发进度,都是在想名字,现在做一下梳理,统一一下思想,希望以后能减少这块的痛苦。 命名规则 [功能名称]__[组成部分名称]--[样式名称] 思路…

小红书自动点赞工具,其成功分享与引流攻略从入门到精通

先来看实操成果,↑↑需要的同学可看我名字↖↖↖↖↖,或评论888无偿分享 一、小红书引流的优势 小红书作为一个社交电商平台,具有巨大的引流潜力。其独特的UGC(用户生成内容)模式使得用户可以轻松地分享自己的购物心得…

python主流开发工具排名,python开发工具有哪些

本篇文章给大家谈谈python的开发工具软件有哪些,以及python主流开发工具排名,希望对各位有所帮助,不要忘了收藏本站喔。 python中用到哪些软件 一、Python代码编辑器1、sublime Textsublime Text是一款非常流行的代码编辑器,支持P…

Python渗透测试——一、数据包的编辑工具——Scapy

Python渗透测试 一、Scapy简介二、Scapy中的分层结构三、Scapy中的常用函数四、在Scapy 中发送和接收数据包五、Scapy 中的抓包函数 一、Scapy简介 提到数据包(这里泛指帧、段和报文等)的构造,我们首先需要了解协议和分层这两个概念。在“互联世界的规则一协议”中…

概率论之 证明 正态分布的上a 分位点的对称的性质

公式(Z(a) -Z(1-a)) 表示正态分布的上(a)分位点与下(1-a)分位点在分布曲线上关于均值的对称性。 左侧 (Z(a)): 这是分布曲线上累积概率为(a)的那个点。也就是说,这是一个使得这个点及其左侧的面积占据整个曲线下方(a)的位置。 右侧 (Z(1-a))&#xff1…

网页设计--第6次课后作业

试用Vue相关指令完成对以下json数据的显示。显示效果如下: 其中:gender1 显示为女,gender2显示为男。价格超过30元,显示“有点小贵”。价格少于等于30元,则显示“价格亲民”。 data: {books: [{"id": "…

Selenium IED安装及简单使用

本文已收录于专栏 《自动化测试》 目录 背景介绍优势特点安装步骤录制脚本总结提升 背景介绍 Selenium 通过使用 WebDriver 支持市场上所有主流浏览器的自动化。 Webdriver 是一个 API 和协议,它定义了一个语言中立的接口,用于控制 web 浏览器的行为。 每…

基于Unity3D 低多边形地形模型纹理贴图

在线工具推荐: 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时,有几种不同的风格&#xf…

springboot + thymeleaf + layui 初尝试

一、背景 公司运营的同事有个任务,提供一个数据文件给我,然后从数据库中找出对应的加密串再导出来给他。这个活不算是很难,但时不时就会有需求。 同事给我的文件有时是给excel表格,每一行有4列,逗号隔开,…

RocketMQ-RocketMQ高性能核心原理(流程图)

1.NamesrvStartup 2.BrokerStartup 3. DefualtMQProducer 4.DefaultMQPushConsumer

LeetCode-数组-重叠、合并、覆盖问题-中等难度

435. 无重叠区间 我认为区间类的题型,大多数考验的是思维能力,以及编码能力,该类题型本身并无什么算法可言,主要是思维逻辑,比如本题实际上你只需要能够总结出重叠与不重叠的含义,再加上一点编码技巧&#…

【上海大学数字逻辑实验报告】五、记忆元件测试

一、实验目的 掌握R-S触发器、D触发器和JK触发器的工作原理及其相互转换。学会用74LS00芯片构成钟控RS触发器。学会用74LS112实现D触发器学会在Quartus II上用D触发器实现JK触发器。 二、实验原理 基本R-S触发器是直接复位-置位的触发器,它是构成各种功能的触发器…

解读Stable Video Diffusion:详细解读视频生成任务中的数据清理技术

Diffusion Models视频生成-博客汇总 前言:Stable Video Diffusion已经开源一周多了,技术报告《Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets》对数据清洗的部分描述非常详细,虽然没有开源源代码,但是博主正在尝试复现其中的操作。这篇…

基于ssm平面设计课程在线学习平台系统源码和论文

idea 数据库mysql5.7 数据库链接工具:navcat,小海豚等 随着信息化时代的到来,管理系统都趋向于智能化、系统化,平面设计课程在线学习平台系统也不例外,但目前国内的市场仍都使用人工管理,市场规模越来越大,…

chrome安装jsonview

写在前面 通过jsonview可以实现,当http响应时application/json时直接在浏览器格式化显示,增加可读性。本文看下如何安装该插件到chrome中。 1:安装 首先在这里 下载插件包,然后解压备用。接着在chrome按照如下步骤操作&#xf…

JAVEE初阶 多线程基础(七)

懒汉模式 指令重排序问题 一. 懒汉模式的意义和代码实现二. 饿汉模式和懒汉模式的线程安全三. 懒汉模式的线程安全问题解决3.1 加锁阶段3.2 嵌套if阶段3.3 指令重排序问题3.4 解决线程安全问题阶段 一. 懒汉模式的意义和代码实现 在上一章节中,我们先学习了单例模式中的饿汉模式…

go语言学习-并发编程(并发并行、线程协程、通道channel)

1、 概念 1.1 并发和并行 并发:具有处理多个任务的能力 (是一个处理器在处理任务),cpu处理不同的任务会有时间错位,比如有A B 两个任务,某一时间段内在处理A任务,这时A任务需要停止运行一段时间,那么会切换到处理B任…

基于redisson实现发布订阅(多服务间用避坑)

前言 今天要分享的是基于Redisson实现信息发布与订阅(以前分享过直接基于redis的实现),如果你是在多服务间基于redisson做信息传递,并且有服务压根就收不到信息,那你一定要看完。 今天其实重点是避坑&#xff0…

MySQL 对null 值的特殊处理

需求 需要将不再有效范围内的所有数据都删除,所以用not in (有效list)去实现,但是发现库里,这一列为null的值并没有删除,突然想到是不是跟 anull 不能生效一样,not in 对null不生效,也需要特殊处理。 解决 …

$sformat在仿真中打印文本名的使用

在仿真中,定义队列,使用任务进行函数传递,并传递文件名,传递队列,进行打印 $sformat(filename, “./data_log/%0d_%0d_%0d_0.txt”, f_num, lane_num,dt); 使用此函数可以自定义字符串,在仿真的时候进行文件…