血细胞分类项目

news2024/11/17 21:24:17

血细胞分类项目

    • 数据集:血细胞分类数据集
    • 数据处理 dataset.py
    • 网络 net.py
    • 训练 train.py
    • 拿训练集的几张图进行预测

数据集:血细胞分类数据集

https://aistudio.baidu.com/datasetdetail/10278
在这里插入图片描述
在这里插入图片描述

数据处理 dataset.py

from torchvision import transforms
import torchvision
import torch
import matplotlib.pyplot as plt
from PIL import Image
#一、数据转换
train_transformer=transforms.Compose(
[
   transforms.RandomHorizontalFlip(0.2),
   transforms.RandomRotation(68),
   transforms.RandomGrayscale(0.2),
   transforms.Resize((256,256)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.5,0.5,0.5],
                std=[0.5,0.5,0.5])
]
)
test_transformer=transforms.Compose(
[
   transforms.Resize((256,256)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.5,0.5,0.5],
                std=[0.5,0.5,0.5])
]
)
#二、读入数据
train_dataset=torchvision.datasets.ImageFolder(
  'E:/Jupytercode/血细胞分类/数据/blood-cells/dataset2-master/dataset2-master/images/TRAIN',
    transform=train_transformer
)

test_dataset=torchvision.datasets.ImageFolder(
  'E:/Jupytercode/血细胞分类/数据/blood-cells/dataset2-master/dataset2-master/images/TEST',
   transform=test_transformer
)

#进行编码
#原      {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}
#转换后  {0: 'EOSINOPHIL', 1: 'LYMPHOCYTE', 2: 'MONOCYTE', 3: 'NEUTROPHIL'}
id_to_class={}
for k,v in train_dataset.class_to_idx.items():
    #print(k,v)
    id_to_class[v]=k
#id_to_class #查看转换后的格式

#三、批次读入数据,可以作为神经网络的输入  一次性拿多少张图片进行训练
Batch_size=64#一次性训练64张
dl_train=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=Batch_size,
        shuffle=True
)
dl_test=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=Batch_size,
        shuffle=True
)
#取一个批次的数据
# img,label=next(iter(dl_train))
# plt.figure(figsize=(12,8))
# for i,(img,label) in enumerate(zip(img[:8],label[:8])):
#     img=(img.permute(1,2,0).numpy()+1)/2
#     plt.subplot(2,4,i+1)
#     plt.title(id_to_class.get(label.item())) #0: 'EOSINOPHIL', 1: 'LYMPHOCYTE', 2: 'MONOCYTE', 3: 'NEUTROPHIL'
#     plt.imshow(img)
# plt.show() #查看图片

print("数据处理已完成")

网络 net.py

import torch.nn as nn
import torch
#建立神经网络
class Net(nn.Module):  # 模仿VGG
    def __init__(self):
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)

        )
        self.fc = nn.Sequential(
            nn.Linear(256 * 14 * 14, 1024),
            nn.ReLU(),
            nn.Linear(1024, 4)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        #print(x.shape)

        # 将原来的张量 x (四维)重新塑造为一个二维张量。第一个维度的大小由 PyTorch 自动计算,而第二个维度的大小被设置为 256 * 14 * 14
        x = x.view(-1, 256 * 14 * 14)
        x = self.fc(x)

        return x
if __name__ == '__main__':
    x = torch.rand([8, 3, 256, 256])
    model = Net()
    y = model(x)

在这里插入图片描述

训练 train.py

import torch as t
import torch.nn as nn
from tqdm import tqdm  #进度条
import net
from dataset import *


device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
train_dataset=torchvision.datasets.ImageFolder(
  'E:/Jupytercode/血细胞分类/数据/blood-cells/dataset2-master/dataset2-master/images/TRAIN',
    transform=train_transformer
)
test_dataset=torchvision.datasets.ImageFolder(
  'E:/Jupytercode/血细胞分类/数据/blood-cells/dataset2-master/dataset2-master/images/TEST',
   transform=test_transformer
)

id_to_class={}
for k,v in train_dataset.class_to_idx.items():
    #print(k,v)
    id_to_class[v]=k

Batch_size=64#一次性训练64张
dl_train=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=Batch_size,
        shuffle=True
)
dl_test=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=Batch_size,
        shuffle=True
)
model=net.Net()
model = model.to(device)
optim=torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn=nn.CrossEntropyLoss()


def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    model.train()  # 训练模式下  识别normalize层
    for x, y in tqdm(trainloader):
        x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()

        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total

    test_correct = 0
    test_total = 0
    test_running_loss = 0
    model.eval()  # 验证模式下   不识别normalize层
    with torch.no_grad():
        for x, y in tqdm(testloader):
            x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / test_total

    if epoch_acc > 0.95:
        model_state_dict = model.state_dict()
        torch.save(model_state_dict, './{}{}.pth'.format(epoch_acc, epoch_test_acc))

    print('epoch: ', epoch,
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
          )

    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
if __name__ == '__main__':
    epochs = 20
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(epochs):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                     model,
                                                                     dl_train,
                                                                     dl_test)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)

    plt.plot(range(1, epochs + 1), train_loss, label='train_loss')# 绘制训练损失曲线,使用range(1, epochs+1)生成横坐标轴上的点,train_loss为纵坐标轴上的点
    plt.plot(range(1, epochs + 1), test_loss, label='test_loss')# 绘制验证损失曲线,使用range(1, epochs+1)生成横坐标轴上的点,val_loss为纵坐标轴上的点
    plt.legend()# 添加图例,label参数在前面的plot中设置,用于区分不同曲线
    plt.xlabel('Epochs')  # 设置横坐标轴的标签为'Epochs'
    plt.ylabel('Loss')  # 设置纵坐标轴的标签为'Loss'
    plt.savefig('loss.png')
    plt.show()

    plt.plot(range(1, epochs + 1), train_acc, label='train_acc')
    plt.plot(range(1, epochs + 1), test_acc, label='test_acc')
    plt.title('Training and Validation Accuracy')  # 可以添加标题
    plt.xlabel('Epochs')  # 为x轴添加标签
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('acc.png')
    plt.show()
    torch.save(model,'Bloodcell.pkl') #保存模型训练权重

在这里插入图片描述
在深度学习中,模型通常具有两种运行模式:训练模式和验证/测试模式。这两种模式的主要区别在于模型的行为和参数更新方式。

  1. 训练模式(Training Mode):
    在训练模式下,模型会执行以下操作:
    ①梯度计算: 计算模型参数关于损失函数的梯度,以便进行反向传播。
    ②参数更新: 根据梯度和优化算法,更新模型的参数以最小化损失函数。
    ③Dropout生效: 如果模型中使用了 Dropout 层,那么在训练模式下,Dropout 会生效,即在前向传播过程中会随机舍弃一些神经元,以防止过拟合。
    在 PyTorch 中,通过 model.train() 将模型设置为训练模式:
model.train()
  1. 验证/测试模式(Validation/Testing Mode):
    在验证/测试模式下,模型会执行以下操作:
    ①梯度计算: 不计算梯度,因为在验证/测试过程中不需要更新模型参数。
    ②Dropout不生效: 如果使用了 Dropout 层,那么在验证/测试模式下,Dropout 不生效,所有神经元都参与前向传播。
    ③评估模型性能: 使用模型进行预测,并评估模型在验证集或测试集上的性能。
    在 PyTorch 中,通过 model.eval() 将模型设置为验证/测试模式:
model.eval()

切换模型的运行模式是为了确保在不同阶段使用正确的行为。在训练模式下,模型需要进行梯度计算和参数更新,而在验证/测试模式下,模型不需要进行参数更新,而是专注于性能评估。

拿训练集的几张图进行预测

预测pred.py

from dataset import *

model=torch.load('Bloodcell.pkl')
img,label=next(iter(dl_test)) #选取一些图片进行预测
img=img.to('cuda')
model.eval()
pred=model(img)
pred_re=torch.argmax(pred, dim=1)

pred_re=pred_re.cpu().numpy()
pred_re=pred_re.tolist()

for i in pred_re[0:8]:
    print(id_to_class[i])
id_to_class[pred_re[0:8][1]]

plt.figure(figsize=(16,8))
img=img.cpu()#把图片重新放到CPU上
for i,(img,label) in enumerate(zip(img[:8],label[:8])):
    img=(img.permute(1,2,0).numpy()+1)/2
    plt.subplot(2,4,i+1)
    pred_title=id_to_class[pred_re[0:8][i]]
    plt.title('R:{},P:{}'.format(id_to_class.get(label.item()),pred_title))
    plt.imshow(img)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1419610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mysql使用命令行备份数据

目录 前言1. 基本知识2. 常用参数3. 拓展 前言 由于长期使用测试环境的数据库,时不时会有脏数据删除不干净,对此很需要一个实时将生产库的数据定期备份一份,防止生产库中会有脏数据进来。 1. 基本知识 mysqldump 是MySQL数据库管理系统提供…

HTML+CSS:3D卡片组件

效果演示 实现了一个名为“卡片”的效果,当鼠标悬停在一个特定的元素上时,该元素会变得更亮,并且会在其他元素上方显示一个卡片。当鼠标悬停在卡片上时,卡片会变得更亮,并且会在其他元素上方显示一个提示信息。这个效果…

开源:基于Vue3.3 + TS + Vant4 + Vite5 + Pinia + ViewPort适配..搭建的H5移动端开发模板

vue3.3-Mobile-template 基于Vue3.3 TS Vant4 Vite5 Pinia ViewPort适配 Sass Axios封装 vconsole调试工具,搭建的H5移动端开发模板,开箱即用的。 环境要求: Node:16.20.1 pnpm:8.14.0 必须装上安装pnpm,没装的看这篇…

力扣(leetcode)第118题杨辉三角(Python)

118.杨辉三角 题目链接:118.杨辉三角 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]] …

c++入门语法—————引用,内联函数,auto关键字,基于范围的for循环,nullptr

文章目录 一.引用1.引例2.注意事项3.应用场景1.做参数(a:输出型参数b:内容较大参数)2.做返回值(a:修改返回值,b:减少拷贝) 4.引用和指针的区别 二.内联函数1.为什么有内联函数2.用法和底层3.特性 三.auto关键字1.基础示…

CVE-2024-0352 likeshop v2.5.7文件上传漏洞分析

本次的漏洞研究基于thinkPHP开发开的一款项目..... 漏洞描述 Likeshop是Likeshop开源的一个社交商务策略的完整解决方案,开源免费版基于thinkPHP开发。Likeshop 2.5.7.20210311及之前版本存在代码问题漏洞,该漏洞源于文件server/application/api/contr…

数据库之一 基础概念、安装mysql、sql语句基础

数据库之 基础概念、安装mysql、sql语句基础 【一】存储数据的演变过程: 文件存储: 初始阶段随意存放数据到文件,格式任意。目录规范引入: 软件开发使用目录规范,限制数据位置,建立专门文件夹。本地数据存…

inside 的坑

最近代码里面有一句inside 判断语句,明明条件满足,但是就是判断失败,代码如下: xxx;if(i inside {[7:0]}) begin //i5xxx;end xxx; 翻看sv 手册才发现 inside 后面跟的是range value,必须是从小写到大,也就…

腾讯云Linux(OpenCloudOS)安装tomcat9(9.0.85)

腾讯云Linux(OpenCloudOS)安装tomcat9 下载并上传 tomcat官网 https://tomcat.apache.org/download-90.cgi 下载完成后上传至自己想要放置的目录下 解压文件 输入tar -xzvf apache-tomcat-9.0.85.tar.gz解压文件,建议将解压后的文件重新命名为tomcat,方便后期进…

【vue】defineModel在vue3.4中的最新用法和详解

在2023年12月28日,尤大发布了vue3.4版本,这个版本主要对一些实验性特性的改进(比如defineModel),大量重写了模板编译器并重构了响应式系统,可以说是大大提升了运行速度和效率。 之前在vue3.3中defineModel…

应急消防应用步入“繁花”时代,卓翼智能消防无人机顺势而行大有可为

近日,北京卓翼智能科技有限公司(以下简称“卓翼智能”)宣布完成超亿元B轮融资,融资金额高达2.5亿元。这个“智能无人系统”黑马品牌,凭什么出圈?重点发力在哪些领域呢?今天,带你走进…

Ubuntu 22.04.1 LTS 编译安装 nginx-1.22.1,Nginx动静分离、压缩、缓存、黑白名单、跨域、高可用、性能优化

1.Ubuntu 22.04.1 LTS 编译安装 nginx-1.22.1 1.1安装依赖 sudo apt install libgd-dev 1.2下载nginx wget http://nginx.org/download/nginx-1.22.1.tar.gz 1.3解压nginx tar -zvxf nginx-1.22.1.tar.gz 1.4编译安装 cd nginx-1.22.1 编译并指定安装位置,执行安装…

华为笔记本matebook pro X如何扩容 C 盘空间

一、前提条件 磁盘扩展与合并必须是相邻分区空间,且两个磁盘类型需要相同。以磁盘分区为 C 盘和 D 盘为例,如果您希望增加 C 盘容量,可以先将 D 盘合并到 C 盘,然后重新创建磁盘分区,分配 C 盘和 D 盘的空间大小。 访…

2024 新年HTML5+Canvas制作3D烟花特效(附源码)

个人名片: 🐼作者简介:一名大三在校生,喜欢AI编程🎋 🐻‍❄️个人主页🥇:落798. 🐼个人WeChat:hmmwx53 🕊️系列专栏:🖼️…

Ajax入门与使用

目录 ◆ AJAX 概念和 axios 使用 什么是 AJAX? 怎么发送 AJAX 请求? 如何使用axios axios 函数的基本结构 axios 函数的使用场景 1 没有参数的情况 2 使用params参数传参的情况 3 使用data参数来处理请求体的数据 4 上传图片等二进制的情况…

上海亚商投顾:创业板指创调整新低,全市场超4800只个股下跌

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日震荡调整,创业板指午后跌超3%,深成指跌超2%,北证50指数跌逾6%。中…

sqli-labs-master less-1 详解

目录 关于MySQL的一些常识 information_schema 常用的函数 sqli-labs-master less-1 分析PHP源码 测试 关于MySQL的一些常识 information_schema information_schema 是 MySQL 数据库中的一个元数据(metadata)数据库,它包含…

LLM之makeMoE:makeMoE的简介、安装和使用方法、案例应用之详细攻略

LLM之makeMoE:makeMoE的简介、安装和使用方法、案例应用之详细攻略 目录 makeMoE的简介 1、对比makemore 2、相关代码文件 makMoE_from_Scratch.ipynb文件 makeMoE_Concise.ipynb文件 makeMoE的安装和使用方法 1、基于Databricks使用单个A100进行开发 makeM…

Mybatis 获取自增主键ID的几种方式

Mybatis 获取添加的自增主键ID的几种方式 需求实现1. 使用 GeneratedKeys2. 获取 Sequence 序号3. 使用 selectKey 标签 需求 很多时候新增了一条数据之后,不仅要知道是否插入成功,还需要获取存入之后的主键id 以便后续使用。通常的办法是:先…

C# IP v4转地址·地名 高德

需求: IPv4地址转地址 如:输入14.197.150.014,输出河北省石家庄市 SDK: 目前使用SDK为高德地图WebAPI 高德地图开放平台https://lbs.amap.com/ 可个人开发者使用,不过有配额限制。 WebAPI 免费配额调整公告https://lbs.amap.com/news/…