迁移学习实现图片分类任务

news2024/11/19 21:30:03

导入工具包

import time
import os

import numpy as np
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

获取计算硬件

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

图片预处理

from torchvision import transforms

# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

这里对train训练集和text集的处理不同,几个transforms的操作通过compose进行整合。

载入图片分类数据集

# 数据集文件夹路径
dataset_dir = 'fruit30_split'

train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)

from torchvision import datasets

# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)

# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)


print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)

print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

datasets下的ImageFolder,可以直接构建数据集。

类别与索引号一一对应

class_names = train_dataset.classes
n_class = len(class_names)


# 映射关系:类别 到 索引号
train_dataset.class_to_idx

定义数据加载器Dataloader,dataloader用于给模型喂数据。

from torch.utils.data import DataLoader

BATCH_SIZE = 32

# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4
                         )

# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=4
                        )

查看一个batch的图像与标注

# DataLoader 是 python生成器,每次调用返回一个 batch 的数据
images, labels = next(iter(train_loader))

images. Shape
#torch.Size([32, 3, 224, 224])
labels
#tensor([11, 19,  3, 25, 29, 13, 21, 18, 11,  1, 13, 15, 13,  0, 15, 25,  0,  7,11, 10,  9,  6, 26,  2, 11, 10, 29, 29, 15,  8, 19,  8])

迁移学习范式

导入训练所用的工具包

from torchvision import models
import torch.optim as optim
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc
Linear(in_features=512, out_features=30, bias=True)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())

采用第一种迁移学习的方式,优化器采用的是Adam的优化器。

训练配置

model = model.to(device)

# 交叉熵损失函数
criterion = nn.CrossEntropyLoss() 

# 训练轮次 Epoch
EPOCHS = 20

模拟一个batch的训练

这里着重注意反向传播三部曲

# 反向传播“三部曲”
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 优化更新

 运行完整训练

# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):

    model. Train() #每次开始前将模型设置为训练模式

    for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)           # 前向预测,获得当前 batch 的预测结果
        loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数
        
        optimizer.zero_grad()
        loss.backward()                   # 损失函数对神经网络权重反向传播求梯度
        optimizer.step()                  # 优化更新神经网络权重

在测试集上进行初步测试

model.eval() #模型设置为测试模式
with torch.no_grad(): #不再回传梯度
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度
        _, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果
        total += labels.size(0)
        correct += (preds == labels).sum()   # 预测正确样本个数,如果预测类别等于标注类别

    print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

保存模型

torch.save(model, 'checkpoint/fruit30_pytorch_C1.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1434778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Haas 开发板连接阿里云上传温湿度和电池电压

目录 一、在阿里云上创建一个产品 二、开发环境的介绍 三、创建wifi示例 四、编写SI7006和ADC驱动 五、wifi配网 六、主要源码 七、查看实现结果 一、在阿里云上创建一个产品 登录自己的阿里云账号, 应该支付宝,淘宝账号都是可以的。 接着根据需求…

【leetcode题解C++】77.组合 and 216.组合总和III and 17.电话号码的字母组合

77. 组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[1,4], ] 示例 2: 输入&#xff1a…

5 分钟让你了解什么是搜索引擎

文章目录 搜索引擎概述基于业务模式分类垂直搜索(垂搜)通用搜索(通搜)本地搜索引擎 基于技术实现分类基于关键词的搜索引擎(Keyword-based Search Engine)语义搜索引擎(Semantic Search Engine&…

LeetCode 热题 100 | 链表(中下)

目录 1 19. 删除链表的倒数第 N 个节点 2 24. 两两交换链表中的节点 3 25. K 个一组翻转链表 4 138. 随机链表的复制 菜鸟做题第三周,语言是 C 1 19. 删除链表的倒数第 N 个节点 到底是节点还是结点。。。 解题思路: 设置双指针 left 和 ri…

ReactNative实现一个圆环进度条

我们直接看效果,如下图 我们在直接上代码 /*** 圆形进度条*/ import React, {useState, useEffect} from react; import Svg, {Circle,G,LinearGradient,Stop,Defs,Text, } from react-native-svg; import {View, StyleSheet} from react-native;// 渐变色 const C…

少儿编程教育新趋势:信息学奥赛与Scratch等级考试融合实践

近年来,信息学奥林匹克竞赛(简称信息学奥赛)以其独特的魅力吸引了大量热爱编程的青少年参与。这项赛事不仅考察参赛者的编程技能,更注重逻辑思维能力、问题解决能力和创新能力的培养。通过参加信息学奥赛,孩子们能够在…

OpenGL 入门(九)—Material(材质)和 光照贴图

文章目录 材质设置材质光的属性脚本实现 光照贴图漫反射贴图高光反射贴图 材质 材质本质是一个数据集,主要功能就是给渲染器提供数据和光照算法。 如果我们想要在OpenGL中模拟多种类型的物体,我们必须针对每种表面定义不同的材质(Material)属性。 我们…

设计模式1-访问者模式

访问者模式是一种行为设计模式,它允许你定义在对象结构中的元素上进行操作的新操作,而无需修改这些元素的类。这种模式的主要思想是将算法与元素的结构分离开,使得可以在不修改元素结构的情况下定义新的操作。 所谓算法与元素结构分离&#x…

不会PS怎么抠图?分享几个电商抠图的方法

在工作中,物品抠图是一项常见的任务。为了更好地展示物品,需要将其从背景中抠出来,以便与其他元素进行组合或展示。但是,手动抠图不仅费时费力,而且效果往往不尽如人意。这时,一款强大的物品抠图软件就成为…

【数据结构与算法】(10)基础数据结构 之 堆 建堆及堆排序 详细代码示例讲解

目录 2.9 堆建堆习题E01. 堆排序E02. 数组中第K大元素-Leetcode 215E03. 数据流中第K大元素-Leetcode 703E04. 数据流的中位数-Leetcode 295 2.9 堆 以大顶堆为例,相对于之前的优先级队列,增加了堆化等方法 public class MaxHeap {int[] array;int siz…

【已解决】Oracle 12541 TNS 无监听程序

目录 1、找到Oracle监听服务(OracleOraDb10g_homeTNLListener),停止运行 2、首先查看监听文件是否超过4G 3、修改配置文件 连接oracle突然报错,提示Oracle 12541 TNS 无监听程序,可以按照以下步骤解决 1、找到Ora…

【前沿技术杂谈:深度学习新纪元】探索人工智能领域的革命性进展

【前沿技术杂谈:深度学习新纪元】探索人工智能领域的革命性进展 深度学习的进展深度学习的基本原理和算法深度学习的历史发展神经网络的基本构成神经元层次结构激活函数 关键技术和算法反向传播算法卷积神经网络(CNN)循环神经网络&#xff08…

【操作系统·考研】I/O管理概述

1.I/O设备 1.1 块设备 信息交换以数据块为单位,它属于有结构设备。 块设备传输速率较高,可寻址,且可对该设备随机地的读写。 栗子🌰:磁盘。 1.2 字符设备 信息交换以字符为单位,属于无结构类型。 字符…

扩展鸿蒙textinput组件

扩展鸿蒙textinput组件,支持快速扩展展性,标题文本等,文本内容双向绑定、文本组件快速复用。 组件代码 /*** 单选文本*/ Component export default struct DiygwInput{//绑定的值Link value:string;//未选中图标State labelImg: Resource …

《热辣滚烫》预售狂潮来袭,贾玲、马丽、杨紫三大女神联袂出演。

♥ 为方便您进行讨论和分享,同时也为能带给您不一样的参与感。请您在阅读本文之前,点击一下“关注”,非常感谢您的支持! 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 《热辣滚烫》预售票房一日破1300万,燃爆春节档&am…

自定义Dockerfile构建运行springboot

自定义Dockerfile构建运行springboot 通过dockerfile生成自定义nginx镜像 !!!docker 必须在linux环境下才能进行如果你是window则需要装虚拟机 新建一个文件名字为Dockerfile,无需后缀 文件完整名就是Dockerfile,也可以自定义d…

有向图的拓扑排序-BFS求解

题目 给定一个n个点m条边的有向图,图中可能存在重边和自环。 请输出任意一个该有向图的拓扑序列,如果拓扑序列不存在,则输出-1。 若一个由图中所有点构成的序列A满足:对于图中的每条边(x, y),x在A中都出现在y之前,则称…

linux中的makefile

(码字不易,关注一下吧w~~w) makefile文件是用来管理项目文件,通过执行make命令,make就会解析并执行makefile文件。 命名:makefile或者Makefile 规则: 目标文件:依赖文件 (tab)命…

Narrative Visualization: Telling Stories with Data

作者:Edward Segel、Jeffrey Heer 发表:TVCG, 机构:UW Interactive Data Lab 【原斯坦福可视化组】 1.概述 静态可视化:在一大串的文本描述中,可视化作为提供证据和细节的图表出现新兴可视化&#xff1a…

设计模式学习笔记(一):基本概念;UML

文章目录 参考面向对象的设计原则创建型模式结构型模式行为型模式 UML视图图(Diagram)模型元素(Model Element)通用机制类之间的关系关联关系复杂!!聚合关系组合关系 依赖关系泛化关系接口与实现关系 参考 https://github.com/fa…