第52步 深度学习图像识别:Transformer in Transformer建模(Pytorch)

news2025/1/22 14:43:22

基于WIN10的64位系统演示

一、写在前面

(1)Transformer in Transformer

Transformer in Transformer(TNT)模型是一种新的图像分类模型,由研究者在2021年提出。这种模型的特点是在传统的Vision Transformer模型的基础上,引入了一种新的结构,使得模型可以更好地处理图像的局部和全局信息。

在传统的Vision Transformer模型中,输入图像会被划分为一系列的小块或者"patches",然后用Transformer处理这些独立的patches。这种方法的一个问题是,它并未充分利用图像的局部信息,因为在每个patch内部的像素被平均处理,忽视了它们之间的关系。

TNT模型提出了一种解决方案,即在每个patch内部再次应用Transformer,形成一种嵌套的Transformer结构,也就是"Transformer in Transformer"。这种设计使得模型在处理每个patch时,首先会考虑到patch内部的像素之间的关系,然后再处理patch之间的关系。这样,TNT模型能更好地捕捉图像的局部和全局信息。

(2)Transformer in Transformer的码源

本文继续使用Facebook的高级深度学习框架PyTorchImageModels (timm),网址为:

https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/tnt.py

 可以看到,有两种可使用的TNT版本:tnt_s_patch16_224以及tnt_b_patch16_224,主要区别在于模型的规模和复杂性。

"tnt_s_patch16_224":这是一个小型的TNT模型版本,其中的"s"表示"small",意为小型。该模型的结构相对较小,因此参数数量较少,计算成本和内存需求相对较低。然而,由于其结构简单,其性能可能略低于更大的模型版本。

"tnt_b_patch16_224":这是一个基本的TNT模型版本,其中的"b"表示"base",意为基本。该模型的结构比小型版本更复杂,参数数量更多,因此其计算成本和内存需求也相对较高。然而,由于其更复杂的结构,该模型可能会提供更高的性能。

二、Transformer in Transformer迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

import torch
from torchvision import datasets, transforms
import os

# 数据集路径
data_dir = "./MTB"

# 图像的大小
img_height = 100
img_width = 100

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_height),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)

# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占80%
val_size = full_size - train_size  # 验证集的大小

# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']

# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

(c)导入Transformer in Transformer

# 导入所需的库
import torch.nn as nn
import timm

# 定义Transformer in Transformer模型
model = timm.create_model('tnt_s_patch16_224', pretrained=True)  # 你可以选择适合你需求的TNT版本,这里以tnt_s_patch16_224为例
num_ftrs = model.head.in_features

# 根据分类任务修改最后一层
model.head = nn.Linear(num_ftrs, len(class_names))

# 将模型移至指定设备
model = model.to(device)

# 打印模型摘要
print(model)

(d)编译模型

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters())

# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 开始训练模型
num_epochs = 10
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 每个epoch都有一个训练和验证阶段
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # 遍历数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 零参数梯度
            optimizer.zero_grad()

            # 前向
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # 只在训练模式下进行反向和优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # 统计
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()

        # 记录每个epoch的loss和accuracy
        if phase == 'train':
            train_loss_history.append(epoch_loss)
            train_acc_history.append(epoch_acc)
        else:
            val_loss_history.append(epoch_loss)
            val_acc_history.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        # 深拷贝模型
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    print()

print('Best val Acc: {:4f}'.format(best_acc))

(e)Accuracy和Loss可视化

epoch = range(1, len(train_loss_history)+1)

fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()

ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()

#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

 蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('Confusion matrix',fontsize=15)
    plt.ylabel('Actual value',fontsize=14)
    plt.xlabel('Predictive value',fontsize=14)
    
def evaluate_model(model, dataloader, device):
    model.eval()   # 设置模型为评估模式
    true_labels = []
    pred_labels = []
    # 遍历数据
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 前向
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
        
    return true_labels, pred_labels

# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)

# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]

# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数

# 打印出性能指标
print("验证集的灵敏度为:", sen_val, 
      "验证集的特异度为:", sep_val,
      "验证集的准确率为:", acc_val, 
      "验证集的错误率为:", error_rate_val,
      "验证集的精确度为:", precision_val, 
      "验证集的F1为:", F1_val,
      "验证集的MCC为:", MCC_val)

# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)

    
# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, 
      "训练集的特异度为:",sep_train,
      "训练集的准确率为:",acc_train, 
      "训练集的错误率为:",error_rate_train,
      "训练集的精确度为:",precision_train, 
      "训练集的F1为:",F1_train,
      "训练集的MCC为:",MCC_train)

# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)

效果不错:

 (g)AUC曲线绘制

from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import math

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = metrics.roc_curve(labels, predictions)

    plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
    plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
    plt.xlabel('False positives rate')
    plt.ylabel('True positives rate')
    ax = plt.gca()
    ax.set_aspect('equal')


# 确保模型处于评估模式
model.eval()

train_ds = dataloaders['train']
val_ds = dataloaders['val']

val_pre_auc   = []
val_label_auc = []

for images, labels in val_ds:
    for image, label in zip(images, labels):      
        img_array = image.unsqueeze(0).to(device)  # 在第0维增加一个维度并将图像转移到适当的设备上
        prediction_auc = model(img_array)  # 使用模型进行预测
        val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
        val_label_auc.append(label.item())  # 使用Tensor.item()获取Tensor的值
auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)


train_pre_auc   = []
train_label_auc = []

for images, labels in train_ds:
    for image, label in zip(images, labels):
        img_array_train = image.unsqueeze(0).to(device) 
        prediction_auc = model(img_array_train)
        train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])  # 输出概率而不是标签!
        train_label_auc.append(label.item())
auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)

plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
plt.legend(loc='lower right')
#plt.savefig("roc.pdf", dpi=300,format="pdf")

print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

ROC曲线如下:

很优秀的ROC曲线!

三、写在最后

略~

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/761516.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FreeRTOS源码分析-1创建任务

目录 1 任务的句柄(结构体) 2 创建任务主要工作 2.1 创建任务初始化源码分析 2.2 任务添加到就绪列表源码分析 2.3任务堆栈的初始化源码分析 问:R0为什么要入栈保存?因为作为函数的第一个传入参数,必须也要保存。…

spring复习:(37)ProxyFactoryBean之getObject

该工厂bean的getObject代码如下: public Object getObject() throws BeansException {initializeAdvisorChain();if (isSingleton()) {return getSingletonInstance();}else {if (this.targetName null) {logger.info("Using non-singleton proxies with sing…

4.CSS图文样式

考点:line-height为200%时,font-size为40px

第十五章:DenseASPP for Semantic Segmentation in Street Scenes——在街景语义分割中的DenseASPP

0.摘要 语义图像分割是自动驾驶中的基本街景理解任务,在这个任务中,高分辨率图像中的每个像素被归类为一组语义标签。与其他场景不同,自动驾驶场景中的物体呈现出非常大的尺度变化,这给高级特征表示带来了巨大挑战,因为…

IDEA设置显示行号和方法间的分隔符

IDEA设置显示行号和方法间的分隔符 选择File--Settings--Edotor-General-Apperance,勾选上下图中的选项后点击 OK 即可。 每个函数不迷路~~ Show line numbers:显示行数 Show method separators: 显示方法分隔线。

央视赋能,强势出击——方圆出海与《品牌中国》栏目达成战略合作

2023 央视赋能,强势出击 方圆出海 “日前,深圳市方圆出海科技有限公司与《品牌中国》栏目携手,双方正式达成战略合作协议,央视《品牌中国》栏目负责人正式授予方圆出海“《品牌中国》重点推荐品牌”的荣誉称号。 此次签约标志着…

js的this绑定规则以及箭头函数

目录 调用位置默认绑定隐式绑定隐式丢失 显式绑定callapplybind new绑定装箱绑定优先级this规则之外忽略显式绑定间接函数引用 箭头函数 调用位置 从字面意思上来理解,this似乎是指向自己的 然而在JavaScript中,this并不是绑定到自身的 可以看这一个例子…

蓝牙HID模式下输出中文原理简介

目录 前言一、蓝牙和HID简介二、Unicode编码简介三、Windows下alt键code编码输出中文四、蓝牙HID模式下实现在手机上输入中文的原理 前言 最近在使用蓝牙模组,对于蓝牙模组如何输出中文的原理不太清楚,所以找了一些资料简单学习了下,总结如下…

目标检测——FasterRCNN原理与实现

目录 网络工作流程数据加载模型加载模型预测过程RPN获取候选区域FastRCNN进行目标检测 模型结构详解backboneRPN网络anchorsRPN分类RPN回归Proposal层 ROIPooling目标分类与回归 FasterRCNN的训练RPN网络的训练正负样本标记RPN网络的损失函数训练过程实现正负样本设置损失函数 …

Kubernetes 使用 helm 部署 NFS Provisioner

文章目录 1. 介绍2. 预备条件3. 部署 nfs4. 部署 NFS subdir external provisioner4.1 集群配置 containerd 代理4.2 配置代理堡垒机通过 kubeconfig 部署 1. 介绍 NFS subdir external provisioner 使用现有且已配置的NFS 服务器来支持通过持久卷声明动态配置 Kubernetes 持久…

大模型基础知识汇总

本文总结大模型相关基础知识,用于大模型学习入门 (持续更新中…) 文章目录 NLP 基础知识传统 NLP 知识NLU 与 NLG 各种任务的差异 Transformer 相关知识Pre Norm与Post Norm的区别?Bert 预训练过程手写 transformer 的 attention …

从0到1:跑团小程序开发心得笔记

背景介绍 随着健康意识的兴起,越来越多的人选择加入跑步俱乐部,不仅体验到了运动的乐趣,也感受到了人生的不同色,那么通过小程序,把俱乐部搬到手机上,通过小程序了解俱乐部动态和运动常识,可以…

C++自定义信号和QML的槽函数建立连接

0x00 在C代码在定义一个信号函数&#xff1a;“void sendData2UI(QString msg);”&#xff0c;该函数主要是将接收到的UDP消息发送到QML界面中 #ifndef UDPCLI_H #define UDPCLI_H#include <QObject> #include <QUdpSocket> #include <QString>class UdpCli …

【Netty】NIO基础(三大组件)

文章目录 三大组件Channel & BufferSelector ByteBufferByteBuffer 正确使用姿势ByteBuffer 内部结构ByteBuffer 常见方法分配空间向 buffer 写入数据从 buffer 读取数据mark 和 reset 字符串与 ByteBuffer 互转Scattering ReadsGathering Writes粘包、半包分析 附&#xf…

《啊哈算法》第一章--排序

文章目录 前言一、排序算法二、桶排序三、冒泡排序三、快速排序总结 前言 今年蓝桥杯没有拿到省一&#xff0c;所以就决定沉下心来学习算法&#xff0c;为了使得算法的学习更加稳固&#xff0c;所以就拿起了&#xff0c;最基础的且最经典的一本算法书《啊哈算法》&#xff0c;…

Redis进阶底层原理- 持久化

Redis作为基于内存的缓存数据库&#xff0c;就会存在断电即失的问题&#xff0c;所以数据的持久化是非常重要的。Redis随着版本升级迭代&#xff0c;持久化技术也在不断的升级&#xff0c;&#xff08;从最开始的RDB&#xff0c;到的Redis1.1版本加入AOF&#xff0c;3.0版本支持…

全志F1C200S嵌入式驱动开发(sd卡驱动)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 说是sd卡,其实是micro sd卡,或者称之为tf卡更合适。一般的soc都支持从tf卡启动,所以用tf卡来学习soc、驱动和linux,对新人来说是比较合适的。前面我们已经用sd卡构建了一个类似…

uniapp:针对与富文本解析的几种方法

第一章、富文本的解析方法 1.1 uniapp自带组件&#xff1a;rich-text <rich-text :nodes"nodes"></rich-text> 1.2 v-html <view v-html"item.content"></view> 1.3 uview组件&#xff1a;u-parse <u-parse :content&quo…

学习babylon.js --- [3] 开启https

babylonjs提供WebVR功能&#xff0c;但是使用这个功能得用https&#xff0c;本文讲述如何使用自签名证书来开启https&#xff0c;基于第二篇文章中搭建的工程。 一 生成自签名证书 首先要安装openssl&#xff0c;这个去网上搜下就行了。安装完之后在终端下输入openssl回车可以…

DeepC 实用教程(三)环境数据

目 录 一、前言二、风谱/风剖三、洋流四、波浪4.1 规则波浪4.2 随机波浪谱 五、方向六、海床属性七、位置7.1 创建位置7.2 规则波时域条件7.3 随机波时域条件7.4 波浪散布图7.4.1 散布图分块7.4.2 时域条件 八、参考文献 一、前言 SESAM &#xff08;Super Element Structure A…