(动手学习深度学习)第13章 计算机视觉---微调

news2024/9/24 5:32:58

文章目录

    • 微调
      • 总结
    • 微调代码实现

微调

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

在这里插入图片描述
查看数据集中图像的形状

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)

在这里插入图片描述

  1. 数据增强
# 图像增广
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强
    [torchvision.transforms.RandomResizedCrop(224),
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor(),
     normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强
    [torchvision.transforms.Resize(256),
     torchvision.transforms.CenterCrop(224),
     torchvision.transforms.ToTensor(),
     normalize]
)
  1. 定义和初始化模型
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)

在这里插入图片描述

  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512—>1000,改成512—>2)
  • (2)在增加一层分类层(如:512—>1000, 改成512—>1000, 1000—>2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

在这里插入图片描述

print(finetune_net)

在这里插入图片描述

  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):
    train_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(
            os.path.join(data_dir,'train'), transform=train_augs
        ),
        batch_size=batch_size,
        shuffle=True
    )
    test_iter = torch.utils.data.DataLoader(
        torchvision.datasets.ImageFolder(
            os.path.join(data_dir, 'test'), transform=test_augs
        ),
        batch_size=batch_size
    )
    device = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction='none')
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
                     if name not in ['fc.weight', 'fc.bias']]
        trainer = torch.optim.SGD(
            [{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],
            lr=learning_rate, weight_decay=0.001
        )
    else:
        trainer = torch.optim.SGD(
            net.parameters(),
            lr=learning_rate,weight_decay=0.001
        )
    d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train_fine_tuning(finetune_net, 5e-5, 128, 10)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

直接训练:整个模型都使用相同的学习率,重新训练

scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)

import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train_fine_tuning(scracth_net, 5e-4, param_group=False)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1228353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java文件压缩加密,使用流的方式

使用net.lingala.zip4j来进行文件加密压缩。 添加依赖net.lingala.zip4j包依赖&#xff0c;这里使用的是最新的包2.11.5版本。 <dependency><groupId>net.lingala.zip4j</groupId><artifactId>zip4j</artifactId><version>${zip4j.versi…

丹麦能源袭击预示着更关键的基础设施成为目标

5 月&#xff0c;22 个丹麦能源部门组织在与俄罗斯 Sandworm APT 部分相关的攻击中受到损害。 丹麦关键基础设施安全非营利组织 SektorCERT 的一份新报告描述了不同的攻击者群体利用合勤防火墙设备中的多个关键漏洞&#xff08;包括两个零日漏洞&#xff09;侵入工业机械&…

Dockerfile自定义镜像以及案例分析

文章目录 一、Dockerfile自定义镜像1.1 镜像结构1.2 Dockerfile语法 二、构建Java项目三、基于java8构建java四、小结 一、Dockerfile自定义镜像 常见的镜像在DockerHub就能找到&#xff0c;但是我们自己写的项目就必须自己构建镜像了。 而要自定义镜像&#xff0c;就必须先了…

boomYouth

上一周实在是过得太颓废了&#xff0c;我感觉还是要把自己的规划做好一下&#xff1a; 周计划 这周截至周四&#xff0c;我可以用vue简单的画完登陆注册的界面并且弄一点预处理&#xff1a; 周一 的话可以把这些都学一下&#xff1a; 父传子&#xff0c;子传父&#xff1a…

配置iTerm2打开自动执行命令

打开iTerm2&#xff0c;commado&#xff0c;打开profies->edit profies&#xff0c;点击号&#xff0c;创建一个新的profile 在新的profile中填写 name&#xff1a;随意 command&#xff1a;Login Shell Send text at start&#xff1a;执行脚本的命令&#xff0c;不想写路…

python django 小程序图书借阅源码

开发工具&#xff1a; PyCharm&#xff0c;mysql5.7&#xff0c;微信开发者工具 技术说明&#xff1a; python django html 小程序 功能介绍&#xff1a; 用户端&#xff1a; 登录注册&#xff08;含授权登录&#xff09; 首页显示搜索图书&#xff0c;轮播图&#xff0…

某60区块链安全之不安全的随机数实战一

区块链安全 文章目录 区块链安全不安全的随机数实战一实验目的实验环境实验工具实验原理实验内容攻击过程分析合约源代码漏洞EXP利用 不安全的随机数实战一 实验目的 学会使用python3的web3模块 学会以太坊不安全的随机数漏洞分析及利用 实验环境 Ubuntu18.04操作机 实验工…

环境配置|GitHub——解决Github无法显示图片以及README无法显示图片

一、问题背景 最近在整理之前写过的实验、项目&#xff0c;打算把这些东西写成blog&#xff0c;并把工程文件整理上传到Github上。但在上传README文件的时候&#xff0c;发现github无法显示README中的图片&#xff0c;如下图所示&#xff1a; 在README中该图片路径为&#xff1…

Unity Meta Quest 一体机开发(七):配置玩家 Hand Grab 功能

文章目录 &#x1f4d5;教程说明&#x1f4d5;玩家物体配置 Hand Grab Interactor⭐添加 Hand Grab Interactor 物体⭐激活 Hand Grab Visual 和 Hand Grab Glow⭐更新 Best Hover Interactor Group &#x1f4d5;配置可抓取物体&#xff08;无抓取手势&#xff09;⭐刚体和碰撞…

【算法】树形DP③ 监控二叉树 ⭐(二叉树染色二叉树灯饰)!

文章目录 前期知识 & 相关链接例题968. 监控二叉树解法1——标记状态贪心解法2——动态规划 相关练习题目P2458 [SDOI2006] 保安站岗⭐&#xff08;有多个儿子节点&#xff09;&#x1f6b9;LCP 34. 二叉树染色⭐&#xff08;每个节点 单独dp[k 1]数组&#xff09;LCP 64.…

时间序列预测实战(十七)利用Prophet实现电力负荷长期预测(附代码+数据集+详细讲解)

一、本文介绍 Prophet是一个由Facebook开发的开源工具&#xff0c;用于时间序列预测。这个工具特别适合于具有强季节性影响和多个历史数据季节的业务时间序列数据。Prophet的主要思想是将数据分解为如下三个部分&#xff1a;趋势、季节性、节假日和特殊事件。这个模型非常适合…

GIT无效的源路径/URL

ssh-add /Users/haijunyan/.ssh/id_rsa ssh-add -K /Users/haijunyan/.ssh/id_rsa

SQL基础理论篇(七):多表关联的连接算法

文章目录 简介Nested LoopsMerge JoinHash Join总结参考文献 简介 多表之间基础的关联算法一共有三种&#xff1a; Hash JoinNested LoopsMerge Join 还有很多基于这三种基础算法的变体&#xff0c;以Nested Loops为例&#xff0c;就有用于in和exist的半连接&#xff08;Nes…

【Android Jetpack】Hilt的理解与浅析

文章目录 依赖注入DaggerHiltKoin添加依赖项Hilt常用注解的含义HiltAndroidAppAndroidEntryPointInjectModuleInstallInProvidesEntryPoint Hilt组件生命周期和作用域如何使用 Hilt 进行依赖注入 本文只是进行了简单入门&#xff0c;博客仅当做笔记用。 依赖注入 依赖注入是一…

文档向量化工具(一):Apache Tika介绍

Apache Tika是什么&#xff1f;能干什么&#xff1f; Apache Tika是一个内容分析工具包。 该工具包可以从一千多种不同的文件类型&#xff08;如PPT、XLS和PDF&#xff09;中检测并提取元数据和文本。 所有这些文件类型都可以通过同一个接口进行解析&#xff0c;这使得Tika在…

node实战——koa实现文件上传

文章目录 ⭐前言⭐koa实现文件上传⭐foxapi测试⭐总结⭐结束⭐前言 大家好,我是yma16,本文分享关于node实战——node实战——koa实现文件上传。 本文适用对象:前端初学者转node方向,在校大学生,即将毕业的同学,计算机爱好者。 node系列往期文章 node_windows环境变量配置…

Day35力扣打卡

打卡记录 相邻字符不同的最长路径&#xff08;树状DP&#xff09; 链接 若节点也存在父节点的情况下&#xff0c;传入父节点参数&#xff0c;若是遍历到父节点&#xff0c;直接循环里 continue。 class Solution:def longestPath(self, parent: List[int], s: str) -> in…

《微信小程序开发从入门到实战》学习二十

3.3 开发创建投票页面 3.3.8 使用icon图标文件 原来已经实现了投票选项的增加和修改功能&#xff0c;现在还差删除。现在为每一个选项增加删除按钮&#xff0c;可以以通过icon图标组件实现。 icon常用属性如下&#xff1a; type icon的类型&#xff0c;有success、s…

Linux虚拟机中网络连接的三种方式

Linux 虚拟机中网络连接的三种方式 先假设一个场景&#xff0c;在教室中有三个人&#xff1a;张三、李四和王五&#xff08;这三个人每人有一台主机&#xff09;&#xff0c;他们三个同处于一个网段中&#xff08;192.169.0.XX&#xff09;&#xff0c;也就是说他们三个之间可…

ICASSP2023年SPGC多语言AD检测的论文总结

文章目录 引言正文AbstractRelated ArticleNo.1: CONSEN: COMPLEMENTARY AND SIMULTANEOUS ENSEMBLE FOR ALZHEIMERSDISEASE DETECTION AND MMSE SCORE PREDICTION特征相关模型结构数据处理结果分析 No.2: CROSS-LINGUAL TRANSFER LEARNING FOR ALZHEIMERS DETECTION FROM SPON…