第54步 深度学习图像识别:MLP-Mixer建模(Pytorch)

news2024/12/26 21:24:58

基于WIN10的64位系统演示

一、写在前面

(1)MLP-Mixer

MLP-Mixer(Multilayer Perceptron Mixer)是Google在2021年提出的一种新型的视觉模型结构。它的主要特点是完全使用多层感知机(MLP)来处理图像,而不是使用常见的卷积(Convolution)或者自注意力(Self-Attention)机制。

MLP-Mixer的结构主要包括两种类型的层:Token Mixing层和Channel Mixing层。在Token Mixing层中,模型会将图像分割成若干个patch(类似于像素块),然后对这些patch进行处理。在Channel Mixing层中,模型会对每个patch的通道进行处理。这两种类型的层交替堆叠,形成了最终的模型结构。

MLP-Mixer的设计目标是探索除卷积和自注意力之外的其他可能的模型结构,以期在保持性能的同时,降低模型的复杂性和计算成本。实验结果显示,MLP-Mixer在一些图像分类任务上的性能可以与ResNet和Transformer等主流模型相媲美。

然而,需要注意的是,虽然MLP-Mixer在某些方面展现出了很好的性能,但它并不意味着会替代卷积或者自注意力模型。实际上,每种模型都有其适用的场景和优势,MLP-Mixer提供了一个新的视角和工具,供我们处理视觉任务。

(2)MLP-Mixer的码源

本文使用 mlp-mixer-pytorch 库来实现MLP-Mixer。

当然,得先安装这个库:

(a)首先,打开Anaconda Prompt。在开始菜单中找到它,或者直接在搜索栏中输入"Anaconda Prompt"。在打开的Anaconda Prompt中,如果你想在一个特定的环境中安装mlp_mixer_pytorch,你需要先激活这个环境。假设你的环境名为myenv,你可以使用以下命令来激活这个环境:

conda activate myenv

(b)接下来,使用pip来安装mlp_mixer_pytorch库。在Anaconda Prompt中输入以下命令并按回车键:

pip install mlp-mixer-pytorch

二、MLP-Mixer迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

import torch
from torchvision import datasets, transforms
import os

# 数据集路径
data_dir = "./MTB"

# 图像的大小
img_height = 256
img_width = 256

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_height),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)

# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size)  # 假设训练集占80%
val_size = full_size - train_size  # 验证集的大小

# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']

# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

(c)导入MLPMixer

from mlp_mixer_pytorch import MLPMixer

num_classes = len(class_names)  # 根据数据集的类别数量来设置模型的输出类别数量

# 构建MLP-Mixer模型
model = MLPMixer(
    image_size = img_height,  # 图像的高和宽
    channels = 3,  # 图像的通道数
    patch_size = 16,  # MLP-Mixer的patch大小
    dim = 512,  # MLP-Mixer的维度
    depth = 12,  # MLP-Mixer的深度
    num_classes = num_classes  # 输出类别数量
)

# 将模型移动到GPU
model = model.to(device)

# 打印模型摘要
print(model)

说明:mlp-mixer-pytorch库的主要功能就是提供了一个MLP-Mixer的类,可以通过实例化这个类来创建一个MLP-Mixer模型。在创建模型时,可以通过参数来设置图像的大小、通道数、patch的大小、模型的维度、深度以及输出类别的数量等。

需要注意的是,mlp-mixer-pytorch库提供的MLP-Mixer模型默认是随机初始化的,也就是说并没有加载预训练权重。如果你有MLP-Mixer的预训练权重,可以在创建模型后加载。

(d)编译模型

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.Adam(model.parameters())

# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 开始训练模型
num_epochs = 20
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 每个epoch都有一个训练和验证阶段
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # 遍历数据
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 零参数梯度
            optimizer.zero_grad()

            # 前向
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # 只在训练模式下进行反向和优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # 统计
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()

        # 记录每个epoch的loss和accuracy
        if phase == 'train':
            train_loss_history.append(epoch_loss)
            train_acc_history.append(epoch_acc)
        else:
            val_loss_history.append(epoch_loss)
            val_acc_history.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        # 深拷贝模型
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    print()

print('Best val Acc: {:4f}'.format(best_acc))

# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")

(e)Accuracy和Loss可视化

epoch = range(1, len(train_loss_history)+1)

fig, ax = plt.subplots(1, 2, figsize=(10,4))
ax[0].plot(epoch, train_loss_history, label='Train loss')
ax[0].plot(epoch, val_loss_history, label='Validation loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()

ax[1].plot(epoch, train_acc_history, label='Train acc')
ax[1].plot(epoch, val_acc_history, label='Validation acc')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].legend()

#plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

 蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

from sklearn.metrics import classification_report, confusion_matrix
import math
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.pyplot import imshow

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('Confusion matrix',fontsize=15)
    plt.ylabel('Actual value',fontsize=14)
    plt.xlabel('Predictive value',fontsize=14)
    
def evaluate_model(model, dataloader, device):
    model.eval()   # 设置模型为评估模式
    true_labels = []
    pred_labels = []
    # 遍历数据
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 前向
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())
        
    return true_labels, pred_labels

# 获取预测和真实标签
true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)

# 计算混淆矩阵
cm_val = confusion_matrix(true_labels, pred_labels)
a_val = cm_val[0,0]
b_val = cm_val[0,1]
c_val = cm_val[1,0]
d_val = cm_val[1,1]

# 计算各种性能指标
acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val)  # 准确率
error_rate_val = 1 - acc_val  # 错误率
sen_val = d_val/(d_val+c_val)  # 灵敏度
sep_val = a_val/(a_val+b_val)  # 特异度
precision_val = d_val/(b_val+d_val)  # 精确度
F1_val = (2*precision_val*sen_val)/(precision_val+sen_val)  # F1值
MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val)))  # 马修斯相关系数

# 打印出性能指标
print("验证集的灵敏度为:", sen_val, 
      "验证集的特异度为:", sep_val,
      "验证集的准确率为:", acc_val, 
      "验证集的错误率为:", error_rate_val,
      "验证集的精确度为:", precision_val, 
      "验证集的F1为:", F1_val,
      "验证集的MCC为:", MCC_val)

# 绘制混淆矩阵
plot_cm(true_labels, pred_labels)

    
# 获取预测和真实标签
train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
# 计算混淆矩阵
cm_train = confusion_matrix(train_true_labels, train_pred_labels)  
a_train = cm_train[0,0]
b_train = cm_train[0,1]
c_train = cm_train[1,0]
d_train = cm_train[1,1]
acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
error_rate_train = 1 - acc_train
sen_train = d_train/(d_train+c_train)
sep_train = a_train/(a_train+b_train)
precision_train = d_train/(b_train+d_train)
F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train))) 
print("训练集的灵敏度为:",sen_train, 
      "训练集的特异度为:",sep_train,
      "训练集的准确率为:",acc_train, 
      "训练集的错误率为:",error_rate_train,
      "训练集的精确度为:",precision_train, 
      "训练集的F1为:",F1_train,
      "训练集的MCC为:",MCC_train)

# 绘制混淆矩阵
plot_cm(train_true_labels, train_pred_labels)

效果不错:

 (g)AUC曲线绘制

from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import math

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = metrics.roc_curve(labels, predictions)

    plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
    plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
    plt.xlabel('False positives rate')
    plt.ylabel('True positives rate')
    ax = plt.gca()
    ax.set_aspect('equal')


# 确保模型处于评估模式
model.eval()

train_ds = dataloaders['train']
val_ds = dataloaders['val']

val_pre_auc   = []
val_label_auc = []

for images, labels in val_ds:
    for image, label in zip(images, labels):      
        img_array = image.unsqueeze(0).to(device)  # 在第0维增加一个维度并将图像转移到适当的设备上
        prediction_auc = model(img_array)  # 使用模型进行预测
        val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
        val_label_auc.append(label.item())  # 使用Tensor.item()获取Tensor的值
auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)


train_pre_auc   = []
train_label_auc = []

for images, labels in train_ds:
    for image, label in zip(images, labels):
        img_array_train = image.unsqueeze(0).to(device) 
        prediction_auc = model(img_array_train)
        train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])  # 输出概率而不是标签!
        train_label_auc.append(label.item())
auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)

plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
plt.legend(loc='lower right')
#plt.savefig("roc.pdf", dpi=300,format="pdf")

print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

ROC曲线如下:

 这个ROC曲线也是不错的!全部大于95%!

三、写在最后

截至目前,图像分类领域基本就是CNN、Transformer和MLP三足鼎立了。孰优孰劣,还不好说,中庸之道那就是各有千秋。他们之间的两两组合或者一起融合的话,效果又会如何?

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/784195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

seaborn笔记 pairplot PairGrid

1 数据集 鸢尾花数据集 # Visual Python: Data Analysis > File vp_df pd.read_csv(https://raw.githubusercontent.com/visualpython/visualpython/main/visualpython/data/sample_csv/iris.csv) vp_df 1.1 基本pairplot import seaborn as snsg sns.pairplot(vp_df) …

前端随笔:HTML/CSS/JavaScript和Vue

前端随笔 1:HTML、JavaScript和Vue 最近因为工作需要,需要接触一些前端的东西。之前虽然大体上了解过HTML、CSS和JavaScript,也知道HTML定义了内容、CSS定义了样式、JavaScript定义了行为,但是却没有详细的学习过前端三件套的细节…

2023 年第二届钉钉杯大学生大数据挑战赛初赛 初赛 A:智能手机用户监测数据分析 问题二分类与回归问题Python代码分析

2023 年第二届钉钉杯大学生大数据挑战赛初赛 初赛 A:智能手机用户监测数据分析 问题二分类与回归问题Python代码分析 相关链接 【2023 年第二届钉钉杯大学生大数据挑战赛初赛】 初赛 A:智能手机用户监测数据分析 问题一Python代码分析 【2023 年第二届…

RocketMQ 5.0 无状态实时性消费详解

作者:绍舒 背景 RocketMQ 5.0 版本引入了 Proxy 模块、无状态 pop 消费机制和 gRPC 协议等创新功能,同时还推出了一种全新的客户端类型:SimpleConsumer。 SimpleConsumer 客户端采用了无状态的 pop 机制,彻底解决了在客户端发布…

SpringBoot原理分析 | Redis集成

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; Springboot集成Redis 依赖导入 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis<…

九、数据结构——顺序队列中的循环队列

目录 一、循环队列的定义 二、循环队列的实现 三、循环队列的基本操作 ①初始化 ②判空 ③判满 ④入队 ⑤出队 ⑥获取长度 ⑦打印 四、循环队列的应用 五、全部代码 数据结构中的循环队列 在数据结构中&#xff0c;队列&#xff08;Queue&#xff09;是一种常见的线性数据结…

防火墙NAT地址转换的四种应用实验与防火墙的双机热备实验

一、NAT实验 一、源地址转换 1、首先搭建NAT实验环境的拓扑&#xff1a; 这里需要配置各个设备的ip、掩码、网关&#xff1b;省略 2、登录防火墙设备并且为防火墙设备的0/0/0接口配置与虚拟网卡一个网段的ip&#xff0c;并且开启该接口的全部服务 [USG6000V1]int gi 0/0/0…

Keil系列教程11_工程窗口图标说明

1写在前面 很多朋友看到如下工程窗口的图标&#xff08;如&#xff1a;带有“叹号”、“星号”、“钥匙”、“禁止驶入”标志&#xff09;&#xff0c;就会产生疑问&#xff1a;这些图标到底是啥意思呢&#xff1f; 其实&#xff0c;这些不同标志的图标是代表着不同的含义&…

AWVS 15.6 使用教程

目录 介绍 版本 AWVS具有以下特点和功能&#xff1a; 功能介绍&#xff1a; Dashboard功能&#xff1a; Targets功能&#xff1a; Scans功能&#xff1a; Vulnerabilities功能&#xff1a; Reports功能&#xff1a; Users功能&#xff1a; Scan Profiles功能&#x…

MyBatis查询数据库(2)

目录 前言&#x1f36d; 一、增删查改操作 1、查 Ⅰ、mapper接口&#xff1a; Ⅱ、UserMapper.xml 查询所有用户的具体实现 SQL&#xff1a; Ⅲ、进行单元测试 2、增、删、改操作 Ⅰ、增 添加用户 添加用户并且返回自增 id Ⅱ、改 根据id修改用户名 开启 MyBatis …

leetcode每日一练-第141题-环形链表

一、思路 双指针 二、解题方法 使用了正确的快慢环指针方法来判断链表。快指针每次向前移动两步&#xff0c;慢指针每次移动一步&#xff0c;如果链表中向前移动一步&#xff0c;它们最终会相遇。如果链表不存在环&#xff0c;快指针会先到达链表是否存在&#xff0c;此时存在…

【C#】using

文章目录 global 修饰符using 别名结合“global 修饰符”和“using 别名”static 修饰符来源 global 修饰符 向 using 指令添加 global 修饰符意味着 using 将应用于编译中的所有文件&#xff08;通常是一个项目&#xff09;。 global using 指令被添加到 C# 10 中。 其语法为…

怎么快速定位bug?怎么编写测试用例?

目录 01定位问题的重要性 02问题定位技巧 03初次怎么写用例 作为一名测试人员如果连常见的系统问题都不知道如何分析&#xff0c;频繁将前端人员问题指派给后端人员&#xff0c;后端人员问题指派给前端人员&#xff0c;那么在团队里你在开发中的地位显而易见 &#xff0c;口碑…

什么?按Home键SingleInstance Activity销毁了???

前段时间&#xff0c;突然有朋友询问&#xff0c;自己写的SingleInstance Activity在按home键的时候被销毁了&#xff0c;刚听到这个问题的时候&#xff0c;我直觉怀疑是Activity在onPause或者onStop中发生了Crash导致闪退了&#xff0c;但是安装apk查看现象&#xff0c;没有发…

摸索graphQL在前端vue中使用过程(四)

请求网址https://hasura.io/learn/graphql&#xff0c;他这个Authorization好像每天就会一次变化&#xff0c;需要注意。 之前用到了一种类型ID&#xff0c;也就是说&#xff0c;在GraphQL的查询标量的过程中。 标量:就是被查询的字段名称。 这里再补充一点知识&#xff0c;统…

Android 包体积资源优化实践

1 插件优化 插件优化资源在得物App最新版本上收益12MB。插件优化的日志在包体积平台有具体的展示&#xff0c;也是为了提供一个资源问题追溯的能力。 1.1 插件环境配置 插件首先会初始化环境配置&#xff0c;如果机器上未安装运行环境则会去oss下载对应的可执行文件。 1.2 图…

Windows 在VMware16.x安装Win11系统详细教程

文章目录 一、准备二、创建虚拟机1. 创建新的虚拟机2. 选择虚拟机硬件兼容性3. 安装客户机操作系统4. 选择客户机操作系统5. 命名虚拟机6. 固件类型7. 处理器配置8. 此虚拟机内存9. 网络类型10. 选择I/O控制器类型11. 选择磁盘类型12. 选择磁盘13. 指定磁盘容量14. 指定磁盘文件…

【深度学习】日常笔记15

训练集和测试集并不来⾃同⼀个分布。这就是所谓的分布偏移。 真实⻛险是从真实分布中抽取的所有数据的总体损失的预期&#xff0c;然⽽&#xff0c;这个数据总体通常是⽆法获得的。计算真实风险公式如下&#xff1a; 为概率密度函数 经验⻛险是训练数据的平均损失&#xff0c;⽤…

python机器学习(四)线性代数回顾、多元线性回归、多项式回归、标准方程法求解、线性回归案例

回顾线性代数 矩阵 矩阵可以理解为二维数组的另一种表现形式。A矩阵为三行两列的矩阵&#xff0c;B矩阵为两行三列的矩阵&#xff0c;可以通过下标来获取矩阵的元素&#xff0c;下标默认都是从0开始的。 A i j : A_{ij}: Aij​:表示第 i i i行&#xff0c;第 j j j列的元素。…

N位分频器的实现

N位分频器的实现 一、 目的 使用verilog实现n位的分频器&#xff0c;可以是偶数&#xff0c;也可以是奇数 二、 原理 FPGA中n位分频器的工作原理可以简要概括为: 分频器的作用是将输入时钟频率分频,输出低于输入时钟频率的时钟信号。n位分频器可以将输入时钟频率分频2^n倍…