【深度学习-图像识别】使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)

news2024/12/23 23:52:59

文章目录

  • 前言
    • fastai介绍
      • 数据集介绍
  • 一、环境准备
  • 二、数据集处理
    • 1.数据目录结构
    • 2.导入依赖项
    • 2.读入数据
    • 3.模型构建
      • 3.1 寻找合适的学习率
      • 3.2 模型调优
    • 4.模型保存与应用
  • 总结
      • 人工智能-图像识别 系列文章目录


前言

fastai介绍

fastai 是一个深度学习库,它为从业人员提供了高级组件,可以快速、轻松地在标准深度学习领域提供最先进的结果,并为研究人员提供了低级组件,可以混合和匹配以构建新的方法。以解耦抽象的方式表达了许多深度学习和数据处理技术的通用底层模式。
fastai 有两个主要的设计目标:易于使用、快速高效,同时具有很强的可破解性和可配置性。它建立在提供可组合构件的低级应用程序接口的层次结构之上。这样,如果用户想重写部分高级应用程序接口或添加特定行为以满足自己的需求,就不必学习如何使用最底层的应用程序接口。
在这里插入图片描述

数据集介绍

下载链接
Caltech101国内下载地址
Caltech101

Caltech101数据集内部有 101 个类别的物体图片。每个类别约有 40 至 800 张图片。大多数类别约有 50 张图片。每张图片的大小大约为 300 x 200 像素。并且作者还标注了这些图片中每个物体的轮廓,这些都包含在 "Annotations.tar "中。还有一个 MATLAB 脚本 "show_annotations.m "可以查看注释。

Collected in September 2003 by Fei-Fei Li, Marco Andreetto, and
Marc’Aurelio Ranzato。

一、环境准备

这里展示使用GPU进行训练的环境搭建,只用CPU也可以进行训练,只是训练时间比较慢。
首先安装Anaconda,通过conda安装我们需要的包

 conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
 conda install -c nvidia fastai anaconda

详情可见第一篇文章。

二、数据集处理

1.数据目录结构

├───data_iamge
│   ├───101_ObjectCategories
│   │   ├───accordion
│   │   ├───airplanes
│   │   ├───anchor
│   │   ├───ant
│   │   ├───BACKGROUND_Google
│   │   ├───barrel
│   │   ├───bass
│   │   ├───beaver
│   │   ├───binocular
│   │   ├───bonsai
│   │   ├───brain
│   │   ├───brontosaurus
...

2.导入依赖项

from fastai import *
from fastai.vision.all import *
from fastai.metrics import error_rate

import os
#from keras.utils import plot_model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

查看环境以及版本信息,cuda.is_available()判断是否可以用GPU。

print(torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

True
2.0.1
11.8
8700

'''SEED Everything'''
def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True # keep True if all the input have same size.
SEED=42
seed_everything(SEED=SEED)
'''SEED Everything'''

2.读入数据

代码如下(示例):

path='./data_image/101_ObjectCategories/'
image_rsize=224
item_tfms = [Resize((image_rsize,image_rsize))]
data = ImageDataLoaders.from_folder(path, train = '.', valid_pct=0.2,
                                   size=image_rsize,
                                  item_tfms=item_tfms)
data.show_batch(figsize=(7,6))

在这里插入图片描述

3.模型构建

这里使用预训练模型resnet101,这是一个非常优秀的残差网络模型。
这些残差网络更容易优化,并且可以从显着增加的深度中获得准确性。
这些残差网络的集合在 ImageNet 测试集上实现了 3.57% 的误差。该结果在ILSVRC 1分类任务中获得第一名。

learn = cnn_learner(data, models.resnet101, model_dir='./model', path = Path("."))

3.1 寻找合适的学习率

learn.lr_find()

在这里插入图片描述

接下来使用fit_one_cycle方法用更小的学习率进一步训练。fit_one_cycle使用的是一种周期性学习率,从较小的学习率开始学习,缓慢提高至较高的学习率,然后再慢慢下降,周而复始,每个周期的长度略微缩短,在训练的最后部分,允许学习率比之前的最小值降得更低。这不仅可以加速训练,还有助于防止模型落入损失平面的陡峭区域,使模型更倾向于寻找更平坦的极小值,从而缓解过拟合现象。

lr1 = 1e-3
lr2 = 1e-1
epoch	train_loss	valid_loss	time
0	1.417713	1.648756	00:45
1	3.097069	9.964518	00:43
2	5.385355	5.347832	00:44
3	4.194504	12.162844	00:44
4	2.985504	3.486863	00:43
5	2.152388	22.297184	00:43
6	1.295905	3.554162	00:43
7	0.630879	9.193820	00:43
8	0.361619	49.334236	00:43
9	0.255115	9.832499	00:43

3.2 模型调优

unfreeze
在fastai课程中使用的是预训练模型,模型卷积层的权重已经提前在ImageNet
上训练好了,在使用的时候一般只需要在预训练模型最后一层卷积层后添加自定义的全连接层即可。卷积层默认是freeze的,即在训练阶段进行反向传播时不会更新卷积层的权重,只会更新全连接层的权重。在训练几个epoch之后,全连接层的权重已经训练的差不多了,但accuracy还没有达到你的要求,这时你可以调用unfreeze然后再进行训练,这样在进行反向传播时便会更新卷积层的权重(一般不会对卷积层权重进行较大的更新,只会进行一点点的微调,越靠前的卷积层调整的幅度越小,所以有了differential
learning rate 这一想法)

precompute
当precompute=True时,会提前计算出每一个训练样本(不包括增强样本)在预训练模型最后一层卷积层的activation,
并将其缓存下来,之后在训练阶段进行前向传播的时候,直接将precompute 的activation 作为后面全连接层(FC
Layer)的输入,这样便省去前面卷积层进行前向传播的计算量,减少训练所需时间(这种优势在epoch比较大的时候能够显著0提高训练速度)。当precompute=False时,则不会提前计算训练样本的activation,每一个epoch都需要重新将训练样本+增强样本(前提是进行了增强操作)进行卷积层的前向传播,然后进行反向传播更新对应的权重。

learn.unfreeze()
learn.show_results()

在这里插入图片描述
从展示的部分训练结果可以看出,只有一张图被预测错误了,其他的都是正确的。

4.模型保存与应用

最后我们可以将模型保存下来,并且对验证集的图片的类别进行预测。

learn.export(Path("./model/export.pkl"))
from PIL import Image
img = Image.open(path+'ant/image_0001.jpg')
image_rsize=224
# Resize the image to 224x224
img_resized = img.resize((image_rsize,image_rsize))
pred, pred_idx, probs = learn.predict(img_resized)
im_t = cast(array(img_resized), TensorImage)
# Print the predicted label and probability
print(f"Predicted label: {pred}, probability: {probs[pred_idx]:.4f}")
img

在这里插入图片描述

总结

epoch	train_loss	valid_loss	time
0	1.030772	979.477417	00:52
1	1.074642	86.289436	00:52
2	0.553576	0.457210	00:52
3	0.302997	0.546438	00:52
4	0.176070	0.596845	00:52

我们借助fastai训练了resnet101模型,对 101 个类别的图像数据集进行了分类。
使用基于pytorch的fastai库,使用resnet模型和有101个类别的Caltech101图像数据集,训练了一个高准确率的多分类的深度学习模型,能够对101个类别的图像大数据集进行准确的图像类别识别。
使用简洁高效的代码,借助GPU提升训练速度(也可以使用CPU训练,本项目会自动识别硬件),首先数据集进行预处理,然后对模型进行训练,并将模型保存为pkl格式,最后对测试集的图像的类别进行预测。

可见,使用fastai进行图像多分类是非常简便的,所使用的代码行数非常少却能达到很高的准确率,而且借助GPU训练速度非常快。

这里将全部的代码和图片数据集打包起来了,方便大家复现。
开箱即用,欢迎下载
使用fastai对Caltech101数据集进行图像多分类


人工智能-图像识别 系列文章目录

  1. 环境搭建: pytorch以及fastai安装,配置GPU训练环境 待更新。。。
  2. 使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/899291.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PyTorch学习笔记(十五)——完整的模型训练套路

以 CIFAR10 数据集为例,分类问题(10分类) model.py import torch from torch import nn# 搭建神经网络 class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.Ma…

C语言:深度学习知识储备

目录 数据类型 每种类型的大小是多少呢? 变量 变量的命名: 变量的分类: 变量的作用域和生命周期 作用域: 生命周期: 常量 字符串转义字符注释 字符串: 转义字符 操作符: 算术操作符…

nginx反向代理、负载均衡

修改nginx.conf的配置 upstream nginx_boot{# 30s内检查心跳发送两次包,未回复就代表该机器宕机,请求分发权重比为1:2server 192.168.87.143 weight100 max_fails2 fail_timeout30s; server 192.168.87.1 weight200 max_fails2 fail_timeout30s;# 这里的…

【流程引擎】--Camunda基础及sprringboot简单集成Camunda

目录 一、前言二、Camunda基本介绍2.1、camunda基础--符号表示2.2、camunda基础--网关表示2.3、camunda基础--事件表示 三、springboot集成Camunda四、后续 一、前言 目前市场上有常见的流程引擎:JBPM、Activiti、Camunda、Flowable、CompileFlow。它们的发展史如下…

TR 已经释放 task未释放的问题

货铺QQ群号:834508274 微信群不能扫码进了,可以加我微信SAPliumeng拉进群,申请时请提供您哪个模块顾问,否则是一律不通过的。 进群统一修改群名片,例如BJ_ABAP_森林木。群内禁止发广告及其他一切无关链接,小…

16-案例-记账单

功能需求: <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </head> &l…

224、仿真-基于51单片机音乐播放器流水灯控制Proteus仿真设计(程序+Proteus仿真+原理图+程序流程图+元器件清单+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括&#xff1a; 需要完整的资料可以点击下面的名片加下我&#xff0c;找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选…

C++音乐播放系统

C音乐播放系统 音乐的好处c发出声音乐谱与赫兹对照把歌打到c上 学习c的同学们都知道&#xff0c;c是一个一本正经的编程语言&#xff0c;因该没有人用它来做游戏、做病毒、做…做…做音乐播放系统吧&#xff01;&#xff01; 音乐的好处 提升情绪&#xff1a;音乐能够影响我们…

【C++进阶】继承、多态的详解(多态篇)

【C进阶】继承、多态的详解&#xff08;多态篇&#xff09; 目录 【C进阶】继承、多态的详解&#xff08;多态篇&#xff09;多态的概念多态的定义及实现多态的构成条件&#xff08;重点&#xff09;虚函数虚函数的重写&#xff08;覆盖、一种接口继承&#xff09;C11 override…

解决C#报“MSB3088 未能读取状态文件*.csprojAssemblyReference.cache“问题

今天在使用vscode软件C#插件&#xff0c;编译.cs文件时&#xff0c;发现如下warning: 图(1) C#报cache没有更新 出现该warning的原因&#xff1a;当前.cs文件修改了&#xff0c;但是其缓存文件*.csprojAssemblyReference.cache没有更新&#xff0c;需要重新清理一下工程&#x…

双层优化入门(4)—基于对偶变换的双层优化求解

之前的博客介绍了双层优化的基本原理、以及如何使用KKT条件和智能优化算法求解双层优化问题&#xff0c;这篇博客将继续介绍如何通过对偶变换求解双层优化问题。 1.线性规划的对偶问题 参考资料&#xff1a; 运筹学修炼日记&#xff1a;如何优雅地写出大规模线性规划的对偶_刘…

spring boot 整合支付宝微信支付

1.目录结构 2.引入依赖 <!--引入阿里支付--><dependency><groupId>com.alipay.sdk</groupId><artifactId>alipay-sdk-java</artifactId><version>4.11.8.ALL</version></dependency><!--引入微信支付--><depe…

Redis中的淘汰策略

前言 本文主要说明在Redis面临key过期和内存不足的情况时&#xff0c;可以采用什么策略进行解决问题。 Redis中是如何应对过期数据的 正如我们知道的Redis是基于内存的、单线程的一个中间件&#xff0c;在面对过期数据的时候&#xff0c;Redis并不会去直接把它从内存中进行剔…

运用工具Postman快速导出python接口测试脚本

Postman的脚本可以导出多种语言的脚本&#xff0c;方便二次维护开发。 Python的requests库&#xff0c;支持python2和python3&#xff0c;用于发送http/https请求 使用unittest进行接口自动化测试 一、环境准备 1、安装python&#xff08;使用python2或3都可以&#xff09;…

HCIP之VLAN实验

目录 一、实验题目 二、实验思路 三、实验步骤 3.1 将接口划入vlan&#xff0c;设置trunk干道 3.2 启动DHCP服务&#xff0c;下发地址 四、测试 一、实验题目 实验要求&#xff1a; 1&#xff0c;PC1/3的接口均为access模式&#xff0c;且属于vlan2&#xff0c;处于同一…

pyltp 0.2.1安装

1. LTP及pyltp pyltp是 LTP的 Python封装&#xff0c;它里面提供了包括分词&#xff0c;词性标注&#xff0c;命名实体识别&#xff0c;句法分析等等能力。 比较坑的是我们可能无法直接通过pip install pyltp0.2.1方式来安装&#xff0c;所以本文就简单记录下如何通过源码安装…

04_15页表缓存(TLB)和巨型页

前言 linux里面每个物理内存(RAM)页的一般大小都是4kb(32位就是4kb),为了使管理虚拟地址数变少 加快从虚拟地址到物理地址的映射 建议配值并使用HugePage巨型页特性 cpu和mmu和页表缓存(TLB)和cache和ram的关系 CPU看到的都是虚拟地址&#xff0c;需要经过MMU的转化&#xf…

langchain-ChatGLM源码阅读:模型加载

文章目录 使用命令行参数初始化加载器模型实例化清空显存加载模型调用链loader.py的_load_model方法auto_factory.py的from_pretrained方法modeling_utils.py的from_pretrained方法hub.py的get_checkpoint_shard_files方法modeling_utils.py的_load_pretrained_mode方法回到loa…

电脑远程接入软件可以进行文件传输吗?快解析内网穿透

电脑远程接入软件的出现&#xff0c;让我们可以在两台电脑之间进行交互和操作。但是&#xff0c;很多人对于这些软件能否进行文件传输还存在一些疑问。下面的文章将解答这个问题。 1.电脑远程接入软件可以进行文件传输。传统上&#xff0c;我们可能会通过传输线或者移动存储设…