手搓基于CNN的Chest X-ray图像分类

news2025/3/12 14:16:50

数据集Chest X-ray PD Dataset 数据集介绍 - 知乎https://zhuanlan.zhihu.com/p/661311561

 

CPU版本

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from time import time

start_time = time()


# 定义数据集类
class ChestXRayDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform=None):
        self.data_info = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data_info.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.data_info.iloc[idx, 1]
        if label == 'covid':
            label = 0
        elif label == 'normal':
            label = 1
        else:
            label = 2

        if self.transform:
            image = self.transform(image)

        return image, label

# 数据预处理
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据
root_dir = r'E:\NiuCode\LianXi\data\Chest X-ray\DataSet'
csv_path = r'E:\NiuCode\LianXi\data\Chest X-ray\metadata.csv'
dataset = ChestXRayDataset(csv_path, root_dir, transform=data_transform)
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

# 创建数据加载器
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

print("==========")

# 定义CNN模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(-1, 32 * 56 * 56)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)
        return out

model = CNNModel()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

# 测试模型
softmax = nn.Softmax(dim=1)  # 定义softmax函数,在类别维度上进行操作
correct = 0
total = 0
all_probabilities = []  # 用于存储所有样本的概率
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        probabilities = softmax(outputs)  # 将模型输出转换为概率分布
        all_probabilities.extend(probabilities.cpu().numpy())  # 存储概率
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test set: {100 * correct / total}%')

# 打印前几个样本的概率
print("Probabilities for the first few samples:")
for i in range(min(5, len(all_probabilities))):
    print(f"Sample {i+1}: {all_probabilities[i]}")

end_time = time()
print("消耗时间={}".format(end_time - start_time))

GPU版本

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from time import time

start_time = time()
# 定义数据集类
class ChestXRayDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform=None):
        self.data_info = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data_info.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.data_info.iloc[idx, 1]
        if label == 'covid':
            label = 0
        elif label == 'normal':
            label = 1
        else:
            label = 2

        if self.transform:
            image = self.transform(image)

        return image, label

# 数据预处理
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据
root_dir = r'E:\NiuCode\LianXi\data\Chest X-ray\DataSet'
csv_path = r'E:\NiuCode\LianXi\data\Chest X-ray\metadata.csv'
dataset = ChestXRayDataset(csv_path, root_dir, transform=data_transform)
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

# 创建数据加载器
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

print("==========")

# 定义CNN模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(-1, 32 * 56 * 56)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)
        return out

# 检查 CUDA 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = CNNModel()
# 将模型移到设备上
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移到设备上
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

# 测试模型
softmax = nn.Softmax(dim=1)  # 定义softmax函数,在类别维度上进行操作
correct = 0
total = 0
all_probabilities = []  # 用于存储所有样本的概率
with torch.no_grad():
    for images, labels in test_loader:
        # 将数据移到设备上
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        probabilities = softmax(outputs)  # 将模型输出转换为概率分布
        all_probabilities.extend(probabilities.cpu().numpy())  # 存储概率
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test set: {100 * correct / total}%')

# 打印前几个样本的概率
print("Probabilities for the first few samples:")
for i in range(min(5, len(all_probabilities))):
    print(f"Sample {i+1}: {all_probabilities[i]}")

end_time = time()

print("消耗时间={}".format(end_time-start_time))

CPU版本耗时:1310.6518561840057

GPU版本耗时:70.60973024368286

正确率:100%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2295217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用java代码操作rabbitMQ收发消息

SpringAMQP 将来我们开发业务功能的时候,肯定不会在控制台收发消息,而是应该基于编程的方式。由于RabbitMQ采用了AMQP协议,因此它具备跨语言的特性。任何语言只要遵循AMQP协议收发消息,都可以与RabbitMQ交互。并且RabbitMQ官方也…

【数据结构】(7) 栈和队列

一、栈 Stack 1、什么是栈 栈是一种特殊的线性表,它只能在固定的一端(栈顶)进行出栈、压栈操作,具有后进先出的特点。 2、栈概念的例题 答案为 C,以C为例进行讲解: 第一个出栈的是3,那么 1、…

Composo:企业级AI应用的质量守门员

在当今快速发展的科技世界中,人工智能(AI)的应用已渗透到各行各业。然而,随着AI技术的普及,如何确保其可靠性和一致性成为了企业面临的一大挑战。Composo作为一家致力于为企业提供精准AI评估服务的初创公司,通过无代码和API双模式,帮助企业监测大型语言模型(LLM)驱动的…

Python数据分析案例71——基于十种模型的信用违约预测实战

背景 好久没写这种基础的做机器学习流程了,最近过完年感觉自己代码忘了好多.....复习一下。 本次带来的是信贷违约的预测,即根据这个人的特征(年龄收入什么的),预测他是不是会违约,会违约就拒绝贷款&…

python康威生命游戏的图形化界面实现

康威生命游戏(Conway’s Game of Life)是由英国数学家约翰何顿康威(John Horton Conway)在1970年发明的一款零玩家的细胞自动机模拟游戏。尽管它的名字中有“游戏”,但实际上它并不需要玩家参与操作,而是通…

区块链技术:Facebook 重塑社交媒体信任的新篇章

在这个信息爆炸的时代,社交媒体已经成为我们生活中不可或缺的一部分。然而,随着社交平台的快速发展,隐私泄露、数据滥用和虚假信息等问题也日益凸显。这些问题的核心在于传统社交媒体依赖于中心化服务器存储和管理用户数据,这种模…

UE求职Demo开发日志#25 试试网络同步和尝试打包

1 改了一些时序上的bug,成功运行了多端 (UE一些网络相关的功能都弄好了,只需要标记哪个变量或Actor需要复制) 2 以前遗留的bug太多了,改到晚上才打包好一个能跑的版本,而且有的特效还不显示(悲…

Win10环境使用ChatBox集成Deep Seek解锁更多玩法

Win10环境使用ChatBox集成Deep Seek解锁更多玩法 前言 之前部署了14b的Deep Seek小模型,已经验证了命令行及接口方式的可行性。但是纯命令行或者PostMan方式调用接口显然不是那么友好: https://lizhiyong.blog.csdn.net/article/details/145505686 纯…

第 26 场 蓝桥入门赛

2.对联【算法赛】 - 蓝桥云课 问题描述 大年三十,小蓝和爷爷一起贴对联。爷爷拿出了两副对联,每副对联都由 N 个“福”字组成,每个“福”字要么是正的(用 1 表示),要么是倒的(用 0 表示&#…

CodeGPT + IDEA + DeepSeek,在IDEA中引入DeepSeek实现AI智能开发

CodeGPT IDEA DeepSeek,在IDEA中引入DeepSeek 版本说明 建议和我使用相同版本,实测2022版IDEA无法获取到CodeGPT最新版插件。(在IDEA自带插件市场中搜不到,可以去官网搜索最新版本) ToolsVersionIntelliJ IDEA202…

【2025年更新】1000个大数据/人工智能毕设选题推荐

文章目录 前言大数据/人工智能毕设选题:后记 前言 正值毕业季我看到很多同学都在为自己的毕业设计发愁 Maynor在网上搜集了1000个大数据的毕设选题,希望对大家有帮助~ 适合大数据毕业设计的项目,完全可以作为本科生当前较新的毕…

什么是中间件中间件有哪些

什么是中间件? 中间件(Middleware)是指在客户端和服务器之间的一层软件组件,用于处理请求和响应的过程。 中间件是指介于两个不同系统之间的软件组件,它可以在两个系统之间传递、处理、转换数据,以达到协…

使用docker搭建FastDFS文件服务

1.拉取镜像 docker pull registry.cn-hangzhou.aliyuncs.com/qiluo-images/fastdfs:latest2.使用docker镜像构建tracker容器(跟踪服务器,起到调度的作用) docker run -dti --networkhost --name tracker -v /data/fdfs/tracker:/var/fdfs -…

天津三石峰科技——汽车生产厂的设备振动检测项目案例

汽车产线有很多传动设备需要长期在线运行,会出现老化、疲劳、磨损等 问题,为了避免意外停机造成损失,需要加装一些健康监测设备,监测设备运 行状态。天津三石峰科技采用 12 通道振动信号采集卡(下图 1)对…

Linux之文件IO前世今生

在 Linux之文件系统前世今生(一) VFS中,我们提到了文件的读写,并给出了简要的读写示意图,本文将分析文件I/O的细节。 一、Buffered I/O(缓存I/O)& Directed I/O(直接I/O&#…

半导体制造工艺讲解

目录 一、半导体制造工艺的概述 二、单晶硅片的制造 1.单晶硅的制造 2.晶棒的切割、研磨 3.晶棒的切片、倒角和打磨 4.晶圆的检测和清洗 三、晶圆制造 1.氧化与涂胶 2.光刻与显影 3.刻蚀与脱胶 4.掺杂与退火 5.薄膜沉积、金属化和晶圆减薄 6.MOSFET在晶圆表面的形…

深入理解进程优先级

目录 引言 一、进程优先级基础 1.1 什么是进程优先级? 1.2 优先级与系统性能 二、查看进程信息 2.1 使用ps -l命令 2.2 PRI与NI的数学关系 三、深入理解Nice值 3.1 Nice值的特点 3.2 调整优先级实践 四、进程特性全景图 五、优化实践建议 结语 引言 在操…

微信小程序案例2——天气微信小程序(学会绑定数据)

文章目录 一、项目步骤1 创建一个无AppID的weather项目2 进入index.wxml、index.js、index.wxss文件,清空所有内容,进入App.json,修改导航栏标题为“中国天气网”。3进入index.wxml,进行当天天气情况的界面布局,包括温…

【Linux网络编程】之守护进程

【Linux网络编程】之守护进程 进程组进程组的概念组长进程 会话会话的概念会话ID 控制终端控制终端的概念控制终端的作用会话、终端、bash三者的关系 前台进程与后台进程概念特点查看当前终端的后台进程前台进程与后台进程的切换 作业控制相关概念作业状态(一般指后…

MarkupLM:用于视觉丰富文档理解的文本和标记语言预训练

摘要 结合文本、布局和图像的多模态预训练在视觉丰富文档理解(VRDU)领域取得了显著进展,尤其是对于固定布局文档(如扫描文档图像)。然而,仍然有大量的数字文档,其布局信息不是固定的&#xff0…