【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)

news2024/11/29 8:38:30

一、具体介绍

timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。

timm 的特点如下:

  1. PyTorch 原生实现:timm 的实现方式与 PyTorch 高度契合,开发者可以方便地使用 PyTorchAPI 进行模型训练和部署。
  2. 轻量级的设计:timm 的设计以轻量化为基础,根据不同的计算机视觉任务,提供了多种轻量级的网络结构。
  3. 大量的预训练模型:timm 提供了大量的预训练模型,可以直接用于各种计算机视觉任务。
  4. 多种模型组件:timm 提供了各种模型组件,如注意力模块、正则化模块、激活函数等等,这些模块都可以方便地插入到自己的模型中。
  5. 高效的代码实现:timm 的代码实现高效并且易于使用。

需要注意的是,timm 是一个社区驱动的项目,它由计算机视觉领域的专家共同开发和维护。在使用时需要遵循相关的使用协议。

二、图像分类案例

下面以使用 timm 实现图像分类任务为例,进行简单的介绍。

2.1 安装 timm 包

!pip install timm

2.2 导入相关模块,读取数据集

import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# 数据增强
train_transforms = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 数据集
train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transforms)

# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

导入相关模块,其中 timmtorchvision.datasets.CIFAR10 需要分别安装 timmtorchvision 包。

定义数据增强的方式,其中训练集和测试集分别使用不同的增强方式,并且对图像进行了归一化处理。transforms.Compose() 可以将各种操作打包成一个 transform 操作流,transforms.ToTensor() 将图像转化为 tensor 格式,transforms.Normalize() 将图像进行标准化处理。

使用自带的 CIFAR10 数据集,设置 train=True 定义训练集,设置 train=False 定义测试集。数据集会自动下载到指定的 root 路径下,并进行数据增强操作。

使用 torch.utils.data.DataLoader 定义数据加载器,将数据集包装成一个高效的可迭代对象,其中 batch_size 定义批次大小,shuffle 定义是否对数据进行随机洗牌,num_workers 定义使用多少个 worker 来加载数据。

在这里插入图片描述

2.3 定义模型

# 加载预训练模型
model = timm.create_model('resnet18', pretrained=True)

# 修改分类器
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

这里使用 timm.create_model() 函数来创建一个预训练模型,其中参数 resnet18 定义了使用的模型架构,参数 pretrained = True 表示要使用预训练权重。

这里修改了模型的分类器,首先使用 model.fc.in_features 获取模型 fc 层的输入特征数,然后使用 nn.Linear() 重新定义了一个 nn.Linear 层,输入为上一层的输出特征数,输出为类别数(即 len(train_dataset.classes))。这里直接使用了数据集类别数来定义输出层,以适配不同分类任务的需求。

在这里插入图片描述
在这里,我们使用了 timm 中的 ResNet18 模型,并将其修改为我们需要的分类器,同时在创建模型时,设置参数 pretrained=True 来加载预训练权重。

2.4 定义损失函数和优化器

# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

在深度学习中,损失函数是评估模型预测结果与真实标签之间差异的一种指标,常用于模型训练过程中。nn.CrossEntropyLoss() 是一个常用的损失函数,适用于多分类问题。

优化器用于更新模型参数以使损失函数最小化。在这里,我们使用了随机梯度下降法(SGD)优化器,以控制模型权重的变化。通过 model.parameters() 指定需要优化的参数,lr 定义了学习率,表示每次迭代时参数必须更新的量的大小,momentum 则是添加上次迭代更新值的一部分到这一次的更新值中,以减小参数更新的方差,稳定训练过程。

2.5 训练模型

num_epochs = 10

for epoch in range(num_epochs):
    # 训练
    model.train()
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 测试
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct/total))

这段代码是模型训练和测试的循环。num_epochs 定义了循环的次数,每次循环表示一个训练周期。

在训练阶段,首先将模型切换到训练模式,然后使用 train_loader 迭代地读取训练集数据,进行前向传播、计算损失、反向传播和优化器更新等操作。

在测试阶段,模型切换到评估模式,然后使用 test_loader 读取测试集数据,进行前向传播和计算模型预测结果,使用预测结果和真实标签进行准确率计算,并输出每个训练周期的准确率。

其中,torch.max() 函数用于返回每行中最大值及其索引,total 记录了总的测试样本数,correct 记录了正确分类的样本数,最后计算准确率并输出。

输出结果为:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/523980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java之线程池

目录 一.上节复习 1.阻塞队列 二.线程池 1.什么是线程池 2.为什么要使用线程池 3.JDK中的线程池 三.工厂模式 1.工厂模式的目的 四.使用线程池 1.submit()方法 2.模拟两个阶段任务的执行 五.自定义一个线程池 六.JDK提供线程池的详解 1.如何自定义一个线程池? 2.创…

【计网】第三章 数据链路层(3)信道划分介质访问控制

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 3.5-1 信道划分介质访问控制(播报信道中应用)一、传输数据使用的两种链路二、介质访问控制 三、信道划分 介质访问控制(静态划分…

协程切换原理与实践 -- 从ucontext api到x86_64汇编

目录 1.协程切换原理理解 2.ucontext实现协程切换 2.1 实现流程 2.2 根据ucontext流程看协程实现 2.3 回答开头提出的问题 3.x86_64汇编实现协程切换 3.1libco x86_64汇编代码分析 3.2.保存程序返回代码地址流程 3.3.恢复程序地址以及上下文 4.实现简单协程框架 1.协程…

《编程思维与实践》1071.猜猜猜

《编程思维与实践》1071.猜猜猜 题目 思路 对于首字符而言,如果后一位字符与之相同,则首位选法只有1种,不同则2种; 对于最后一位字符而言,如果前一位字符与之相同,则末位选法只有1种,不同则2种; 对于中间的字符而言,有以下几种可能: 1.中间字符与前后字符均不同且前后字符不同…

企业挑选人力资源管理系统,需要从哪些角度考察?

企业在挑选人力资源管理系统时,除了要考虑到企业自身的主要需求外,还应该从哪些角度考察人力资源管理系统呢?一起来看看吧~ 一. 数据是否共通 企业在人力资源管理系统时通常有多个功能模块的需求。除了要看系统是否具备这些功能模块&#xff…

一分钟图情论文:《数据与信息之间逻辑关系的探讨——兼及DIKW概念链模式》

一分钟图情论文:《数据与信息之间逻辑关系的探讨——兼及DIKW概念链模式》 1989年,Ackoff R L在论文:《From data to wisdom》中正式提出DIKW概念链模型,在该模型提出后的20年间,在计算机学科、信息管理学科、图书情报…

数据结构--线段树

写在前面: 学习之前需要知道以下内容: 1. 递归 2. 二叉树 文章目录 线段树介绍用途建树修改单点修改区间修改 查询 代码实现。建树更新lazy传递查询 练习洛谷 P3372 【模板】线段树 1题目描述题解 线段树 介绍 线段树是一种二叉树,也可以…

【5G RRC】5G中的服务小区和邻区测量方法

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

STL配接器(容器适配器)—— stack 的介绍使用以及模拟实现。

注意 : 以下所有文档都来源此网站 : http://cplusplus.com/ 一、stack 的介绍和使用 stack 文档的介绍:https://cplusplus.com/reference/stack/stack/ 1. stack是一种容器适配器,专门用在具有后进先出操作的上下文环境中&…

Matlab进阶绘图第20期—带类别标签的三维柱状图

带类别标签的三维柱状图是一种特殊的三维柱状图。 与三维柱状图相比,带类别标签的三维柱状图通过颜色表示每根柱子的所属类别,从而可以更加直观地表示四维/四变量数据。 由于Matlab中未收录带类别标签的三维柱状图的绘制函数,因此需要大家自…

Java 使用 jdbc 连接 mysql

简介 Java JDBC 是 Java Database Connectivity 的缩写,它是一种用于连接和操作数据库的标准 API。Java JDBC 可以让 Java 程序通过 JDBC 驱动程序连接到各种不同类型的数据库,并且执行 SQL 语句来实现数据的读取、插入、更新、删除等操作。在本篇文章中…

Springboot整合Flowable流程引擎

文章目录 前言1. Flowable的主要表结构1.1 通用数据表(通用表)1.2运行时数据表(runtime表)1.3.历史数据表(history表)1.4. 身份数据表(identity表)1.5. 流程定义数据表(r…

C++: 并行加速图像读取和处理的过程

文章目录 1. 目的2. 设计3. 串行实现4. 并行实现5. 比对:耗时和正确性6. 加速比探讨 1. 目的 读取单张图像,计算整图均值,这很好实现,运行耗时很短。 读取4000张相同大小的图像,分别计算均值,这也很好实现…

【OpenCv • c++】形态学技术操作 —— 开运算与闭运算

🚀 个人简介:CSDN「博客新星」TOP 10 , C/C 领域新星创作者💟 作 者:锡兰_CC ❣️📝 专 栏:【OpenCV • c】计算机视觉🌈 若有帮助,还请关注➕点赞➕收藏&#xff…

openGauss5.0.0在vscode成功调试

之前在虚拟机上编译成功过,但今天启动数据库的时候出现权限错误问题,我重新删除了data文件夹,重新初始化启动数据库还是不成功,后来对报错文件进行赋权,成功解决! 问题(一) 1.启动…

图像水印MATLAB实验

文章目录 一、实验目的二、实验内容1. 简单的可见水印嵌入实验2. 不可见脆弱水印实验3. 不可见鲁棒水印实验 一、实验目的 了解数字图像水印技术的基本原理、分类和应用。掌握简单的可见水印和不可见水印的嵌入方法。实现一种基于DCT的不可见鲁棒水印,并进行水印鲁…

Dubbo 服务端源码深入分析 (7)

目录 1. 前提 2. 认识 Protocol 和 ProxyFactory Protocal ProxyFactory Dubbo服务流程 服务端源码分析 测试代码: Protocal代理的源码 ProxyFactory源码: 获取invoker对象 具体步骤 1. 我们调用的是ProxyFactory的代理对象的getInvoker方法…

Linux线程同步(6)——更高并行性的读写锁

互斥锁或自旋锁要么是加锁状态、要么是不加锁状态,而且一次只有一个线程可以对其加锁。读写锁有 3 种状态:读模式下的加锁状态(以下简称读加锁状态)、写模式下的加锁状态(以下简称写加锁状态)和不加锁状态&…

django视图(request请求response返回值)

一、视图函数介绍 视图就是应用中views.py中定义的函数,称为视图函数 def index(request):return HttpResponse("hello world!") 1、视图的第一个参数必须为HttpRequest对象,还可能包含下参数如通过正则表达式组获取的位置参数、通…

VBA——01篇(入门篇——简单基础语法)

VBA——01篇(入门篇——简单基础语法) 1. 语法格式1.1 简单语法1.2 简单例子 2. 变量2.1 常用数据类型2.2 声明变量的常用方式2.3 简单例子 3. 单元格赋值3.1 直接赋值3.2 拷贝单元格 4. 简单的逻辑语法4.1 简单if4.2 简单for循环4.2.1 简单语法例子4.2.…