pytorch迁移学习训练图像分类

news2025/1/11 18:08:37

pytorch迁移学习训练图像分类

  • 一、环境配置
  • 二、迁移学习关键代码
  • 三、完整代码
  • 四、结果对比

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:Pytorch迁移学习训练自己的图像分类模型

一、环境配置

1,安装所需的包

pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装Pytorch

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,创建目录

import os
# 存放训练得到的模型权重
os.mkdir('checkpoint')

4,下载数据集压缩包(下载之后需要解压数据集)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip

二、迁移学习关键代码

以下是迁移学习的三种选择,根据训练的需求选择不同的迁移方法:

  • 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与 当前数据集类别数n_class 对应
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
  • 选择二:微调训练所有层。

适用于训练数据集与预训练模型相差大时,可以选择微调训练所有层,此时只使用预训练模型的部分权重和特征,例如原始模型为imageNet,而训练数据为医疗相关

model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
  • 选择三:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())

三、完整代码

import time
import os

import numpy as np
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn

# 忽略出现的红色提示
import warnings
warnings.filterwarnings("ignore")

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

from torchvision import transforms

# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

# 数据集文件夹路径
dataset_dir = 'fruit30_split'
train_path = os.path.join(dataset_dir, 'train')	# 测试集路径
test_path = os.path.join(dataset_dir, 'val')	# 测试集路径

from torchvision import datasets

# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)

# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)

# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)

# 定义数据加载器DataLoader
from torch.utils.data import DataLoader

BATCH_SIZE = 32

# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4
                         )

# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=4
                        )

from torchvision import models
import torch.optim as optim

# 选择一:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True,指定张量需要梯度计算
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc	# 查看全连接层
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())    # optim 是 PyTorch 的一个优化器模块,用于实现各种梯度下降算法的优化方法


# 选择二:微调训练所有层
# 训练数据集与预训练模型相差大时,可以选择微调训练所有层,只使用预训练模型的部分权重和特征,例如原始模型为imageNet,训练数据为医疗相关
# model = models.resnet18(pretrained=True) # 载入预训练模型
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())


# 选择三:随机初始化模型全部权重,从头训练所有层
# model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())

# 训练配置
model = model.to(device)

# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 训练轮次 Epoch
EPOCHS = 30

# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):

    model.train()

    for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)           # 前向预测,获得当前 batch 的预测结果
        loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数
        
        optimizer.zero_grad()
        loss.backward()                   # 损失函数对神经网络权重反向传播求梯度
        optimizer.step()                  # 优化更新神经网络权重

# 测试集上初步测试
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度
        _, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果
        total += labels.size(0)
        correct += (preds == labels).sum()   # 预测正确样本个数

    print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

# 保存模型
torch.save(model, 'checkpoint/fruit30_pytorch_A1.pth') # 选择一:微调全连接层
# torch.save(model, 'checkpoint/fruit30_pytorch_A2.pth') # 选择二:微调所有层
# torch.save(model, 'checkpoint/fruit30_pytorch_A3.pth') # 选择三:随机权重

四、结果对比

调用不同迁移学习得到的模型对比测试集准确率

# 测试集导入和图像预处理等代码和上述完整代码中一致,此处省略……

# 调用自己训练的模型
model = torch.load('checkpoint/fruit30_pytorch_A1.pth')

# 测试集上进行测试
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度
        _, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果
        total += labels.size(0)
        correct += (preds == labels).sum()   # 预测正确样本个数

    print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

结果如下:
对于微调全连接层的选择一,测试集准确率为 72.078%
在这里插入图片描述
而所有权重随机的选择三测试集准确率为 43.228%
43.228

总体而言,迁移学习能够利用已有的知识和经验,加速模型的训练过程,提高模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工业检测 ocr

采用OpenCV和深度学习的钢印识别_菲斯奇的博客-CSDN博客采用OpenCV和深度学习的钢印识别[这个帖子标题党了很久,大概9月初立贴,本来以为比较好做,后来有事情耽搁了,直到现在才有了一些拿得出手的东西。肯定不会太监的。好&#xf…

RL 暂态电路与磁能

前言 RL 电路是一个电阻 R 和 自感线圈 L 组成的 RL 电路,在连接或者接通电源U 的时候,由于自感电动势的作用,电路中的电流不会瞬间改变,而是一个连续的渐变的过程,通常这个时间很短暂,所以被称为暂态过程…

MySQL 面试题——MySQL 基础

目录 1.什么是 MySQL?有什么优点?2.MySQL 中的 DDL 与 DML 是分别指什么?3.✨数据类型 varchar 与 char 有什么区别?4.数据类型 BLOB 与 TEXT 有什么区别?5.DATETIME 和 TIMESTAMP 的异同?6.✨MySQL 中 IN …

xss-labs实操

文章目录 1.Level2.Level23.Level34.Level45.Level56.Level67.Level78.Level89.Level910.Level1011.Level1112.Level1213.Level13 1.Level 无过滤法 2.Level2 “>闭合 “>&submit搜索 通过观察发现alert里的引号没了,是不是被过滤了呢 因为如果该语句类…

leetcode:69. x 的平方根

一、题目 函数原型:int mySqrt(int x) 二、思路 利用二分查找思想,在0与x区间进行查找。 设置左边界 left (初始值为0),右边界 right(初始值为x)和中值 mid (值为区间的中间值&#…

VCP-DCV VMware vSphere:安装、配置和管理[V8.x]

VMware官方授权合作活动,全国招生! VCP-DCV VMware vSphere:安装、配置和管理[V8.x] 课程名称:VMware vSphere安装、配置和管理[V8.x] 培训课时:40课时 培训天数:5天 课程介绍:本课程重点讲…

MySQL优化技巧:提升数据库性能

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

盐碱地改良通用技术 铁尾砂改良学习

声明 本文是学习GB-T 42828.1-2023 盐碱地改良通用技术 第1部分:铁尾砂改良. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件描述了铁尾砂改良盐碱地技术的技术原理,规定了技术要求、田间管理和效果评价。 本文…

计算机d3dx9_43.dll丢失怎么解决,简单的5个解决方法分享

在当今这个高度依赖计算机技术的时代,我们的生活和工作都离不开各种软件的支持。然而,有时候我们可能会遇到一些棘手的问题,比如计算机中的某个dll文件丢失,导致程序无法正常运行。最近,我就遇到了这样一个问题&#x…

无涯教程-JavaScript - TRANSPOSE函数

描述 TRANSPOSE函数将单元格的垂直范围作为水平范围返回,反之亦然。必须将TRANSPOSE函数作为数组公式输入,该范围必须具有与行范围和列范围相同的行和列数。 您可以使用TRANSPOSE在工作表上移动数组或范围的垂直和水平方向。 语法 TRANSPOSE (array)键入函数后,按CTRL SHI…

【计算机视觉】Vision Transformers算法介绍合集(二)

文章目录 一、Transformer in Transformer二、Bottleneck Transformer三、Pyramid Vision Transformer v2四、Class-Attention in Image Transformers五、Co-Scale Conv-attentional Image Transformer六、XCiT七、Focal Transformers八、CrossViT九、ConViT十、CrossTransform…

【C++】map,set简单操作的封装实现(利用红黑树)

文章目录 一、STL中set与map的源码二、 红黑树结点的意义三、仿函数的妙用四、set,map定义迭代器的区别五、map,set迭代器的基本操作:1.begin() end()2.operator3.operator-- 六、迭代器拷贝构造…

A Framework to Evaluate Fusion Methods for Multimodal Emotion Recognition

题目A Framework to Evaluate Fusion Methods for Multimodal Emotion Recognition译题一种评估多模态情感识别融合方法的框架时间2022年仅用于记录学习,不作为商用 一种评估多模态情感识别融合方法的框架 摘要:情绪识别的多模态方法考虑了预测情绪的几…

2023/9/17周报

摘要 本周阅读了两篇论文,其一为一种基于空气质量时频域特征提取的hybrid预测方法,另一篇为基于烛台与视觉几何群模型的 PM2.5 变化趋势特征提取与分类预测方法。在第一篇文章中,通过小波变化,对数据进行分频,并设计了…

详解3dMax中渲染线框的两种简单方法

在3dMax中渲染线框是你在某个时候想要完成的事情,例如为了演示分解步骤,或是仅仅为了在模型上创建线框覆盖的独特效果。为三维模型渲染线框最常见的原因是能够在模型上显示干净的拓扑。这篇文章将带你了解在3dMax中渲染三维模型线框的两种最常见、最简单…

太炫酷,3分钟学会,视频倒放技能

一,视频倒放 视频倒放是一种有趣的视频编辑技术,可以为您的视频带来一些特殊的效果。通过倒放视频,您可以实现以下效果(如果有其它需要的软件和技术,可以私信小编;更多精彩可关注微信公众号:黑…

推荐一个高质量专栏:「前端面试必备」

文章目录 专栏作者介绍专栏介绍目录(前25篇)目录(后25篇)专栏文章部分摘抄JavaScriptVue网络请求和HTTPNode.jswebpackBabelVite微信小程序Vuexuni-appGitECharts前端工程化 写在结尾 专栏作者介绍 🤍 前端开发工程师&…

DT Paint Effects工具(三)

管 分支 使用细枝 叶 力 使用湍流 流动画 渲染全局参数 建造盆栽植物

Ubuntu 22.04安装过程

iso下载地址 Ubuntu Releases 1.进入引导菜单 选择Try or Install Ubuntu Server安装 2.选择安装语言 默认选择English 3.选择键盘布局 默认即可 4.选择安装服务器版本 最小化安装 5.配置网络 选择ipv4 选择自定义 DHCP也可 6.配置代理 有需要可以配置 这里跳过 7.软件源 …

「UG/NX」Block UI 指定轴SpecifyAxis

✨博客主页何曾参静谧的博客📌文章专栏「UG/NX」BlockUI集合📚全部专栏「UG/NX」NX二次开发「UG/NX」BlockUI集合「VS」Visual Studio「QT」QT5程序设计「C/C+&#