第51步 深度学习图像识别:Convolutional Vision Transformer建模(Pytorch)

news2024/10/6 22:34:32

基于WIN10的64位系统演示

一、写在前面

(1)Convolutional Vision Transformers

Convolutional Vision Transformer(ConViT)是一种结合了卷积神经网络(Convolutional Neural Networks,简称CNN)和Transformer结构的深度学习模型,由Facebook AI在2021年提出。该模型旨在利用CNN在局部特征提取方面的优势,同时保留Transformer在长距离依赖性处理上的优点,以此提升图像识别任务的性能。

在传统的Vision Transformer(ViT)模型中,输入图像会被划分为一系列的小块或者"patches",然后用Transformer处理这些独立的patches。这种方法的一个问题是,它并未充分利用图像的局部信息,这些信息在CNN模型中得到了很好的处理。

ConViT模型提出了一种解决方案,即在每个Transformer的自注意力(Self-Attention)机制中引入卷积的归纳偏置(Convolutional Inductive Bias)。这种设计使得模型在处理每个patch时,能更加考虑到它与周围patch的关系,从而更好地捕捉图像的局部信息。同时,由于Transformer的全局自注意力机制,ConViT也能处理图像的长距离依赖性。

(2)Convolutional Vision Transformers的码源

本文继续使用Facebook的高级深度学习框架PyTorchImageModels (timm),网址为:

https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convit.py

二、Convolutional Vision Transformers迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

import torch
from torchvision import datasets, transforms
import os

# 数据集路径
data_dir = "./MTB"

# 图像的大小
img_height = 100
img_width = 100

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_height),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)

# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占80%
val_size = full_size - train_size  # 验证集的大小

# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']

# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

(c)导入Convolutional Vision Transformers

# 导入所需的库
import torch.nn as nn
import timm

# 定义Convolutional Vision Transformer模型
model = timm.create_model('convit_base', pretrained=True)  # 你可以选择适合你需求的ConViT版本,这里以'convit_base'为例
num_ftrs = model.head.in_features

# 根据分类任务修改最后一层
model.head = nn.Linear(num_ftrs, len(class_names))

# 将模型移至指定设备
model = model.to(device)

# 打印模型摘要
print(model)

(d)编译模型

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters())

# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 开始训练模型
num_epochs = 10
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 每个epoch都有一个训练和验证阶段
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # 遍历数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 零参数梯度
            optimizer.zero_grad()

            # 前向
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # 只在训练模式下进行反向和优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # 统计
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()

        # 记录每个epoch的loss和accuracy
        if phase == 'train':
            train_loss_history.append(epoch_loss)
            train_acc_history.append(epoch_acc)
        else:
            val_loss_history.append(epoch_loss)
            val_acc_history.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        # 深拷贝模型
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    print()

print('Best val Acc: {:4f}'.format(best_acc))

(e)Accuracy和Loss可视化

epoch = range(1, len(train_loss_history)+1)

fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()

ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()

#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

 蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('Confusion matrix',fontsize=15)
    plt.ylabel('Actual value',fontsize=14)
    plt.xlabel('Predictive value',fontsize=14)
    
def evaluate_model(model, dataloader, device):
    model.eval()   # 设置模型为评估模式
    true_labels = []
    pred_labels = []
    # 遍历数据
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 前向
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
        
    return true_labels, pred_labels

# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)

# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]

# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数

# 打印出性能指标
print("验证集的灵敏度为:", sen_val, 
      "验证集的特异度为:", sep_val,
      "验证集的准确率为:", acc_val, 
      "验证集的错误率为:", error_rate_val,
      "验证集的精确度为:", precision_val, 
      "验证集的F1为:", F1_val,
      "验证集的MCC为:", MCC_val)

# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)

    
# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, 
      "训练集的特异度为:",sep_train,
      "训练集的准确率为:",acc_train, 
      "训练集的错误率为:",error_rate_train,
      "训练集的精确度为:",precision_train, 
      "训练集的F1为:",F1_train,
      "训练集的MCC为:",MCC_train)

# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)

效果不错:

 (g)AUC曲线绘制

from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import math

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = metrics.roc_curve(labels, predictions)

    plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
    plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
    plt.xlabel('False positives rate')
    plt.ylabel('True positives rate')
    ax = plt.gca()
    ax.set_aspect('equal')


# 确保模型处于评估模式
model.eval()

train_ds = dataloaders['train']
val_ds = dataloaders['val']

val_pre_auc   = []
val_label_auc = []

for images, labels in val_ds:
    for image, label in zip(images, labels):      
        img_array = image.unsqueeze(0).to(device)  # 在第0维增加一个维度并将图像转移到适当的设备上
        prediction_auc = model(img_array)  # 使用模型进行预测
        val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
        val_label_auc.append(label.item())  # 使用Tensor.item()获取Tensor的值
auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)


train_pre_auc   = []
train_label_auc = []

for images, labels in train_ds:
    for image, label in zip(images, labels):
        img_array_train = image.unsqueeze(0).to(device) 
        prediction_auc = model(img_array_train)
        train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])  # 输出概率而不是标签!
        train_label_auc.append(label.item())
auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)

plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
plt.legend(loc='lower right')
#plt.savefig("roc.pdf", dpi=300,format="pdf")

print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

ROC曲线如下:

 很优秀的ROC曲线!

三、写在最后

运算量和消耗的计算资源依旧,但是这种组合模型在本数据集上跑出来的性能比ViT模型和DeiT是要好一些,说明这种组合策略还是有用的。

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/751183.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

飞机【频闪灯、导航灯】效果的设置——灯和灯的光晕

一、飞机外部灯光系统——频闪灯和防撞灯——闪烁效果 二、实现的原理 如下图所示,灯效果的组成包含两部分,一是灯本身,二是灯光产生的光晕 灯—— 就是一个球(Sphere),给它一个Emission(自发光)材质光晕——光晕的…

云苍穹各类参数使用说明

目录 公共参数 云参数 应用参数 单据参数 单据类型参数 用户选项参数 列表选项参数 公共参数 不推荐使用 参数值获取&#xff1a; // 获取整体公共参数 Map<String,Object> publicParamWhole SystemParamServiceHelper.loadPublicParametersFromCache();// 获取某…

vscode安装文件下载缓慢

官网下载vscode安装文件太慢大部分是因为vscode官网服务器跟我们国内的链接速度有关&#xff0c;当我们去官网下载&#xff08;Download Visual Studio Code - Mac, Linux, Windows&#xff09;一般都会出现下面的情况 下载速度几乎为0。 解决办法&#xff0c;鼠标右键复制下载…

C语言a---b

C语言的编译遵循贪心读法&#xff0c;也就是说&#xff0c;对于有歧义的符号&#xff0c;编译器会一直读取&#xff0c;直到它的意思完结&#xff1b; a---b&#xff0c;是a-- -b还是a- --b&#xff0c;根据贪心法则&#xff0c;读到第二个减号&#xff0c;意思完结&#xff0c…

今日分享——语音同声翻译软件

安娜和卡洛是一对在旅行时偶遇的年轻男女&#xff0c;他们互有好感&#xff0c;但他们来自不同的国家&#xff0c;说着不同的语言。每次面对彼此的时候&#xff0c;他们总是陷入语言的困扰&#xff0c;无法用自己熟悉的语言表达内心的情感。因此他俩都十分需要一款翻译语音的软…

【从零开始进行高精度手眼标定 eye in hand(小白向)3 非线性高精度标定法编程实现】

从零开始进行高精度手眼标定 eye in hand&#xff08;小白向&#xff09;1 原理推导 前言原理推导算法框图 MATLAB编程计算相关优化工具箱的安装&#xff08;不安装会报错&#xff09;数据读取目标函数计算完整代码实验验证 传送门&#xff1a; 1.【从零开始进行高精度手眼标定…

2023.7.13-【if】与【for】的配合使用:键入一个整数,输出结果用1234567890循环填充,填充的位数等于键入的整数

功能描述&#xff1a; 例如我们输入一个整数&#xff1a;25。输出的结果为1234567890123456789012345&#xff0c;共计25个数填充了这个输出结果。 程序&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include<stdio.h> int main() {int a;int b; int c;int i;prin…

fiddler抓包工具使用大全

目录 Fiddler基础知识 HTTP协议 Fiddler的使用 HTTPS抓包 Fiddler过滤会话 对request设置断点 对response设置断点 Fiddler的编码和解码 Fiddler基础知识 Fiddler是强大的抓包工具&#xff0c;它的原理是以web代理服务器的形式进行工作的&#xff0c;使用的代理地址是&…

Camtasia Studio 2023怎么导出mp4格式的视频的详细教程介绍

很多用户刚接触Camtasia Studio 2023&#xff0c;不熟悉如何保存mp4格式的视频。在今天的文章中小编为大家带来了Camtasia Studio 2023保存为mp4格式的视频的详细教程介绍。 Camtasia Studio 2023保存为mp4格式的视频的详细教程 1、 打开Camtasia Studio。 Camtasia Studio- …

Linux 删除 颜色转义字符 乱码 \x1b

目录 Linux颜色控制 方式一&#xff1a;添加sed正则命令 方式二&#xff1a;将输出写入文件再读取 Git颜色控制 使用Python paramiko ssh 获取 git 输出时&#xff0c;出现乱码&#xff0c;实际上是终端输出的ANSI颜色转义字符&#xff0c;用于控制终端颜色展示&#xff1a;…

CSS---CSS面试题

目录 1.盒模型 2.offsetHeight /clientheight/scrollHeight 3.left与offsetLeft 4.对BFC规范的理解 5.解决元素浮动导致的父元素高度塌陷的问题 6.CSS样式的先级 7.隐藏页面元素 8.display: none 与 visibility: hidden 的区别 9.页面引入样式时&#xff0c;使用link与import有…

水库大坝安全监测系统是由什么组成的?

水库大坝是防洪抗灾的重要设施&#xff0c;它们的安全性直接关系到人民群众的生命财产安全。因此&#xff0c;水库大坝的安全监测必不可少。水库大坝安全监测系统是一种集成了数据采集、传输、处理和分析的技术平台&#xff0c;能够实时、准确地监测大坝的状态&#xff0c;及时…

Python实现将pdf,docx,xls,doc,wps链接下载并将文件保存到本地

前言 本文是该专栏的第31篇,后面会持续分享python的各种干货知识,值得关注。 在工作上,尤其是在处理爬虫项目中,会遇到这样的需求。访问某个网页或者在采集某个页面的时候,正文部分含有docx,或pdf,或xls,或doc,或wps等链接。需要你使用python自动将页面上含有的这些信…

青岛大学_王卓老师【数据结构与算法】Week05_09_顺序栈的操作3_学习笔记

本文是个人学习笔记&#xff0c;素材来自青岛大学王卓老师的教学视频。 一方面用于学习记录与分享&#xff0c; 另一方面是想让更多的人看到这么好的《数据结构与算法》的学习视频。 如有侵权&#xff0c;请留言作删文处理。 课程视频链接&#xff1a; 数据结构与算法基础…

WebSocket理解

WebSocket理解 WebSocket定义与HTTP关系相同点:不同点&#xff1a;联系总体过程 HTTP问题长轮询Ajax轮询 WebSocket特点 WebSocket 定义 本质上是TCP的协议 持久化的协议 实现了浏览器和服务器的全双工通信&#xff0c;能更好的节省服务器资源和带宽 与HTTP关系 相同点: 基于…

路径规划算法:基于蜣螂优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于蜣螂优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于蜣螂优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化算法蜣螂…

Linux与Windows:自由与便利的角力

引言&#xff1a; Linux与Windows作为两大常见的电脑操作系统&#xff0c;一直以来都在用户中有着激烈的争议。两者各有优势和劣势&#xff0c;也有着不同的定位和用户群体。Linux以其开源、自由和灵活的特性&#xff0c;吸引了一大批技术爱好者和专业人士&#xff1b;而Window…

数据采集专家----4通道AD采集子卡推荐

FMC136是一款4通道250MHz采样率16位AD采集FMC子卡&#xff0c;符合VITA57规范&#xff0c;可以作为一个理想的IO模块耦合至FPGA前端&#xff0c;4通道AD通过高带宽的FMC连接器&#xff08;HPC&#xff09;连接至FPGA从而大大降低了系统信号延迟。 该板卡支持板上可编程采样时钟…

【SLAM】Ceres优化库官方实例详细解析(SLAM相关问题)

在ceres-solver官方教程中提供了三个SLAM相关的优化问题&#xff0c;分别为&#xff1a;bundle_adjuster、pose_graph_2d、pose_graph_3d。本文就这三个问题进行拆解和分析。 在阅读本文之前&#xff0c;你需要确保已经知晓了Ceres的基础用法&#xff0c;这部分可以参考文章&a…

一个多线程并发执行且顺序进出的执行器设计

考虑这样一种场景&#xff0c;有一批数据需要处理&#xff0c;且每个数据是有顺序的&#xff0c;这时最简单的方式就是按顺序同步执行每个数据的处理。但是&#xff0c;在有多核cpu时这种方式无疑是对cpu资源的极大浪费&#xff0c;这时就需要多开几个线程同时对数据进行处理&a…