深度学习——微调笔记+代码

news2024/10/2 22:21:28

1.微调在深度学习中计算机视觉最重要的技术微调也是迁移学习

2.标注一个数据集很

 

ImageNet注了1000多万张图片,实际使用120万张图片,类别是1000,大型数据集

Fashion-MNIST一共有6万张图片,类别是10,小型数据集

③通常的数据集是两者之间,5万图片左右。类别是100,每一类物体有500张图片

训练样本有限,训练模型的准确性可能无法满足实际的要求。

解决方案:

更多的数据集,但耗时间和金钱

应用迁移学习(transfer leanring):将源数据集上学到的知识(提取更通用的特征,有助于识别边缘,纹理,形状和对象组合)迁移到目标数据集上提升数据集上的精度 。思想:某个模型有一定的识别物体基础,不需要自己提供太大的数据集就能够获得更好的精度。

3.网络架构

①一个神经网络一般可以分为两块:Ⅰ特征提取将原始像素变成容易线性分割的特征

线性分类器来做分类

 

4.微调

 

①在源数据集上(大型数据集)训练好了一个模型pre-train。模型中特征提取的部分对目标数据集有效的,优于随机生成特征提取。在目标数据集上也会根据数据重新学习,训练次数不会太多。

分类不能直接使用的,因为标号变化了,难以重用。

思想:在一个大型数据集上训练好了模型用于特征提取部分,在目标数据集上提取特征进行重用

5.微调中的权重初始化

 

源数据集上源模型pre-train。创建一个新的神经网络模型,目标模型。新模型重用源模型的模型设计和参数(输出层除外)

向目标模型添加输出层,输出数是目标数据集的类别数。随机初始化该层的模型参数(标号不同)

③在目标数据集上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调

注意:因为损失 Loss 是从后往前进行传递的,所以最后的分类部分训练比较快,进行随机初始化也不会有太大的影响;而前面的特征提取的部分本身已经具备很好的特征提取效果,只是根据源数据集和目标数据集的差异进行微调,可能在最开始训练的时候就已经比较接近最终的结果,所以不用做过多的训练和变动。

6.训练

①一个目标数据集上的正常的训练任务,但使用更强的正则化

Ⅰ使用更小的学习率(模型比较好) Ⅱ使用更少的数据迭代

源数据集复杂目标数据,通常微调效果更好。

7.重用分类器权重

①源数据集可能有目标数据中的部分标号。比如源数据集有“车”,你的目标数据集也有车

②使用预训练好模型分类器对应标号  对应的向量来做初始化

8.固定一些层

 

①神经网络通常学习有层次的特征表示

低层次的特征更加通用

高层次的特征跟数据集相关

Ⅲ高层次对标号的关联度很大,低层次的特征更加通用。

②可以固定底部一些层的参数,不参与更新(不做优化,不改变底层的权重。模型复杂度变低):更强的正则(底部参数不更新,不容易过拟合)

【总结】

微调通过使用在大数据上得到的预训练好的模型来初始化模型权重完成提升精度。(白嫖)

②预训练模型质量很重要

③微调通常速度更快,精度更高。

【代码实现】

1.获取热狗数据集    小数据集上使用微调ResNet模型,该模型在ImageNet预训练

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

# 热狗识别  小数据集上使用微调ResNet模型,该模型在ImageNet预训练
# 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                          'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')

2.读取训练和测试数据集

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

3.数据增广

normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    #  RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差
    # 在Image模型已经做了这个
)
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),  # 缩放224*224
    torchvision.transforms.RandomHorizontalFlip(),  # 水平翻转
    torchvision.transforms.ToTensor(),
    normalize
])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),  # 缩放256
    torchvision.transforms.CenterCrop(224),  # 剪裁中央224*224区域输入
    torchvision.transforms.ToTensor(),
    normalize
])

4.定义和初始化模型 pretrained=True自动下载预训练的模型参数

# 定义和初始化模型  pretrained=True自动下载预训练的模型参数
pretrained_net = torchvision.models.resnet18(pretrained=True)
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)  # 输出层  预训练模型输入一样,输出是自己的
nn.init.xavier_uniform_(finetune_net.fc.weight)  #输出层w随机初始化

5. 模型微调 训练函数

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率 param_group模型初始值

def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:  # 体征提取层学习率n低,随机初始化的输出层w学习率是10n
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)

6.使用较小的学习率

train_fine_tuning(finetune_net, 5e-5)

 

loss 0.552, train acc 0.870, test acc 0.951

 

7.对比 不使用模型初始化 aram_group=False

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)

loss 0.348, train acc 0.843, test acc 0.863

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/95249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

matlab:鼠标循环点击器

目录简介使用说明板块1采点板块作用名称解释板块2坐标板块作用名称解释板块3历史数据板块作用名称解释板块4循环点击板块作用名称解释程序附注简介 采集PC端一个或是多个点的位置坐标,对这些位置可以按照次序循环点击。(之前玩阴阳师的时候,…

动态规划问题——矩阵的最小路径和

题目: 给定一个矩阵m,从左上角开始每次只能向右或者向下走,最后到达右下角的位置,路径上所有的数字累加起来就是路径和,返回所有路径中最小的路径和。 示例: 给定的m如下: 1 3 …

灌区信息化管理系统解决方案 灌区用水量测系统介绍

平升电子灌区信息化管理系统解决方案/灌区用水量测系统,对灌区的渠道水位、流量、水雨情、土壤墒情、气象等信息进行监测,同时对泵站、闸门进行远程控制,对重点区域进行视频监控,实现了信息的采集、统计、分析、控制等功能&#x…

我国均温板行业发展趋势:5G手机领域需求强劲 今年市场空间或超15亿

均温板(VaporChamber)技术从原理上类似于热管,但在传导方式上有所区别。热管为一维线性热传导,而真空腔均热板中的热量则是在一个二维的面上传导,因此效率更高。具体来说,真空腔底部的液体在吸收芯片热量后…

【图像去噪】非局部均值(NLM)滤波图像去噪【含Matlab源码 420期】

⛄一、图像去噪及滤波简介 1 图像去噪 1.1 图像噪声定义 噪声是干扰图像视觉效果的重要因素,图像去噪是指减少图像中噪声的过程。噪声分类有三种:加性噪声,乘性噪声和量化噪声。我们用f(x,y)表示图像,g(x,y&#xff0…

Spring Cloud Alibaba Sentinel - - > 容错机制

文章目录Sentinel 的作用分布式微服务系统遇到的问题导致服务不可用的原因:Sentinel - - > 容错机制参考:Sentinel 的作用 Sentinel 主要用来解决微服务架构中出现的一些可用性问题,从而实现系统的高可用。系统在运行过程中不可能不出现问…

单例模式的创建(饿汉模式懒汉模式)

目录 一.什么是单例模式 二.用static来创建单例模式 三.饿汉模式与懒汉模式 四.饿汉模式与懒汉模式的线程安全问题 五.New引发的指令重排序问题 六.小结 一.什么是单例模式 单例模式就是指某个类有且只有一个实例(instance) 这个是由需求决定的,有些需求场景就要求实例不…

关于mysql学习

1.索引 1.1 索引概述 Mysql官方对索引的定义是:索引(index)是帮助mysql高效获取数据的数据结构(有序)。在数据库之外,数据库系统还维护着满足特定查找算法的数据结构,这些数据结构以某些方式引用(指向)数据,这样就可以在数据结构…

密码技术扫盲:对称加密

个人博客 🎯 密码技术扫盲:对称加密密码技术扫盲:非对称加密密码技术扫盲:认证 人类最较真、技艺最精湛的事业是军事,密码技术最大放异彩的地方也在军事,战争中需要通过无线电或其他手段来传达指令&#…

LeetCode刷题复盘笔记—一文搞懂动态规划之309. 最佳买卖股票时机含冷冻期问题(动态规划系列第二十四篇)

今日主要总结一下动态规划的一道题目,309. 最佳买卖股票时机含冷冻期 题目:309. 最佳买卖股票时机含冷冻期 Leetcode题目地址 题目描述: 给定一个整数数组prices,其中第 prices[i] 表示第 i 天的股票价格 。​ 设计一个算法计算…

Python编程 函数的定义与参数

作者简介:一名在校计算机学生、每天分享Python的学习经验、和学习笔记。 座右铭:低头赶路,敬事如仪 个人主页:网络豆的主页​​​​​​ 目录 前言 一.函数 1.函数例子 不会让代码重复的出现。CVout 2.函数介绍(熟悉) 3.…

如何在AdsPower中设置Oxylabs住宅代理和数据中心代理?

AdsPower是一款适用于Windows和Mac系统的浏览器管理工具,允许多用户登录。AdsPower的主要功能有多账户管理、浏览器指纹处理等。 集成操作流程 在官网(www.adspower.com/download)下载AdsPower并完成安装工作后,单击新建配置文件…

分布式文件系统之NFS

「分布式」是现在蛮流行的一个词,而其盛行,离不开底层网络通信能力的迅速发展。在文件系统这个领域,早期的分布式实现更多的是用来实现「共享」,而不是「容错」。传统的集中式文件系统允许单个系统中的多用户共享本地存储的文件&a…

SVG公众号排版 | GIF动图如何禁止循环播放?PS设置了也没用!

在SVG公众号排版中,我们经常使用到GIF动图,有些排版需求是想让GIF动图一直无限循环播放,也有其他排版需求是只想让GIF动图播放一次就停止了,这种情况我们可以通过Photoshop软件来设置GIF动图的播放次数,详见下图。 但是呢,也有一种情况,即使在Photoshop软件设置了GIF动图…

大话设计模型 Task02:策略、装饰、代理

目录一、简单工厂模式问题描述模式定义问题分析代码实现二、策略模式问题描述问题分析模式定义代码实现三、装饰模式问题描述问题分析模式定义代码实现四、代理模式问题描述问题分析模式定义代码实现五、工厂方法模式问题描述问题分析模式定义简单工厂 vs. 工厂方法代码实现一、…

上传项目代码到Github|Gitee

上传项目代码到Github|Gitee 文章目录上传项目代码到Github|Gitee1、前置准备1.1 Git 安装1.2 在 Git 中设置用户名1.2.1 为计算机上的每个存储库设置 Git 用户名1.2.2 为一个仓库设置 Git 用户名1.3 SSH免密登录1.4 Github创建一个新的仓库2、上传项目2.1 初始化本地库2.2 添加…

蓝桥杯入门即劝退(十六)查找元素范围(双解法)

欢迎关注点赞评论,共同学习,共同进步! ------持续更新蓝桥杯入门系列算法实例-------- 如果你也喜欢Java和算法,欢迎订阅专栏共同学习交流! 你的点赞、关注、评论、是我创作的动力! -------希望我的文章…

什么是制造业数字化转型?制造业数字化转型的核心与意义

对于生产制造企业来讲,当下如果不进行数字化转型的话,很大概率会被时代所抛弃的。为什么这么讲?因为在未来的很长一段时间,你可以充分了解到,数字化转型已然成为了制造业向前的主旋律。既然数字化势在必行,…

可以赚钱的副业项目,简单易上手兼职副业推荐

在当前的经济环境下,对每个人来说,仅仅依靠那点薪水生活是非常紧张的。为了改善你的生活,你需要找到其他赚钱的方法,在互联网上做兼职是一个不错的选择。 今天推荐几个普通人可以做的兼职副业,希望对大家有所帮助。 一…

微信公众号的文章可以修改几次?修改的步骤有哪些

许多小伙伴们在运营微信公众号的时候,可能会遇到过这些难题,在发布微信公众号之前检查没有检查好,导致有错字或者是错句。有的时候可能配图还会配错! 今天伯乐网络传媒就给大家带来一些实用的东西,比如微信公众号可以…