PyTorch学习笔记:新冠肺炎X光分类

news2024/10/6 16:29:27

前言

目的是要了解pytorch如何完成模型训练
https://github.com/TingsongYu/PyTorch-Tutorial-2nd参考的学习笔记


数据准备

由于本案例目的是pytorch流程学习,为了简化学习过程,数据仅选择了4张图片,分为2类,正常与新冠,训练集2张,
验证集2张。标签信息存储于TXT文件中。具体目录结构如下:

注意:covid-19的图可以找到但是no-finding两张图没有找到
covid-19-1
covid-19-2
no-finding的图随便照两张看着正常的,别问我哪个是正常的,我也不知道(❍ᴥ❍ʋ),需要改名字为00001215_000.png00001215_001.png

├─imgs
│  ├─covid-19
│  │      auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg
│  │      ryct.2020200028.fig1a.jpeg
│  │
│  └─no-finding
│         00001215_000.png
│         00001215_001.png
│
└─labels
       train.txt
       valid.txt

创建标签文件:

创建 train.txt 和 valid.txt 文件,并填入图片路径和标签信息

  • train.txt:
covid-19/auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg 1
no-finding/00001215_000.png 0

  • valid.txt:
covid-19/ryct.2020200028.fig1a.jpeg 1
no-finding/00001215_001.png 0

完整代码示例:

以下是准备数据集、定义模型和训练模型的完整代码示例:

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 自定义数据集类
class COVID19Dataset(Dataset):
    def __init__(self, img_dir, label_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = []

        with open(label_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                self.img_labels.append(line.strip().split())

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_path)
        image = Image.open(img_path).convert('RGB')
        label = int(label)

        if self.transform:
            image = self.transform(image)

        return image, label

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((8, 8)),
    transforms.ToTensor()
])

# 创建数据集和数据加载器
train_dataset = COVID19Dataset(img_dir='imgs', label_file='labels/train.txt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

valid_dataset = COVID19Dataset(img_dir='imgs', label_file='labels/valid.txt', transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False)

# 定义简单卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3)  # 输入通道为3(RGB),输出通道为1,卷积核大小为3x3
        self.fc1 = nn.Linear(1 * 6 * 6, 2)  # 全连接层,输入大小为6*6*1,输出大小为2(2类)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = x.view(-1, 1 * 6 * 6)  # 展平操作
        x = self.fc1(x)
        return x

model = SimpleCNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {running_loss / 10:.6f}')
            running_loss = 0.0

# 验证函数
def validate(model, valid_loader, criterion):
    model.eval()
    validation_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in valid_loader:
            output = model(data)
            validation_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    validation_loss /= len(valid_loader.dataset)
    print(f'\nValidation set: Average loss: {validation_loss:.4f}, Accuracy: {correct}/{len(valid_loader.dataset)} ({100. * correct / len(valid_loader.dataset):.0f}%)\n')

# 训练和验证
for epoch in range(1, 11):
    train(model, train_loader, criterion, optimizer, epoch)
    validate(model, valid_loader, criterion)

效果展示:

由于数据量少,随机性非常大,大家多运行几次,观察结果。不过本案例结果完全不重要!)可以观看Average loss变化,Accuracy由于训练数据过少几乎不会变化
在这里插入图片描述

知识点总结

1. 数据

  • Q:要知道pytorch需要模型的格式
    A:需要编写代码完成数据的读取,转换成模型能够读取的格式。在 PyTorch 中,读取数据通常通过自定义 Dataset 类和内置的 DataLoader 来实现。这种方法既灵活又高效,适用于各种类型的数据集。
  • Q:自己如何编写Dataset?
    A:编写一个自定义的 Dataset 类,需要继承 torch.utils.data.Dataset 并实现三个方法:__init____len__ __getitem__

2. 模型

可参考:
从“卷积”、到“图像卷积操作”、再到“卷积神经网络”,“卷积”意义的3次改变_哔哩哔哩_bilibili

  • Q: 卷积层,全连接层的作用是什么?
    A: 卷积层提取特征,全连接层进行分类。
    1. 卷积层
    • 卷积层的作用是提取输入图像的特征。
    • 使用 3x3 的卷积核进行卷积操作,可以捕捉到局部的空间特征。
    • 卷积操作后的输出会产生一个新的特征图,这个特征图是卷积层提取到的特征表示。
    1. 全连接层
    • 全连接层的作用是将卷积层提取到的特征进行进一步的处理,最终输出分类结果。
    • 在这个例子中,全连接层有两个神经元,分别输出两个分类的概率。
    • 全连接层的输入被限制在 8x8,这意味着输入的特征图经过扁平化(flatten)后被映射到一个 8x8 的向量。

3. 优化

  • Q:根据什么规则对模型的参数进行更新学习呢?
    A:常用的方法:交叉熵损失函数(CrossEntropyLoss)、随机梯度下降法(SGD)和按固定步长下降学习率策略(StepLR)

4. 迭代

  • Q:怎么进行模型迭代?
    A: 有了模型参数更新的必备组件,接下来需要一遍又一遍地给模型喂数据,监控模型训练状态,这时候就需要for循环,不断地从dataloader里取出数据进行前向传播,反向传播,参数更新,观察loss、acc,周而复始。

总结

详细内容https://github.com/TingsongYu/PyTorch-Tutorial-2nd可查看,这是一篇读书笔记,与代码实现的分享。后续的笔记会以Q-A解决一些问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1700386.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang | Leetcode Golang题解之第114题二叉树展开为链表

题目: 题解: func flatten(root *TreeNode) {curr : rootfor curr ! nil {if curr.Left ! nil {next : curr.Leftpredecessor : nextfor predecessor.Right ! nil {predecessor predecessor.Right}predecessor.Right curr.Rightcurr.Left, curr.Righ…

95.网络游戏逆向分析与漏洞攻防-ui界面的设计-ui的设计与架构

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 如果看不懂、不知道现在做的什么,那就跟着做完看效果,代码看不懂是正常的,只要会抄就行,抄着抄着就能懂了 内容…

JVM之【运行时数据区】

JVM简图 运行时数据区简图 一、程序计数器(Program Counter Register) 1.程序计数器是什么? 程序计数器是JVM内存模型中的一部分,它可以看作是一个指针,指向当前线程所执行的字节码指令的地址。每个线程在执行过程中…

通过acme.sh和cloudflare实现免费ssl证书自动签发

参考使用acme.sh通过cloudflare自动签发免费ssl证书 | LogDicthttps://www.logdict.com/archives/acme.shshi-yong-cloudflarezi-dong-qian-fa-mian-fei-sslzheng-shu

服务器数据恢复—服务器正常断电重启后raid信息丢失的数据恢复案例

服务器数据恢复环境: 一台某品牌DL380 G4服务器,服务器通过该服务器品牌smart array控制器挂载了一台国产的磁盘阵列,磁盘阵列中有一组由14块SCSI硬盘组建的RAID5。服务器安装LINUX操作系统,搭建了NFSFTP,作为内部文件…

ROS添加GDB调试

文章目录 一、问题描述二、配置步骤1. debug 模式编译2. rosrun 添加GDB指令3. launch 添加GDB指令 三、GDB基本命令1. 基本2. 显示被调试文件信息3. 查看/修改内存4. 断点5. 调试运行 一、问题描述 在享受ROS带来便利的同时,但因每运行出现错误,ROS不会…

Python筑基之旅-文件(夹)操作和流

目录 一、文件操作 1、文件打开与关闭 2、文件读写 3、文件操作模式 4、文件编码 二、文件夹操作 1、创建文件夹 2、删除文件夹 3、改变当前工作目录 4、获取当前工作目录 5、检查文件/文件夹是否存在 6、遍历文件夹 三、文件路径操作 1、获取绝对路径 2、构建完…

Android 逆向学习【1】——版本/体系结构/代码学习

#Android 历史版本 参考链接:一篇文章让你了解Android各个版本的历程 - 知乎 (zhihu.com) 三个部分:api等级、版本号、代号(这三个东西都是指的同一个系统) API等级:在APP开发的时候写在清单列表里面的 版本号&…

【Springboot系列】SpringBoot 中的日志如何工作的,看完这一篇就够了

文章目录 强烈推荐引言Spring Boot 中的日志是怎么工作日志框架选择配置文件日志级别自定义日志配置集成第三方日志库实时监控和日志管理 Log4j2工作原理分析1. 核心组件2. 配置文件3. Logger的继承和层次结构4. 日志事件处理流程5. 异步日志 总结强烈推荐专栏集锦写在最后 强烈…

【MySQL进阶之路 | 基础篇】MySQL新特性 : 窗口函数

1. 前言 (1). MySQL8开始支持窗口函数. 其作用类似于在查询中对数据进行分组(GROUP BY),不同的是,分组操作会把分组的结果聚合成一条记录. 而窗口函数是将结果置于每一条数据记录中. (2). 窗口函数还可以分为静态窗口函数和动态窗口函数. 静态窗口函数…

堆(建堆算法,堆排序)

目录 一.什么是堆? 1.堆 2.堆的储存 二.堆结构的创建 1.头文件的声明: 2.向上调整 3.向下调整 4.源码: 三.建堆算法 1.向上建堆法 2.向下建堆法 四.堆排序 五.在文件中Top出最小的K个数 一.什么是堆? 1.堆 堆就…

AIGC产业链上下游解析及常见名词

文章目录 AIGC上游产业链 - 基础层AIGC中游产业链 - 大模型层与工具层AIGC下游产业链 - 应用层AIGC产业链常见的名词表 在上一章节为大家介绍了 “大模型的不足与解决方案” ,这一小节呢为大家针对AIGC的相关产业进行一个拆解,以及相关的一些专业名词做出…

RK3568笔记二十六:音频应用

若该文为原创文章,转载请注明原文出处。 一、介绍 音频是我们最常用到的功能,音频也是 linux 和安卓的重点应用场合。 测试使用的是ATK-DLR3568板子,板载外挂RK809 CODEC芯片,RK官方驱动是写好的,不用在自己重新写。…

C语言 | Leetcode C语言题解之第113题路径总和II

题目: 题解: int** ret; int retSize; int* retColSize;int* path; int pathSize;typedef struct {struct TreeNode* key;struct TreeNode* val;UT_hash_handle hh; } hashTable;hashTable* parent;void insertHashTable(struct TreeNode* x, struct Tr…

第八篇【传奇开心果系列】Python微项目技术点案例示例:以微项目开发为案例,深度解读Dearpygui 编写图形化界面桌面程序的优势

传奇开心果博文系列 系列博文目录Python微项目技术点案例示例系列 博文目录前言一、开发图形化界面桌面程序的优势介绍二、跨平台特性示例代码和解析三、高性能特性示例代码和解析四、简单易用特性示例代码和解析五、扩展性强示例代码和解析六、现代化设计示例代码和解析七、知…

【PB案例学习笔记】-09滚动条使用

写在前面 这是PB案例学习笔记系列文章的第8篇,该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习,提高编程技巧,以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码,小凡都上传到了gitee…

如何使用KNN

导入文件和库 加载数据集、拆分数据集 训练模型 预测 打印结果

用C#调用SAP 的WebServices接口

文章目录 用C#调用SAP 的WebServices接口创建C#的项目添加窗体添加引用在表单的装载事件里编写代码运行结果SAP的RFC函数 用C#调用SAP 的WebServices接口 创建C#的项目 添加窗体 添加引用 在表单的装载事件里编写代码 using System; using System.Collections.Generic; using …

MicroLED:苹果对知识产权的影响

Yole的洞察揭示,MicroLED IP在经历了七年的爆炸式增长后,已然屹立于行业之巅。苹果公司,作为微LED领域的先行者,早在2014年便敏锐地捕捉到Luxvue这家初创公司的潜力,将其纳入麾下,引发了业界的广泛关注。然…

204页 | MES项目需求案例方案:效率+精细化+品质+数据互联(免费下载)

【1】关注本公众号,转发当前文章到微信朋友圈 【2】私信发送 MES项目需求案例方案 【3】获取本方案PDF下载链接,直接下载即可。 如需下载本方案PPT/WORD原格式,请加入微信扫描以下方案驿站知识星球,获取上万份PPT/WORD解决方案&…