使用SwinTransformer进行图片分类

news2024/11/24 23:05:55

f2da30eb3a4453fa3639013e9b073d63.png

SwinTransformer 是微软亚洲研究院在2021年提出的适用于CV领域的一种基于Tranformer的backbone结构。

2af577c974fa45825ea3150c66699e2f.png

它是 Shift Window Transformer 的缩写,主要创新点如下。

  • 1,分Window进行Transformer计算,将自注意力计算量从输入尺寸的平方量级降低为线性量级。

  • 2,使用Shift Window 即窗格偏移技术 来 融合不同窗格之间的信息。(SW-MSA)

  • 3,使用类似七巧板拼图技巧 和Mask 技巧 来对 Window偏移后不同大小的窗格进行注意力计算以提升计算效率。

  • 4,在经典的QKV注意力公式中引入 Relative Position Bias 项来非常自然地表达位置信息的影响。

  • 5,使用Patch Merging技巧来 实现特征图的下采样,作用类似池化操作但不易丢失信息。

  • 6,使用不同大小的Window提取不同层次的特征并进行融合。

86cd19c0056f50dee21ba76fb9329981.png

SwinTransformer虽然采用了Transformer的实现方法,但在整体设计上借鉴了非常多卷积的设计特点。

如:局域性,平移不变性,特征图逐渐减小,通道数逐渐增加,多尺度特征融合等。

同时它还应用了非常多的trick来弥补Transformer的不足,如效率问题,位置信息表达不充分等。

B站上有UP主说SwinTransformer是披着Transformer皮的CNN。但毕竟它的主要内在计算是Transformer,所以我感觉它更像是叠加了卷积Buff的Transformer

SwinTransformer这个backbone结构表达能力非常强,同时适用性广泛,可适用于图片分类,分割,检测等多种任务,而且结构设计和实验工作都做得比较touch,所以被评为了2021年的ICCV best paper.

下面的范例我们微调 timm库中的 SwinTransformer模型来 做一个猫狗图片分类任务。

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和数据集下载链接。

#!pip install -U  timm, torchkeras

〇,预训练模型

import timm 
from urllib.request import urlopen
from PIL import Image
import timm
import torch 

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
img

752eb8ca6342e372b9c28a8beaf3ff04.png


model = timm.create_model("swin_base_patch4_window7_224.ms_in22k_ft_in1k", pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
info = timm.data.ImageNetInfo()
class_codes = info.__dict__['_synsets']
class_names = [info.__dict__['_lemmas'][x] for x in class_codes]
{class_names[i]:v for i,v in zip(top5_class_indices.tolist()[0],
                                top5_probabilities.tolist()[0])}
{'espresso': 0.1655443161725998,
 'cup': 0.12100766599178314,
 'chocolate sauce, chocolate syrup': 0.11809349805116653,
 'eggnog': 0.06144588068127632,
 'tray': 0.03965265676379204}
识别出来的主要是 espresso(蒸馏咖啡),cup 啥的,跟图片差不多,么得问题。

一,准备数据

import torch
import os
data_path = './datasets/cats_vs_dogs'

train_cats = os.listdir(os.path.join(data_path,"train","cats"))
img = Image.open(os.path.join(os.path.join(data_path,"train","cats",train_cats[0])))
img

84910346a10e4b5fa91aa608f0312ce1.png

train_dogs = os.listdir(os.path.join(data_path,"train","dogs"))
img = Image.open(os.path.join(os.path.join(data_path,"train","dogs",train_dogs[0])))
img

457329b334acbdeb122a6162b404db2f.png

from torchvision.datasets import ImageFolder


ds_train = ImageFolder(os.path.join(data_path,"train"),transforms)

ds_val = ImageFolder(os.path.join(data_path,"val"),transforms)


dl_train = torch.utils.data.DataLoader(ds_train, batch_size=4 ,
                                             shuffle=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2,
                                             shuffle=True)

class_names = ds_train.classes

print(len(ds_train))
print(len(ds_val))
2000
995
for batch in dl_val:
    break
batch[1]
tensor([0, 1])

二,定义模型

model.reset_classifier(num_classes=2)
model(batch[0])
tensor([[ 0.1698, -0.3366],
        [ 0.4805,  0.1415]], grad_fn=<AddmmBackward0>)
model.cuda();

三,训练模型

from torchkeras import KerasModel 
from torchmetrics import Accuracy

loss_fn = torch.nn.CrossEntropyLoss()
metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=2)}

optimizer = torch.optim.Adam(model.parameters(),
                            lr=1e-5)

keras_model = KerasModel(model,
                   loss_fn = loss_fn,
                   metrics_dict= metrics_dict,
                   optimizer = optimizer
                  )
features,labels = batch
loss_fn(model(features.cuda()),labels.cuda())
tensor(0.6743, device='cuda:0', grad_fn=<NllLossBackward0>)
dfhistory= keras_model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=100, 
                    ckpt_path='checkpoint.pt',
                    patience=10, 
                    monitor="val_acc",
                    mode="max",
                    mixed_precision='no',
                    plot = True,
                    quiet=True
                   )

3a0ef9e271bf187a5f594c9929e79e4d.png

可以看到SwinTransformer的拟合能力非常逆天,在这个简单的数据集上,finetune两个Epoch就直接把训练集上的Acc打到了100%,最后的验证集结果也是高达99.8%,非常强大~

四,评估模型

keras_model.evaluate(dl_val)

五,使用模型

from PIL import Image 
img = Image.open('./datasets/cats_vs_dogs/val/dogs/dog.2005.jpg')
model.eval();
model(transforms(img)[None,...].cuda()).softmax(axis=1)
tensor([[1.1537e-04, 9.9988e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

六,保存模型

torch.save(model.state_dict(),'swin_transformer.pt')

更多有趣范例,公众号算法美食屋后台回复关键词:torchkeras,可在tochkeras仓库获取范例源码。

49431072a4bbaa0f580555396752de60.png

619466906f396ad67a6670bf624ac227.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/644483.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣刷题记录(一)剑指Offer(第二版)

1、本栏用来记录社招找工作过程中的内容,包括基础知识学习以及面试问题的记录等,以便于后续个人回顾学习; 暂时只有2023年3月份,第一次社招找工作的过程; 2、个人经历: 研究生期间课题是SLAM在无人机上的应用,有接触SLAM、Linux、ROS、C/C++、DJI OSDK等; 3、参加工作后…

【JVM系列】jvm内存结构详解

文章目录 前言HotSpot VM介绍内存结构程序计数器JVM栈本地方法栈Java堆方法区运行时常量池 永久代和元空间永久代&#xff08;PermGen&#xff09;元空间&#xff08;Metaspace&#xff09; 直接内存总结 前言 我们为什么要学习JVM&#xff1f; 面试的需要中高级程序员的必备技…

java的字符流

字符流的底层也是字节流。字符流字节流字符集。 特点是输入流一次读一个字节&#xff0c;遇到中文时&#xff0c;一次读多个字节&#xff08;读多少个与字符集有关&#xff09;&#xff1b;输出流底层会把数据按照指定的编码方式进行编码&#xff0c;变成字节再写到文件中。 字…

AI绘图高级篇 第7篇 MJ以图换图-卡通头像

大家好&#xff0c;我是菜鸟哥 这个是我们MJ系列的第7篇&#xff0c;以前在会员群里发过&#xff0c;就是把头像做成卡通或者3D的效果还是很酷&#xff0c;或者是迪斯尼风格的。其实非常简单&#xff0c;就是用了一个MJ的以图换图的功能&#xff0c;今天给大家详细的说一下。 前…

ChatGPT 指南:如何与人工智能模型进行对话与互动

人工智能技术的快速发展使得我们能够与智能机器进行对话和互动。 ChatGPT 是一种基于 GPT-3.5 架构的强大语言模型&#xff0c;它能够进行自然语言处理&#xff0c;理解我们的问题并提供相应的回答。本文将为您提供使用 ChatGPT 进行对话和互动的详细指南。 1、提出问题 与 Cha…

Python爬虫之基础知识

爬虫基础知识 一、爬虫的概念 模拟浏览器&#xff0c;发送请求&#xff0c;获取响应 网络爬虫&#xff08;又被称为网页蜘蛛&#xff0c;网络机器人&#xff09;就是模拟客户端(主要指浏览器)发送网络请求&#xff0c;接收请求响应&#xff0c;一种按照一定的规则&#xff0c;…

基于springboot+mybatis+mysql+html企业人事管理系统

基于springbootmybatismysqlhtml企业人事管理系统 一、系统介绍二、功能展示1.用户登陆2.员工奖惩--员工3.合同管理--员工4.个人薪酬--员工5.培训管理--员工6.个人绩效--员工7.员工管理&#xff08;管理员&#xff09;8.奖惩管理&#xff08;管理员&#xff09;9.薪酬管理&…

华为OD机试真题B卷 JavaScript 实现【公共子串计算】,附详细解题思路

一、题目描述 给定两个只包含小写字母的字符串&#xff0c;计算两个字符串的最大公共子串的长度。 注&#xff1a;子串的定义指一个字符串删掉其部分前缀和后缀&#xff08;也可以不删&#xff09;后形成的字符串。 二、输入描述 输入两个只包含小写字母的字符串。 三、输…

chatgpt赋能python:Python使用的排大小方式全解析:关于SEO的学习!

Python使用的排大小方式全解析&#xff1a;关于SEO的学习&#xff01; 对于一个Python工程师来说&#xff0c;深度理解编程语言的基础知识总是非常重要的&#xff0c;包括了语法、函数、模块、数据结构以及算法等等。而在SEO领域&#xff0c;Python所采用的排大小方式&#xf…

jdk动态代理和cglb动态代理

目录 概述 JDK动态代理 cglb动态代理 概述 动态代理和静态代理都是代理模式的实现方式&#xff0c;其主要区别在于代理类生成的时机和方式。 静态代理是在编译时就确定了代理类的代码&#xff0c;在程序运行前就已经存在了代理类的class文件。代理类与委托类的关系在编译时就…

STM32F407移植1588v2(ptpd)

硬件&#xff1a; STM32F407ZGT6开发板 软件&#xff1a; VSCode arm-none-eabi-gcc openOCD st-link 在github搜到一个在NUCLEO-F429ZI开发板上移植ptpd的example&#xff0c;因为和F407差别很小&#xff0c;所以就打算用这个demo移植到手头的开发板上。因为目前只需要…

【C语言】VScode中配置C语言/C++运行环境(保姆级图文)

目录 省流助手1. 下载安装VScode2. 下载编译器MinGW并解压3. 将MinGW添加至环境变量4. 安装vscode的插件5. 运行代码6. 调整和优化&#xff08;这部分转自零流火星动力猿 2022.4.12&#xff09;总结 欢迎关注 『C语言』 系列&#xff0c;持续更新中 欢迎关注 『C语言』 系列&am…

初识滴滴交易策略之三:供需调节

本篇文章分为&#xff1a; 1.什么是交易市场中的供需&#xff1f; 供需的动态性供需的相互作用 2.滴滴业务场景涉及的供需调节技术 供需感知和供需预测 时序预测供需调节以提升市场匹配程度&#xff0c;保持供需平衡 整数规划为司机规划更好的出车方式 模仿学习&#xff08;Imi…

【工具篇】Maven使用${revision}实现多模块版本统一管理

背景说明 在使用Maven多模块结构工程时&#xff0c;版本管理是一件很繁琐且容易出错的事情。每次升级版本号都要手动调整或者通过mvn versions:set -DnewVerionxx命令去更改每一个子模块的版本号&#xff0c;非常的不方便&#xff0c;而且会改动所有的模块&#xff0c;出现如下…

Vue语法(4)

目录 1. 自定义指令 1.1 使用方法 1.2 实际案例 1.3 全局指令和局部指令 2. 组件对象 2.1 组件基础 2.2 组件对象 2.3 组件的属性——props 2.4 全局组件和局部组件 2.5 单文件组件 1. 概念&#xff1a; 2. 配置开发环境的指令&#xff1a; 3. 单文件组…

云迁移为业务赋能,跨出数字化转型第一步

新钛云服已累计为您分享752篇技术干货 云迁移如何赋能企业业务 随着科技的迅猛发展和数字化转型的浪潮席卷全球&#xff0c;越来越多的企业开始意识到云计算的重要性和潜力。在这个数字化时代&#xff0c;企业不再满足于传统的IT架构和数据中心&#xff0c;而是追求更高效、灵活…

《JAX可微分编程》包邮送书五本

文章目录 前言JAX到底是什么&#xff1f;书籍内容介绍包邮送书5本 前言 2015年&#xff0c;Google Brain开放了一个名为「TensorFlow」的研究项目&#xff0c;这款产品迅速流行并成为人工智能业界的主流深度学习框架&#xff0c;塑造了现代机器学习的生态系统。 7年后&#x…

Shell脚本文本三剑客之sed编辑器

目录 一、sed编辑器简介 二、sed工作流程 三、sed命令 四、sed命令的使用 1.sed打印文件内容&#xff08;p&#xff09; &#xff08;1&#xff09;打印文件所有行 &#xff08;2&#xff09;打印文件指定行 2.sed增加、插入、替换行&#xff08;a、i、c&#xff09; …

CVPR 2023 | 视频AIGC,预测/插帧/生成/编辑

1、A Dynamic Multi-Scale Voxel Flow Network for Video Prediction 视频预测&#xff08;video prediction&#xff09;的性能已经通过先进的深度神经网络大幅提高。然而&#xff0c;大多数当前的方法存在着大的模型尺寸和需要额外的输入&#xff08;如&#xff0c;语义/深度…

【Instruction Tuning】ChatGLM 微调实战(附源码)

在之前的文章中&#xff0c;我们已经讲过了 ChatGPT 的三个主要流程&#xff1a; SFT&#xff1a;通过 Instruction Tuning 来微调一个监督学习模型。Reward Model&#xff1a;通过排序序列来训练一个打分模型。Reinforcement Learning&#xff1a;通过强化学习来进一步优化模…