李沐37_微调——自学笔记

news2024/9/8 23:48:07

标注数据集很贵

网络架构

1.一般神经网络分为两块,一是特征抽取原始像素变成容易线性分割的特征,二是线性分类器来做分类

微调

1.原数据集不能直接使用,因为标号发生改变,通过微调可以仍然对我数据集做特征提取

2.pre-train源数据集后,迁移模型中的架构,调整权重

3.是一个目标数据集上的正常训练任务,但使用更强的正则化(更小的学习率、更少的数据迭代)

4.源数据集远复杂于目标数据,通常微调效果更好

重用分类器权重

1.源数据集可能也有目标数据中的部分标号

2.可以使用预训练好的模型分类器中对应标号对应的向量来初始化

固定一些层

1.神经网络通常学习有层次的特征表达,低层次的特征更加通用,高层次的特征则和数据集相关

2.可以固定底部一些层的参数,不参与更新,更强的正则化

总结

1.微调通过使用在大数据上得到的预训练好的模型来初始化模型权重来提升精度

2.预训练模型质量很重要

3.微调通常速度更快、精度更高

代码实现——热狗识别

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

获取数据集

解压下载热狗数据集,一共1400张热狗“正向”图片,其他尽可能多的其他食物“负向”图片,解压下载的数据集,我们获得了两个文件夹hotdog/train和hotdog/test。 这两个文件夹都有hotdog(有热狗)和not-hotdog(无热狗)两个子文件夹, 子文件夹内都包含相应类的图像。

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

读取训练集和测试集的图片

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

前8个正类、后8个负类,图片的大小和高宽比不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在这里插入图片描述

训练:从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224X224
输入图像。
测试:我们将图像的高度和宽度都缩放到256像素,然后裁剪中央224X224
区域作为输入。

此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

初始化模型

使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True以自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。

pretrained_net = torchvision.models.resnet18(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]

fc:预训练的源模型实例包含许多特征层和一个输出层fc。 此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。

# fc
pretrained_net.fc
Linear(in_features=512, out_features=1000, bias=True)

目标模型finetune_net中成员变量features的参数被初始化为源模型相应层的模型参数。 由于模型参数是在ImageNet数据集上预训练的,并且足够好,因此通常只需要较小的学习率即可微调这些参数。

成员变量output的参数是随机初始化的,通常需要更高的学习率才能从头开始训练。 假设Trainer实例中的学习率为n,我们将成员变量output中参数的学习率设置为10n。

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

训练模型微调

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)

使用较小的学习率

train_fine_tuning(finetune_net, 5e-5)
loss 0.238, train acc 0.910, test acc 0.941
360.1 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

为了比较,定义相同的模型,但需较大学习率

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
loss 0.403, train acc 0.821, test acc 0.791
366.6 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1597712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL12 获取每个部门中当前员工薪水最高的相关信息

题目&#xff1a;获取每个部门中当前员工薪水最高的相关信息 注意了&#xff0c;这道题目&#xff0c;分组函数只能查出来&#xff1a;每个部门的最高薪水&#xff0c;group by dept_no &#xff0c;根据部门分组&#xff0c;绝对不能group by dept_no,emp_no&#xff0c;不能…

el-tree如何修改节点点击颜色

el-tree修改点击节点颜色三大步 使用elementui库时&#xff0c;有时候我们会对里面提供的组件做一些样式修改。如果我们想要修改el-tree组件点击节点时的颜色&#xff0c;可以使用下面这种方式实现&#xff1a;

C++ | Leetcode C++题解之第29题两数相除

题目&#xff1a; 题解&#xff1a; class Solution { public:int divide(int dividend, int divisor) {// 考虑被除数为最小值的情况if (dividend INT_MIN) {if (divisor 1) {return INT_MIN;}if (divisor -1) {return INT_MAX;}}// 考虑除数为最小值的情况if (divisor I…

密码学基础 -- 走进RSA(2)(放弃数学原理版)

目录 1.概述 2. RSA测试 2.1 加解密实验 2.2 签名验签测试 3. RSA原理简介 4.小结 1.概述 从上面密码学基础 -- 走进RSA(1)(放弃数学原理版)-CSDN博客我们知道了非对称算法的密钥对使用时机&#xff0c;那么接下里我们继续讲解RSA&#xff0c;我们分别从RSA加解密、签名验…

【Unity】RPG小游戏创建游戏中的交互

RPG小游戏创建游戏中的交互 创建可交互的物体的公共的父类&#xff08;Interactable&#xff09;InteractableObject 类NPCObject 类PickableObject 类 创建可交互的物体的公共的父类&#xff08;Interactable&#xff09; InteractableObject 类 using System.Collections; u…

数字乡村创新实践推动农业现代化发展:科技赋能农业产业升级、提升农民收入水平与乡村治理效能

随着信息技术的迅猛发展和数字化转型的深入推进&#xff0c;数字乡村创新实践已成为推动农业现代化发展的重要引擎。数字技术的广泛应用不仅提升了农业生产的智能化水平&#xff0c;也带动了农民收入的增加和乡村治理的现代化。本文旨在探讨数字乡村创新实践如何科技赋能农业产…

61、ARM/串口通信相关学习20240415

一、串口通信&#xff1a;实现PC端串口助手与开发板的字符串通信。 代码&#xff1a; main&#xff1a; #include "uart4.h"int main(){uart4_config();//char a;char s[64];while (1){//a getchar();//putchar(a1);gets(s);puts(s);}return 0;}usrt4.c&#xff…

供应链投毒预警 | 开源供应链投毒202403月报发布啦!(含投毒案例分析)

悬镜供应链安全情报中心通过持续监测全网主流开源软件仓库&#xff0c;结合程序动静态分析方式对潜在风险的开源组件包进行动态跟踪和捕获&#xff0c;能够第一时间捕获开源组件仓库中的恶意投毒攻击。在2024年3月份&#xff0c;悬镜供应链安全情报中心在NPM官方仓库&#xff0…

golang 迷宫回溯算法(递归)

// Author sunwenbo // 2024/4/14 20:13 package mainimport "fmt"// 编程一个函数&#xff0c;完成老鼠找出路 // myMap *[8][7]int 地图&#xff0c;保证是同一个地图&#xff0c;因此是引用类型 // i,j表示对地图的哪个点进行测试 func SetWay(myMap *[8][7]int, …

学习一门语言的方法和套路(B站转述)

视频链接 up虽然长相英(ping)俊(ping)&#xff0c;但是讲的干活&#xff0c;没恰饭。 学习流程&#xff1a; 1.快速阅读&#xff0c;掌握概况 2.深入细节内容 例如&#xff1a;java (JDBC)、html 、netty 不管三七二十一&#xff0c;先了解套路&#xff0c;再深入研究。 高…

【华为】Telnet实验配置

【华为】Telnet 实验配置 应用场景三种认证方式配置注意事项拓扑无认证&#xff08;None&#xff09;交换机配置顺序Telnet ServerTelnet Client测试 密码认证&#xff08;Password&#xff09;配置顺序Telnet ServerTelnet Client测试 AAA认证&#xff08;scheme&#xff09;配…

密码学 | 椭圆曲线 ECC 密码学入门(四)

目录 正文 1 曲线方程 2 点的运算 3 求解过程 4 补充&#xff1a;有限域 ⚠️ 知乎&#xff1a;【密码专栏】动手计算双线性对&#xff08;中&#xff09; - 知乎 ⚠️ 写在前面&#xff1a;本文属搬运博客&#xff0c;自己留着学习。注意&#xff0c;这篇博客与前三…

验证ElasticSearch 分词的BUG

验证ElasticSearch 分词的BUG 环境介绍 ElasticSearch 版本号: 6.7.0 BUG 重现 创建测试案例索引 PUT test_2022 {"settings": {"analysis": {"filter": {"pinyin_filter": {"type": "pinyin"}},"analy…

万界星空科技商业开源MES+项目合作+低代码平台

今天我想和大家分享的是一套商业开源的 MES制造执行管理系统。对于制造业而言&#xff0c;MES 是一个至关重要的系统&#xff0c;它可以帮助企业提高生产效率、优化资源利用、提高产品质量&#xff0c;从而增强市场竞争力。什么是 MES&#xff1f; MES 是指通过计算机技术、自动…

uniapp开发 如何获取IP地址?

一、需求 使用uniapp开发小程序时&#xff0c;需要调取【记录日活动统计】的接口&#xff0c;而这个接口需要传递一个ip给后台&#xff0c; 那么前端如何获取ip呢&#xff1f;下面代码里可以实现 二、代码实现 1.在项目的manifest.json中配置一下网络权限&#xff1a; &quo…

企业网络日益突出的难题与SD-WAN解决方案

随着企业规模的迅速扩张和数字化转型的深入推进&#xff0c;企业在全球范围内需要实现总部、分支机构、门店、数据中心、云等地点的网络互联、数据传输和应用加速。SD-WAN作为当今主流解决方案&#xff0c;在网络效率、传输质量、灵活性和成本等方面远远超越传统的互联网、专线…

C语言之探秘:访问结构体空指针与结构体空指针的地址的区别(九十三)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

前端标记语言HTML

HTML&#xff08;HyperText Markup Language&#xff09;是一种用于创建网页的标准标记语言。它是构建和设计网页及应用的基础&#xff0c;通过定义各种元素和属性&#xff0c;HTML使得开发者能够组织和格式化文本、图像、链接等内容。 HTML的基本结构 文档类型声明&#xff0…

Adobe Premiere 2020 下载地址及安装教程

Premiere是一款专业的视频编辑软件&#xff0c;由Adobe Systems开发。它为用户提供了丰富的视频编辑工具和创意效果&#xff0c;可用于电影、电视节目、广告和其他多媒体项目的制作。 Premiere具有直观的用户界面和强大的功能&#xff0c;使得编辑和处理视频变得简单而高效。它…

【AngularJs】前端使用iframe预览pdf文件报错

<iframe style"width: 100%; height: 100%;" src"{{vm.previewUrl}}"></iframe> 出现报错信息&#xff1a;Cant interpolate: {{vm.previewUrl}} 在ctrl文件中信任该文件就可以了 vm.trustUrl $sce.trustAsResourceUrl(vm.previewUrl);//信任…