第50步 深度学习图像识别:Data-efficient Image Transformers建模(Pytorch)

news2024/11/24 22:36:50

基于WIN10的64位系统演示

一、写在前面

(1)Data-efficient Image Transformers

Data-efficient Image Transformers (DeiT)是一种用于图像分类的新型模型,由Facebook AI在2020年底提出。这种方法基于视觉Transformer,通过训练策略的改进,使得模型能在少量数据下达到更高的性能。

在许多情况下,Transformer模型需要大量的数据才能得到好的结果。然而,这在某些场景下是不可能的,例如在只有少量标注数据的情况下。DeiT方法通过在训练过程中使用知识蒸馏,解决了这个问题。知识蒸馏是一种让小型模型学习大型模型行为的技术。

DeiT中的关键技术之一是使用学生模型预测教师模型的类别分布,而不仅仅是硬标签(原始数据集中的类别标签)。这样做的好处是,学生模型可以从教师模型的软标签(类别概率分布)中学习更多的信息。另外,DeiT还引入了一种新的训练方法,称为“硬标签蒸馏”,这种方法更进一步提高了模型的性能。通过这种方法,即使在ImageNet这样的大规模数据集上,DeiT也可以与更复杂的卷积神经网络(如ResNet和EfficientNet)相媲美或者超越,同时还使用了更少的计算资源。

(2)Data-efficient Image Transformers的预训练版本

本文继续使用Facebook的高级深度学习框架PyTorchImageModels (timm)。该库提供了多种预训练的模型,太多了,我还是给网址吧:

https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/deit.py

从第155行到第249行都是:

 

比如“deit_tiny_patch16_224”

"deit": 这是模型类型的缩写,代表 "Data-efficient Image Transformers"。

"tiny": 这个词说明了模型的大小。在此情况下,"tiny"意味着这是一个更小、计算成本更低的模型版本。相对的,还有"small","base"等不同规模的模型。

"patch16": 这是指在模型的输入阶段,原始图像被分割成大小为16x16像素的小方块(也被称为patch)进行处理。

"224": 这是指模型接受的输入图像的尺寸是224x224像素。这是在计算机视觉领域常用的图像尺寸。

二、Data-efficient Image Transformers迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

import torch
from torchvision import datasets, transforms
import os

# 数据集路径
data_dir = "./MTB"

# 图像的大小
img_height = 100
img_width = 100

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_height),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)

# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占80%
val_size = full_size - train_size  # 验证集的大小

# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']

# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

(c)导入Data-efficient Image Transformers

# 导入所需的库
import torch.nn as nn
import timm

# 定义Data-efficient Image Transformers模型
model = timm.create_model('deit_base_patch16_224', pretrained=True)  # 你可以选择适合你需求的DeiT版本,这里以deit_base_patch16_224为例
num_ftrs = model.head.in_features

# 根据分类任务修改最后一层
model.head = nn.Linear(num_ftrs, len(class_names))

# 将模型移至指定设备
model = model.to(device)

# 打印模型摘要
print(model)

(d)编译模型

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters())

# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 开始训练模型
num_epochs = 10
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 每个epoch都有一个训练和验证阶段
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # 遍历数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 零参数梯度
            optimizer.zero_grad()

            # 前向
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # 只在训练模式下进行反向和优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # 统计
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()

        # 记录每个epoch的loss和accuracy
        if phase == 'train':
            train_loss_history.append(epoch_loss)
            train_acc_history.append(epoch_acc)
        else:
            val_loss_history.append(epoch_loss)
            val_acc_history.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        # 深拷贝模型
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    print()

print('Best val Acc: {:4f}'.format(best_acc))

(e)Accuracy和Loss可视化

epoch = range(1, len(train_loss_history)+1)

fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()

ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()

#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('Confusion matrix',fontsize=15)
    plt.ylabel('Actual value',fontsize=14)
    plt.xlabel('Predictive value',fontsize=14)
    
def evaluate_model(model, dataloader, device):
    model.eval()   # 设置模型为评估模式
    true_labels = []
    pred_labels = []
    # 遍历数据
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 前向
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
        
    return true_labels, pred_labels

# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)

# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]

# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数

# 打印出性能指标
print("验证集的灵敏度为:", sen_val, 
      "验证集的特异度为:", sep_val,
      "验证集的准确率为:", acc_val, 
      "验证集的错误率为:", error_rate_val,
      "验证集的精确度为:", precision_val, 
      "验证集的F1为:", F1_val,
      "验证集的MCC为:", MCC_val)

# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)

    
# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, 
      "训练集的特异度为:",sep_train,
      "训练集的准确率为:",acc_train, 
      "训练集的错误率为:",error_rate_train,
      "训练集的精确度为:",precision_train, 
      "训练集的F1为:",F1_train,
      "训练集的MCC为:",MCC_train)

# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)

效果不错:

(g)AUC曲线绘制

from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import math

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = metrics.roc_curve(labels, predictions)

    plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
    plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
    plt.xlabel('False positives rate')
    plt.ylabel('True positives rate')
    ax = plt.gca()
    ax.set_aspect('equal')


# 确保模型处于评估模式
model.eval()

train_ds = dataloaders['train']
val_ds = dataloaders['val']

val_pre_auc   = []
val_label_auc = []

for images, labels in val_ds:
    for image, label in zip(images, labels):      
        img_array = image.unsqueeze(0).to(device)  # 在第0维增加一个维度并将图像转移到适当的设备上
        prediction_auc = model(img_array)  # 使用模型进行预测
        val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
        val_label_auc.append(label.item())  # 使用Tensor.item()获取Tensor的值
auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)


train_pre_auc   = []
train_label_auc = []

for images, labels in train_ds:
    for image, label in zip(images, labels):
        img_array_train = image.unsqueeze(0).to(device) 
        prediction_auc = model(img_array_train)
        train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])  # 输出概率而不是标签!
        train_label_auc.append(label.item())
auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)

plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
plt.legend(loc='lower right')
#plt.savefig("roc.pdf", dpi=300,format="pdf")

print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

 ROC曲线如下:

 很优秀的ROC曲线!

三、写在最后

运算量和消耗的计算资源还是大,在这个数据集上跑出来的性能比ViT模型要好一些,说明优化策略还是起到效果的。

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/747038.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode Java实现 222. 完全二叉树的节点个数

文章目录 题目递归遍历左右子树个数实现思路具体代码实现缺点 根据完全二叉树性质优化思路具体代码实现优点 结语 题目 给你一棵 完全二叉树 的根节点 root ,求出该树的节点个数。 完全二叉树 的定义如下:在完全二叉树中,除了最底层节点可能…

串口的再认识

常用函数介绍 串口发送/接收函数 HAL_UART_Transmit(); 串口发送数据,使用超时管理机制(即在发送成功前一直阻塞) HAL_UART_Receive(); 串口接收数据,使用超时管理机制 HAL_UART_Transmit_IT(); 串口中断模式发送 HAL_UART…

DAY45:动态规划(六)背包问题优化:一维DP解决01背包问题

文章目录 一维DP数组的解法二维DP递推思路滚动数组优化思路(重要)一维DP数组的含义一维DP递推公式一维DP的初始化遍历顺序(重要)举例推导DP数组 一维DP数组完整版写法面试问题为什么一维DP背包的for循环一定要倒序遍历为什么一维D…

Qt + QR-Code-generator 生成二维码

0.前言 之前使用 libgrencode 生成二维码,LGPL 协议实在不方便,所以需要找一个 github 星星多的,代码简单最好 header-only,协议最好是 MIT 或者兼容协议而不是 GPL 或者 LPGL。 QR-Code-generator 正好符合这个要求&#xff0c…

JMeter 如何模拟不同的网络速度

目录 前言: 限制输出带宽以模拟不同的网络速度 将这两行添加到user.properties文件中(可以在JMeter安装的bin文件夹中找到此行) 通过-J 命令行参数传递属性的值,如下所示: 前言: JMeter可以通过使用不同…

【网络安全带你练爬虫-100练】第12练:pyquery解析库提取指定数据

目录 一、目标1、基础/环境的准备工作 二、目标2:开始使用pyquery 三、目标3:提取到指定的数据 四、目标3:通过列表的形式获取指定数据 五、扩展:其他方法 六、网络安全O 一、目标1、基础/环境的准备工作 1、文档&#xff1…

wordpress的wp_trim_words摘要没有截取到指定的字符长度的问题解决

一、问题描述 在imqd.cn文章列表中&#xff0c;用于将正文内容截取保留前75个字的方法突然没效果了。 <p class"post-desc text-black-50 d-none d-md-block"><?phpecho wp_trim_words( get_the_content(), 75, ...);?> </p>直接展示了所有的正…

计算机网络实验(1)--Windows网络测试工具

&#x1f4cd;实验目的 理解知识点&#xff08;ping&#xff0c;netstat&#xff0c;ipconfig&#xff0c;arp&#xff0c;tracert&#xff0c;route&#xff0c;nbtstat&#xff0c;net&#xff09;所涉及的基本概念&#xff0c;并学会使用这些工具测试网络的状态及从网上获取…

java后端开发环境搭建 mac

在mac pro上搭建一套java 后端开发环境&#xff0c;主要安装的内容有&#xff1a;jdk、maven、git、tomcat、mysql、navicat、IntelliJ、redis。 本人mac pro的系统为mac OS Monterey 12.6.7&#xff0c;主机的硬件架构为x86_64。 左上角关于本机查看系统版本&#xff1b;终端…

LLM - Baichuan7B Lora 训练详解

目录 一.引言 二.环境准备 三.模型训练 1.依赖引入与 tokenizer 加载 2.加载 DataSet 与 Model 3.Model 参数配置 4.获取 peft Model 5.构造 Trainer 训练 6.训练完整代码 四.Shell 执行 1.脚本构建 2.训练流程 3.训练结果 五.总结 一.引言 LLM - Baichuan7B Tok…

Ubuntu18.04 docker kafka 本地测试环境搭建

文章目录 一、kafka 介绍二、Ubuntu docker kafka 本地测试环境搭建2.1 docker kafka 启动2.1.1 下载镜像2.1.2 启动 wurstmeister/zookeeper2.1.3 启动 wurstmeister/kafka 2.2 docker kafka 测试数据收发2.2.1 docker kafka 测试数据收发2.2.2 windows验证 三、嵌入式集成四、…

加密劫持者攻击教育机构

我们的专家分析了2023年第一季度的当前网络威胁。研究表明&#xff0c;独特事件的数量增加&#xff0c;勒索软件活动激增&#xff0c;特别是针对学术和教育机构。我们记录了大量与就业有关的网络钓鱼邮件&#xff0c;出现了QR网络钓鱼和恶意广告的增加。 我们的研究表明&#…

「2023 最新版」Java 工程师面试题总结 (1000 道题含答案解析)

作为一名优秀的程序员&#xff0c;技术面试都是不可避免的一个环节&#xff0c;一般技术面试官都会通过自己的方式去考察程序员的技术功底与基础理论知识。 如果你参加过一些大厂面试&#xff0c;肯定会遇到一些这样的问题&#xff1a; 1、看你项目都用的框架&#xff0c;熟悉…

AIGC:文生图stable-diffusion-webui部署及使用

1 stable-diffusion-webui介绍 Stable Diffusion Web UI 是一个基于 Stable Diffusion 的基础应用&#xff0c;利用 gradio 模块搭建出交互程序&#xff0c;可以在低代码 GUI 中立即访问 Stable Diffusion Stable Diffusion 是一个画像生成 AI&#xff0c;能够模拟和重建几乎…

Linux(centos7)下安装mariadb10详解

MariaDB 和 MySQL 之间存在紧密的关系。 起源&#xff1a;MariaDB 最初是作为 MySQL 的一个分支而创建的。它的初始目标是保持与 MySQL 的兼容性&#xff0c;并提供额外的功能和性能改进。 共同的代码基础&#xff1a;MariaDB 使用了 MySQL 的代码基础&#xff0c;并在此基础上…

PYTHON 解码 IP 层

PYTHON 解码 IP 层 引言1.编写流量嗅探器1.1 Windows 和 Linux 上的包嗅探2.解码 IP 层2.1 struct 库3.编写 IP 解码器4.解码 ICMP5.总结 作者&#xff1a;高玉涵 时间&#xff1a;2023.7.12 环境&#xff1a;Windows 10 专业版 22H2&#xff0c;Python 3.10.4 引言 IP 是 …

JWT的深入理解

1、JWT是什么 JWT&#xff08;JSON Web Token&#xff09;是一种开放标准&#xff08;RFC 7519&#xff09;&#xff0c;用于在不同实体之间安全地传输信息。它由三部分组成&#xff0c;即头部&#xff08;Header&#xff09;、载荷&#xff08;Payload&#xff09;和签名&…

获取QT界面坐标的各种方法

链接 ract() 获取rect所在部件的尺寸。 rect()返回的QRect对象可以用来做什么

openResty的Redis模块踩坑记录

OpenResty提供了操作Redis的模块&#xff0c;我们只要引入该模块就能直接使用。说是这样说&#xff0c;但是实践起来好像并不太顺利。 1.设置了密码的redis&#xff0c;lua业务逻辑中需要添加身份认证代码 网上很多资料、文章似乎都是没有设置redis密码&#xff0c;说来也奇怪…

JS区域滤镜

思路 简单一点的&#xff0c;像素点X坐标小于图宽1/3和大于2/3的点变灰&#xff0c;中间的点不变。 复杂的暂时不会搞。 原图 处理后 <html> <style> #canvas { width:100%; } </style> <body> <input id"file" type"file" …