第64步 深度学习图像识别:多分类建模误判病例分析(Pytorch)

news2024/9/23 15:26:15

基于WIN10的64位系统演示

一、写在前面

上期我们基于TensorFlow环境介绍了多分类建模的误判病例分析。

本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于Pytorch环境,构建SqueezeNet多分类模型,分析误判病例,因为它建模速度快。

同样,基于GPT-4辅助编程。

二、误判病例分析实战

使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,健康人900张,肺结核病人700张,COVID-19病人549张、细菌性(病毒性)肺炎组900张,分别存入单独的文件夹中。

直接分享代码:

######################################导入包###################################
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

################################导入数据集#####################################
from torchvision import datasets, transforms
from torch.nn.functional import softmax
from PIL import Image
import pandas as pd
import torch.nn as nn
import timm
from torch.optim import lr_scheduler

# 自定义的数据集类
class ImageFolderWithPaths(datasets.ImageFolder):
    def __getitem__(self, index):
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

# 数据集路径
data_dir = "./MTB-1"

# 图像的大小
img_height = 256
img_width = 256

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_height),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
full_dataset = ImageFolderWithPaths(data_dir, transform=data_transforms['train'])

# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.8 * full_size)  # 假设训练集占70%
val_size = full_size - train_size  # 验证集的大小

# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 应用数据增强到训练集和验证集
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']

# 创建数据加载器
batch_size = 8
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes

# 获取数据集的类别
class_names = full_dataset.classes

# 保存预测结果的列表
results = []

###############################定义SqueezeNet模型################################
# 定义SqueezeNet模型
model = models.squeezenet1_1(pretrained=True)  # 这里以SqueezeNet 1.1版本为例
num_ftrs = model.classifier[1].in_channels

# 根据分类任务修改最后一层
# 这里我们改变模型的输出层为4,因为我们做的是四分类
model.classifier[1] = nn.Conv2d(num_ftrs, 4, kernel_size=(1,1))

# 修改模型最后的输出层为我们需要的类别数
model.num_classes = 4

model = model.to(device)

# 打印模型摘要
print(model)

#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = torch.optim.Adam(model.parameters())

# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 开始训练模型
num_epochs = 20

# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # 每个epoch都有一个训练和验证阶段
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # 设置模型为训练模式
        else:
            model.eval()   # 设置模型为评估模式

        running_loss = 0.0
        running_corrects = 0

        # 遍历数据
        for inputs, labels, paths in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 零参数梯度
            optimizer.zero_grad()

            # 前向
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # 只在训练模式下进行反向和优化
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # 统计
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()

        # 记录每个epoch的loss和accuracy
        if phase == 'train':
            train_loss_history.append(epoch_loss)
            train_acc_history.append(epoch_acc)
        else:
            val_loss_history.append(epoch_loss)
            val_acc_history.append(epoch_acc)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

    print()

# 保存模型
torch.save(model.state_dict(), 'SqueezeNet_model-m-s.pth')

# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")
###########################误判病例分析#################################
import os
import pandas as pd
from collections import defaultdict

# 判定组别的字典
group_dict = {
    ("COVID-19", "Normal"): "B",
    ("COVID-19", "Pneumonia"): "C",
    ("COVID-19", "Tuberculosis"): "D",
    ("Normal", "COVID-19"): "E",
    ("Normal", "Pneumonia"): "F",
    ("Normal", "Tuberculosis"): "G",
    ("Pneumonia", "COVID-19"): "H",
    ("Pneumonia", "Normal"): "I",
    ("Pneumonia", "Tuberculosis"): "J",
    ("Tuberculosis", "COVID-19"): "K",
    ("Tuberculosis", "Normal"): "L",
    ("Tuberculosis", "Pneumonia"): "M",
}

# 创建一个字典来保存所有的图片信息
image_predictions = {}

# 循环遍历所有数据集(训练集和验证集)
for phase in ['train', 'val']:
    # 设置模型的状态
    model.eval()

    # 遍历数据
    for inputs, labels, paths in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 计算模型的输出
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        # 循环遍历每一个批次的结果
        for path, pred in zip(paths, preds):
            # 提取图片的类别
            actual_class = os.path.split(os.path.dirname(path))[-1] 
            # 提取图片的名称
            image_name = os.path.basename(path)
            # 获取预测的类别
            predicted_class = class_names[pred]

            # 判断预测的分组类型
            if actual_class == predicted_class:
                group_type = 'A'
            elif (actual_class, predicted_class) in group_dict:
                group_type = group_dict[(actual_class, predicted_class)]
            else:
                group_type = 'Other'  # 如果没有匹配的条件,可以归类为其他

            # 保存到字典中
            image_predictions[image_name] = [phase, actual_class, predicted_class, group_type]

# 将字典转换为DataFrame
df = pd.DataFrame.from_dict(image_predictions, orient='index', columns=['Dataset Type', 'Actual Class', 'Predicted Class', 'Group Type'])

# 保存到CSV文件中
df.to_csv('result-m-s.csv')

四、改写过程

先说策略:首先,先把二分类的误判病例分析代码改成四分类的;其次,用咒语让GPT-4帮我们续写代码已达到误判病例分析。

提供咒语如下:

①改写{代码1},改变成4分类的建模。代码1为:{XXX};

在{代码1}的基础上改写代码,达到下面要求:

(1)首先,提取出所有图片的“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”;文件的路劲格式为:例如,“MTB-1\Normal\XXX.png”属于Normal,“MTB-1\COVID-19\XXX.jpg”属于COVID-19,“MTB-1\Pneumonia\XXX.jpeg”属于Pneumonia,“MTB-1\Tuberculosis\XXX.png”属于Tuberculosis;

(2)其次,根据样本预测结果,把样本分为以下若干组:(a)预测正确的图片,全部判定为A组;(b)本来就是COVID-19的图片,预测为Normal,判定为B组;(c)本来就是COVID-19的图片,预测为Pneumonia,判定为C组;(d)本来就是COVID-19的图片,预测为Tuberculosis,判定为D组;(e)本来就是Normal的图片,预测为COVID-19,判定为E组;(f)本来就是Normal的图片,预测为Pneumonia,判定为F组;(g)本来就是Normal的图片,预测为Tuberculosis,判定为G组;(h)本来就是Pneumonia的图片,预测为COVID-19,判定为H组;(i)本来就是Pneumonia的图片,预测为Normal,判定为I组;(j)本来就是Pneumonia的图片,预测为Tuberculosis,判定为J组;(k)本来就是Tuberculosis的图片,预测为COVID-19,判定为H组;(l)本来就是Tuberculosis的图片,预测为Normal,判定为I组;(m)本来就是Tuberculosis的图片,预测为Pneumonia,判定为J组;

(3)居于以上计算的结果,生成一个名为result-m.csv表格文件。列名分别为:“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”、“判定的组别”。其中,“原始图片的名称”为所有图片的图片名称;“属于训练集还是验证集”为这个图片属于训练集还是验证集;“预测为分组类型”为模型预测该样本是哪一个分组;“判定的组别”为根据步骤(2)判定的组别,从A到J一共十组选择一个。

(4)需要把所有的图片都进行上面操作,注意是所有图片,而不只是一个批次的图片。

代码1为:{XXX}

③还需要根据报错做一些调整即可,自行调整。

最后,看看结果:

模型只运行了2次,所以效果很差哈,全部是预测成了COVID-19。

四、数据

链接:https://pan.baidu.com/s/1rqu15KAUxjNBaWYfEmPwgQ?pwd=xfyn

提取码:xfyn

五、结语

深度学习图像分类的教程到此结束,洋洋洒洒29篇,涉及到的算法和技巧也够发一篇SCI了。当然,图像识别还有图像分割和目标识别两块内容,就放到最后再说了。下一趴,我们来介绍时间序列建模!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/958806.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小黑受到了未来的焦虑,周四继续参加团跑活动仰山跑,跑奥森的坡,越跑越上瘾更加热爱生活的leetcode之旅:LCR 008. 长度最小的子数组

小黑代码1 class Solution:def minSubArrayLen(self, target: int, nums: List[int]) -> int:# 数组长度n len(nums)# 双指针head 0tail 0# 中间变量sum_ 0# 结果变量res n1# 开始双指针迭代while tail < n:sum_ nums[tail]tail 1while sum_ > target:if tail…

源码安装cv_bridge

1. 下载源码 1去github上下载GitHub - ros-perception/vision_opencv&#xff0c;进去后注意选择与自己的ros对应的版本&#xff1a;&#xff08;我的为noetic&#xff09; 如果你直接使用 git clone https://github.com/ros-perception/vision_opencv.git 来拉取的话cmake的…

CTFshow 菜狗杯 web方向 全

文章目录 菜狗杯 web签到菜狗杯 web2 c0me_t0_s1gn菜狗杯 我的眼里只有$菜狗杯 抽老婆菜狗杯 一言既出菜狗杯 驷马难追菜狗杯 TapTapTap菜狗杯 Webshell菜狗杯 化零为整菜狗杯 无一幸免菜狗杯 无一幸免_FIXED菜狗杯 传说之下&#xff08;雾&#xff09;菜狗杯 算力超群菜狗杯 算…

webdriver对应浏览器

Chrome for Testing availability 更多是在这里下载: http://chromedriver.storage.googleapis.com/index.html?path103.0.5060.53/ 但是这里截止2023年8月31日14:31分没有115以后得版本 但是我的浏览器已经到了116的版本 为了关闭掉Ch 所以下载webdriver的地址是: http…

记录学习--字节码解析try catch

1.示例代码 Testpublic void someTest() {String s "111";try {s "222";int i 1/0;} catch (Exception e){e.printStackTrace();System.out.println(s);}System.out.println(s);}2.示例代码对应的字节码 0 ldc #2 <111>2 astore_13 ldc #3 <22…

python实现语音识别

1. 首先安装依赖库 pip install playsound # 该库用于播放音频文件 pip install speech_recognition # 该库用于语音识别 pip install PocketSphinx # 语音识别模块中只有sphinx支持离线的&#xff0c;使用该模块需单独安装 pip install pyttsx3 # 该库用于将文本转换为语音播…

如何使用ArcGIS Earth制作地图动画视频

通常情况下&#xff0c;我们所看到的地图都是静态展示&#xff0c;对于信息的传递&#xff0c;视频比图片肯定会更加丰富&#xff0c;所以制作地图动画视频更加有利于信息的传递&#xff0c;这里我们讲解一下ArcGIS Earth 2.0如何制作地图动画视频&#xff0c;希望能对你有所帮…

TikTok网红营销之谜:为何成功程度参差不齐?

近年来&#xff0c;随着社交媒体的迅猛发展&#xff0c;TikTok作为一款以短视频为主要内容形式的应用&#xff0c;在全球范围内迅速走红。不仅个人用户在TikTok上分享自己的创意&#xff0c;越来越多的品牌也开始借助TikTok网红进行营销推广。然而&#xff0c;尽管众多人都在尝…

2023年9月数据治理/项目管理/产品管理/商务礼仪企业内训定制

在节奏飞驰、风起云涌的企业世界中&#xff0c;为了企业的蓬勃发展&#xff0c;可以在内部或者外部挑选有经验的老师进行培训和学习。简而言之&#xff0c;任何一个企业想要发展&#xff0c;都少不了进行内训。 企业内训的好处 提高员工的技能和知识水平 通过不断地学习和培训…

分库分表篇-2.1 Mycat-配置文件篇

文章目录 前言一、Mycat server.xml作用&#xff1a;1.1 server.xml 作用&#xff1a;1.2 定义数据库逻辑模式&#xff1a; 二、Mycat schema.xml作用&#xff1a;2.1 schema 标签&#xff1a;2.1.1 schema 中table 标签&#xff1a; 2.2 dataNode 标签&#xff1a;2.3 dataHos…

骨传导耳机十大品牌怎么选,骨传导耳机十大品牌排行榜分享

作为一个拥有20多款骨传导耳机来说&#xff0c;我也算是资深的使用者了&#xff0c;在骨传导耳机刚开始兴起的时候&#xff0c;我就开始接触了&#xff0c;近几年越来越多的骨传导耳机品牌诞生&#xff0c;我也是入手了不少&#xff0c;所以也算是对骨传导耳机非常熟悉了&#…

Error obtaining UI hierarchy Error taking device screenshot: EOF/NULL 解决办法

RT&#xff1a;Error obtaining UI hierarchy Error taking device screenshot: EOF/NULL 解决办法 关于monitor开发神器我就不多说了&#xff0c;但是假如我们在开发中遇到如上问题该怎么处理呢&#xff1f;别慌下面会有方法&#xff0c;不过不是对任何机型都有效&#xff0c…

【送书活动】深入浅出SSD:固态存储核心技术、原理与实战

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 「推荐专栏」&#xff1a; ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄&#xff0c;vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…

数据库中的表和Json

目录 一、表转Json 1.使用 for json path 2.如何返回单个Json 3.如何给返回的Json增加一个根节点呢 4.如何给返回的Json增加上一个节点 二、对Json基本操作 1.判断给的字符串是否是Json格式 2.从 JSON 字符串中提取标量值 3. 从 JSON 字符串中提取对象或数组 4. 更…

价格战“杀疯了”?「智驾」系统级降本增效,才是更优解

从3000元到1500元&#xff0c;再到千元级别&#xff0c;今年以来&#xff0c;行泊一体域控产品不断刷新降本底线。 以上&#xff0c;也反映出今年中国智驾规模量产赛道的竞争激烈程度。当前&#xff0c;各路 Tier1甚至Tier2供应商们还在加速“内卷”&#xff0c;从早期的卷技术…

7个公认的wordpress外贸独立站优点

随着全球化进程的加快&#xff0c;越来越多的企业开始将业务拓展到国际市场。对于外贸企业而言&#xff0c;拥有一个专业且易于管理的网站非常重要。WordPress外贸独立站恰好满足了这一需求&#xff0c;它不仅具备开源、可定制等特点&#xff0c;还有以下几个优点&#xff1a;​…

idea中设置style固定样式

一、样式设置首先打开IDEA之后,点击任务栏的“File”→Settings 二、设置style行内样式 1.首先打开IDEA之后,点击任务栏的“File”。 2.在下拉列表中中选择“Settings” 3.在弹出的设置页面中找到Editor-LiveTemplates 点击号&#xff0c;先选中Template Group...创建 三、详…

【01】弄懂共识机制PoW

基于工作量证明机制的共识机制PoW&#xff08;Proof of Work&#xff09; 特点就是多劳多特 共识过程 一个区块链系统中&#xff0c;交易历经多个步骤才能得以上链&#xff0c;并且需要经过多个节点的验证。以下是这些步骤的详细叙述&#xff1a; 交易进入交易池&#xff08;内…

VS2022 C语言课程设计学生成绩管理系统

C语言课程设计题目及要求 学生成绩管理系统 此成绩管理系统主要利用单链表或者结构数组实现&#xff08;最好用单链表实现&#xff09;&#xff0c;具有如下的五大功能模块。学生成绩管理系统功能模块图如图1所示。 说明&#xff1a; 1.输入记录模块 从键盘逐个输入学生记录…

leetcode236. 二叉树的最近公共祖先(java)

二叉树的最近公共祖先 题目描述递归法代码演示 上期经典 题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q …