【深度学习实战—6】:基于Pytorch的血细胞图像分类(通用型图像分类程序)

news2025/1/24 17:33:27

✨博客主页:米开朗琪罗~🎈
✨博主爱好:羽毛球🏸
✨年轻人要:Living for the moment(活在当下)!💪
🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法】

目录

  • 😺一、数据集介绍
  • 😺二、工程文件夹目录
  • 😺三、option.py
  • 😺四、getdata.py
  • 😺五、utils.py
  • 😺六、model.py
  • 😺七、train.py
  • 😺八、evaluate.py
  • 😺九、pth2onnx.py
  • 😺十、onnx_inference.py

图像分类是搞深度学习一定要掌握的一个视觉任务,本文章将基于血细胞数据集实现图像分类!

本文程序已解耦,可当做通用型图像分类框架使用。

数据集下载地址:Blood Cell Images

😺一、数据集介绍

从 kaggle 上下载到数据集后解压可以得到两个文件夹,分别是dataset-masterdataset2-master

其中dataset-master的 JPEGImages 中包含了血细胞的原始图像,而且没有对血细胞进行分类,在 Annotations 文件夹内包含了对应 JPEGImages 中的每张图像血细胞的.xml格式的定位标签,也就是说,该文件夹是用来做目标检测的。

而在dataset2-master中的 images 文件夹中,包含了TRAINTESTTEST_SIMPLE三种文件夹,且这三种文件夹下包含了血细胞的四种类别,分别是:EOSINOPHIL、LYMPHOCYTE、MONOCYTE、NEUTROPHIL

但需要注意的是,在TRAINTEST文件夹下的图像,是已经经过数据增强之后的了,而TEST_SIMPLE文件夹下的图像并没有经过数据增强,因此我们将TRAINTESTTEST_SIMPLE三种文件夹分别用作训练集、验证集和测试集。即:

  • TRAIN——train(训练集)
  • TEST——val(验证集)
  • TEST_SIMPLE——test(测试集)

在这里插入图片描述

😺二、工程文件夹目录

我的工程文件夹目录如下,可以看到有很多的py文件,每个py文件具有不同的功能,这么写的好处是未来修改程序更加方便,而且每个py程序都没有很长。如果全部写到一个py程序里,则会显得很臃肿,修改起来也不轻松。
在这里插入图片描述
对每个文件的解释如下:

  • checkpoints:存放训练的模型权重;
  • datasets:存放数据集。并对数据集划分;
  • log_dir:存放训练日志。包括训练、验证时候的损失与精度情况;
  • option.py:存放整个工程下需要用到的所有参数;
  • utils.py:存放各种函数。包括文件夹创建、绘制精度与损失变化情况、结果预测等;
  • getdata.py:构建数据管道。其中定义了计算数据集中所有图形的均值和方差函数;
  • model.py:构建神经网络模型;
  • train.py:训练模型;
  • evaluate.py:评估训练模型。有三种预测方式可以选择,分别是:对单张图像进行预测,对多张图像进行预测,对整个目录下的图片进行预测;
  • pth2onnx:将pth模型转换到onnx模型;
  • onnx_inference.py:使用.onnx模型对数据进行推理。

😺三、option.py

为了方便了解这些参数代表什么意思,在help中,全部使用了中文解释。

import argparse


def get_args():
    parser = argparse.ArgumentParser(description='all argument')
    parser.add_argument('--device', type=str, default='cuda', help='可以选择cuda或者cpu训练,苹果电脑m1芯片也可以选择mps加速训练')
    parser.add_argument('--loadsize', type=int, default=224, help='统一图像尺寸')
    parser.add_argument('--epochs', type=int, default=3, help='总的训练次数')
    parser.add_argument('--batch_size', type=int, default=16, help='每次喂多少数据给到网络')
    parser.add_argument('--lr', type=float, default=1e-2, help='初始学习率')
    parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='训练集路径')
    parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='验证集路径')
    parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='测试集路径')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='模型存放路径')
    parser.add_argument('--log_dir', type=str, default='./log_dir', help='训练日志保存的路径')
    parser.add_argument('--logging_txt', type=str, default='./log_dir/logging.txt', help='训练日志位置')
    parser.add_argument('--pretrained', type=bool, default=False, help='是否要继续上次的训练')
    parser.add_argument('--which_epoch', type=str, default='best.pth', help='如果继续训练,需要加载哪一个模型')
    parser.add_argument('--test_model_path', type=str, default='./checkpoints/best.pth', help='选择一个模型用于测试')
    parser.add_argument('--onnx_path', type=str, default='./checkpoints/best.onnx', help='.onnx模型的存放路径')
    parser.add_argument('--test_img_path', type=str, default='./datasets/test/EOSINOPHIL/_0_5239.jpeg', help='选择一张测试图像')
    parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='选择一个测试路径')
    return parser.parse_args()

😺四、getdata.py

getdata.py中各函数的解释:

  • data_augmentation:该函数用作数据增强,最常使用的是transforms.Resize()transforms.ToTensor()transforms.Normalize()。由于数据集中已经对原始图像进行了数据增强,因此部分参数在下面注释掉了。
    • transforms.Resize():将图像统一尺寸。
    • transforms.ToTensor():维度变换。从 HWC 到 CWH 。
    • transforms.Normalize():图像归一化。归一化的参数需要从get_mean_and_std函数计算得到。
  • MyData:构建数据管道。返回一个字典。
  • imshow:图像可视化。可在构建数据管道后,可视化部分数据。
  • get_mean_and_std:计算图像均值和方差。计算结果放到transforms.Normalize()中。
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from option import get_args
opt = get_args()


def data_augmentation():

    data_transform = {
        'train': transforms.Compose([
            # transforms.RandomRotation(45),  # 随机旋转,角度在-45到45度之间
            # transforms.RandomHorizontalFlip(p=0.5),  # 以0.5的概率水平翻转
            # transforms.RandomVerticalFlip(p=0.5),  # 以0.5的概率垂直翻转
            # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数依次为亮度、对比度、饱和度、色相
            # transforms.RandomGrayscale(p=0.025),  # 以0.025的概率变为灰度图像,3通道即R=G=B
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),  # HWC -> CHW
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])  # 使用均值和标准差标准化三个通道的数据
        ]),
        'val': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])
        ]),
        'test': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])
        ])
    }
    return data_transform


def MyData():

    data_transform = data_augmentation()

    # 读取数据集
    image_datasets = {
        'train': ImageFolder(opt.dataset_train, data_transform['train']),
        'val': ImageFolder(opt.dataset_test, data_transform['val']),
        'test': ImageFolder(opt.dataset_test, data_transform['test'])
    }
    # 构建管道
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=True),
        'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=True),
        'test': DataLoader(image_datasets['test'], batch_size=opt.batch_size, shuffle=True)
    }
    return dataloaders


"""
图像可视化
"""
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.6786, 0.6413, 0.6605])
    std = np.array([0.2599, 0.2595, 0.2569])
    inp = inp * std + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.show()


# 计算数据集所有图像的均值和方差
def get_mean_and_std(dataset):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


if __name__ == '__main__':
    mena_std_transform = transforms.Compose([transforms.ToTensor()])
    dataset = ImageFolder(opt.dataset_train, transform=mena_std_transform)
    print(dataset.class_to_idx)		# 每个类别的索引
    mean, std = get_mean_and_std(dataset)
    print(mean)
    print(std)
    dataloader = MyData()
    inputs, classes = next(iter(dataloader['train']))
    out = make_grid(inputs, nrow=4)     # nrow参数可以选择显示的列数
    class_names = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']
    imshow(out, title=[class_names[x] for x in classes])

运行main函数可以得到:

类别索引:  {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}
==> Computing mean and std..
tensor([0.6786, 0.6413, 0.6605])
tensor([0.2599, 0.2595, 0.2569])

将 opt.batchsize 设为8后,可以得到下图:
在这里插入图片描述

😺五、utils.py

utils.py中各函数的解释:

  • make_dir:创建文件夹。
  • draw_number:绘制损失与精度的变化情况。
  • visual_image_single:单张图像可视化预测。
  • visual_image_multi:多张图像可视化预测。
  • get_confusion_matrix:输出混淆矩阵。用于对整个文件夹进行预测的情况。
  • plot_confusion_matrix:混淆矩阵可视化。
  • get_roc_auc:绘制ROC曲线。
  • visual_img_dir:对整个文件夹进行预测。并得到分类报告、准确率、精确率、召回率、F1得分
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
from scipy import interp
from itertools import cycle
from option import get_args
opt = get_args()


"""
创建文件夹
"""
def make_dir():
    if os.path.exists(opt.log_dir) == True:
        pass
    else:
        os.mkdir(opt.log_dir)
    if os.path.exists(opt.checkpoints) == True:
        pass
    else:
        os.mkdir(opt.checkpoints)



"""
绘制损失与精度的变化情况
"""
def draw_number(epochs, train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt):

    color = ['red', 'blue', 'green', 'orange']
    marker = ['o', '*', 'p', '+']
    linestyle = ['-', '--', '-.', ':']

    plt.plot(epochs, train_loss_plt, color=color[0], marker=marker[0], linestyle=linestyle[0], label="trainingsets-loss")
    plt.plot(epochs, train_acc_plt, color=color[1], marker=marker[1], linestyle=linestyle[1], label="trainingsets-acc")
    plt.plot(epochs, val_loss_plt, color=color[2], marker=marker[2], linestyle=linestyle[2], label="validationsets-loss")
    plt.plot(epochs, val_acc_plt, color=color[3], marker=marker[3], linestyle=linestyle[3], label="validationsets-acc")

    plt.legend()
    plt.xlabel("epochs")
    plt.ylabel("value")
    plt.title("Loss and accuracy changes in training and validation sets")
    plt.savefig("Loss_Accuracy.jpg")
    plt.show()


"""
单张图像可视化预测
"""
def visual_image_single(img_path, transform_test, model, class_names):
    image = Image.open(img_path).convert('RGB')
    img = transform_test(image)
    img = img.unsqueeze_(0)
    out = model(img)
    pred_softmax = F.softmax(out, dim=1)        # 对 logit 分数做 softmax 运算
    top_n = torch.topk(pred_softmax, len(class_names))
    confs = top_n[0].cpu().detach().numpy().squeeze().tolist()      # 所有类别的预测概率
    confs_max = max(confs)      # 最大概率值
    confs_max_position = confs.index(confs_max)     # 最大概率值所在的位置
    print('Pre:{}   Conf:{:.3f}'.format(class_names[confs_max_position], confs_max))
    plt.axis('off')
    plt.title('Pre:{}   Conf:{:.3f}'.format(class_names[confs_max_position], confs_max))
    plt.imshow(image)
    plt.show()


"""
多张图像可视化预测
"""
def visual_image_multi(dataloader, model, class_names):
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(images)):
                plt.subplot(4, 4, i + 1)
                plt.title("Prediction:{}\nTarget:{}".format(class_names[predicted[i]], class_names[labels[i]]), fontsize=8)
                img = images[i].swapaxes(0, 1)
                img = img.swapaxes(1, 2)
                plt.imshow(img)
                plt.axis('off')
            plt.show()


"""
对整个文件夹进行预测, 并输出混淆矩阵
"""
def get_confusion_matrix(trues, preds, labels):
    conf_matrix = confusion_matrix(trues, preds, labels=[i for i in range(len(labels))])
    return conf_matrix

def plot_confusion_matrix(conf_matrix, labels):
    plt.imshow(conf_matrix, cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    indices = range(conf_matrix.shape[0])
    plt.xticks(indices, labels)
    plt.yticks(indices, labels)
    plt.colorbar()
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    # 显示数据
    for first_index in range(conf_matrix.shape[0]):
        for second_index in range(conf_matrix.shape[1]):
          plt.text(first_index, second_index, conf_matrix[first_index, second_index])
    plt.savefig('heatmap_confusion_matrix.jpg')
    plt.show()


def get_roc_auc(trues, preds, labels):
    nb_classes = len(labels)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(nb_classes):
        fpr[i], tpr[i], _ = roc_curve(trues[:, i], preds[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr["micro"], tpr["micro"], _ = roc_curve(trues.ravel(), preds.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(nb_classes)])) 

    mean_tpr = np.zeros_like(all_fpr)
    for i in range(nb_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= nb_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    lw = 2
    plt.figure()
    plt.plot(fpr["micro"], tpr["micro"],label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]),color='deeppink', linestyle=':', linewidth=4)
    plt.plot(fpr["macro"], tpr["macro"],label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]),color='navy', linestyle=':', linewidth=4)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])
    for i, color in zip(range(nb_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw, label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.savefig("ROC_多分类.jpg")
    plt.show()

def visual_img_dir(dataloader, model, class_names):
    """
    normalize: True:显示百分比, False: 显示个数
    """
    y_pred = []
    y_true = []
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

        accuracy = accuracy_score(y_true, y_pred)  # 准确率 值所有判断正确的数据(TP+TN)占总量的比例。
        precision = precision_score(y_true, y_pred, average='macro')  # 精确率 所有被判定为正类(TP+FP)中,真实的正类(TP)占的比例。
        recall = recall_score(y_true, y_pred, average='macro')  # 召回率 所有真实为正类(TP+FN)中,被判定为正类(TP)占的比例。
        f1 = f1_score(y_true, y_pred, average='macro')  # f1-score 它赋予Precision score和Recall Score相同的权重,以衡量其准确性方面的性能,使其成为准确性指标的替代方案(它不需要我们知道样本总数)。
        conf_matrix = get_confusion_matrix(y_true, y_pred, labels=class_names)
        print('分类报告:\n', classification_report(y_true, y_pred))  # 分类报告
        print("[accuracy:{:.4f}]  [precision:{:.4f}]  [recall:{:.4f}]  [f1:{:.4f}]".format(accuracy, precision, recall, f1))
        plot_confusion_matrix(conf_matrix, labels=class_names)

        test_trues = label_binarize(y_true, classes=[i for i in range(len(class_names))])
        test_preds = label_binarize(y_pred, classes=[i for i in range(len(class_names))])
        get_roc_auc(test_trues, test_preds, class_names)

😺六、model.py

我们可以自定义一个分类网络,也可以使用现有的经典分类网络,如resnet50,在使用resnet50时,可以选择冻结部分网络层,即冻结的网络层不可再被训练,仅使用其网络结构,网络参数是早已学习好的;也可以选择冻结所有层;也可以选择不冻结任何层。在迁移学习的时候,需要注意最后的分类层。血细胞分类共有4类,而resnet50最后的全连接层有1000个神经元输出,所以需要修改最后一层全连接层,将其输出改为4。

import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
from option import get_args
opt = get_args()

class My_CNN(nn.Module):
    def __init__(self):
        super(My_CNN, self).__init__()
        self.conv1_1 = nn.Sequential(nn.Conv2d(3, 16, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(16))
        self.conv1_2 = nn.Sequential(nn.Conv2d(16, 32, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(32))
        self.conv2_1 = nn.Sequential(nn.Conv2d(32, 32, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(32))
        self.conv2_2 = nn.Sequential(nn.Conv2d(32, 64, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(64))
        self.conv3_1 = nn.Sequential(nn.Conv2d(64, 64, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(64))
        self.conv3_2 = nn.Sequential(nn.Conv2d(64, 128, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(128))

        self.linear_1 = nn.Linear(28 * 28 * 128, 80)
        self.linear_2 = nn.Linear(80, 4)

    def forward(self, x):
        in_size = x.size(0)
        x = self.conv1_1(x)
        x = self.conv1_2(x)
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv3_1(x)
        x = self.conv3_2(x)
        x = x.view(in_size, -1)
        x = self.linear_1(x)
        out = self.linear_2(x)
        return out


"""
使用预训练模型 1 ————微调模型
使用预训练的模型来初始化网络,而非随机初始化网络,并且权重可以随着训练的进行而发生改变,步骤如下:
--(1)替换输出层。将模型的最后一个全连接层替换为新的全连接层;
--(2)训练输出层。新的输出层会将前面的层所提取出的低级特征映射到我们所期望的类别的概率;
--(3)训练输出层之前的层。也就是将这些层的权重标记为需要求导。

固定模型的参数 2 ————微调模型
固定预训练模型的参数,将模型除了输出层之外的所有层看作一个特征提取器。在训练模型的时候,这些层的权重不参与训练,不可优化。
"""
def ResNet():
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    """
    可选择仅冻结某一层或者全部冻结
    """
    # for name, layer in model.named_children():  # 仅冻结layer1层
    #     if name == "layer1":
    #         for param in layer.parameters():
    #             param.requires_grad = False
    #
    # for param in model.parameters():    # 冻结所有层,锁定模型所有参数,所有层设置为不可训练的模式。
    #     param.requires_grad = False

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 4)
    return model

if __name__ == '__main__':
    model = ResNet()
    print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))

😺七、train.py

train.py解释如下:

  • make_dir():从 utils 中调用函数,目的是如果当前工程目录下不存在相应的文件夹(log_dircheckpoints),则主动创建,如果已经存在,则不做处理。
  • file = open(opt.logging_txt, 'w'):创建.txt文件,后续将写入训练过程的相关信息,包括损失与精度的变化情况。
  • writer = SummaryWriter():SummaryWriter 类将条目直接写入指定文件夹中的事件文件,以供 TensorBoard 使用。在程序运行时,会在工程目录下自动新建一个 run 文件夹,用于存储训练过程。在 run 文件夹下使用终端,输入tensorboard –logdir=run可以在网页中查看网络训练过程。
  • train_best:定义训练过程的函数。
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.nn as nn
from model import My_CNN, ResNet
from getdata import MyData
from utils import draw_number, EarlyStopping, make_dir
from option import get_args
opt = get_args()

make_dir()
file = open(opt.logging_txt, 'w')
writer = SummaryWriter()

def train_best(model, num_epoch, dataloaders, optimizer, loss_function):

    model.to(opt.device)
    train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt = [], [], [], []  # 将训练和验证过程的损失和精度保留下来,用于绘制折线图

    for epoch in range(start_epoch, opt.epochs):
        print("---------开始第{}/{}轮训练---------".format(epoch, opt.epochs))
        for phase in ['train', 'val']:

            loss_sum, acc_sum = 0, 0
            step = 0            # 将数据全部取完, 记录每一个batch
            all_step = 0        # 记录取了多少个数据

            for (inputs, labels) in tqdm(dataloaders[phase], position=0):
                if phase == 'train':
                    model.train()
                if phase == 'val':
                    model.eval()
                inputs = inputs.to(opt.device)
                labels = labels.to(opt.device)
                optimizer.zero_grad()  # 梯度清零,防止累加

                a = inputs.size(0)  # 每一批次拿了多少张图像
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                _, pred = torch.max(outputs, 1)  # 返回每一行的最大值和其索引
                loss.backward()
                optimizer.step()

                loss_sum += loss.item() * inputs.size(0)  # 损失
                acc_sum += torch.sum(pred == labels.data)

                step += 1
                all_step += a

                print("[Epoch: {}/{}]  [step = {}]  [{}_loss = {:.3f}, {}_acc = {:.3f}]".
                      format(epoch, opt.epochs, all_step, phase, loss_sum / all_step, phase, acc_sum.double() / all_step))

            # 保留每一个epoch后的训练损失与精度
            if phase == 'train':
                train_loss = loss_sum / len(dataloaders[phase].dataset)
                train_acc = acc_sum.double() / len(dataloaders[phase].dataset)
                train_acc = np.float32(train_acc.cpu().numpy())
                train_loss_plt.append(train_loss)
                train_acc_plt.append(train_acc)

            else:
                val_loss = loss_sum / len(dataloaders[phase].dataset)
                val_acc = acc_sum.double() / len(dataloaders[phase].dataset)
                val_acc = np.float32(val_acc.cpu().numpy())
                val_loss_plt.append(val_loss)
                val_acc_plt.append(val_acc)

                writer.add_scalars('loss', {'train': train_loss, 'val': val_loss}, global_step=epoch + 1 - start_epoch)
                writer.add_scalars('acc', {'train': train_acc, 'val': val_acc}, global_step=epoch + 1 - start_epoch)
                writer.close()

        print("EPOCH = {}/{}  train_loss = {:.3f}, train_acc = {:.3f}, val_loss = {:.3f}, val_acc = {:.3f} \n".
              format(epoch, num_epoch, train_loss, train_acc, val_loss, val_acc))
        file.write("EPOCH = {}/{}  train_loss = {:.3f}, train_acc = {:.3f}, val_loss = {:.3f}, val_acc = {:.3f} \n".
              format(epoch, num_epoch, train_loss, train_acc, val_loss, val_acc))

        state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        if epoch % 2 == 0:
            torch.save(state, opt.checkpoints + 'model_{}.pth'.format(epoch))

    draw_number(np.arange(0, opt.epoch-start_epoch, 1), train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt)


if __name__ == '__main__':
    model = ResNet()
    # model = nn.DataParallel(model)      # 多卡并行训练解开这句注释
    model.to(opt.device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if opt.pretrained:
        checkpoint = torch.load(opt.checkpoints + opt.which_epoch)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

    dataloaders = MyData()

    train_best(model, opt.epochs, dataloaders, optimizer, loss_function)

😺八、evaluate.py

evaluate.py需要注意:

  • class_names:必须要和数据管道的标签对应。也就是getdata.py运行得到的类别索引。

    类别索引: {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}

  • main 函数内的visual_image_single:将每次弹出一张预测结果

  • main 函数内的visual_image_multi:将每次弹出opt.batch_size张预测结果,可以通过修改opt.batch_size改变预测数量,同时可以跳转到utils.py里的visual_image_multi函数中,通过修改plt.subplot()中的参数,可以控制预测结果的排列分布,例如 4 行 4 列 或者 2 行 8 列 等。

  • main 函数内的visual_img_dir:将得到ROC曲线图,混淆矩阵图、各种评估指标等。

from model import My_CNN, ResNet
from getdata import MyData, data_augmentation
import torch.utils.data
from option import get_args
from utils import visual_image_single, visual_image_multi, visual_img_dir


opt = get_args()

model = ResNet()
ckpt = torch.load(opt.test_model_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.eval()

data_transform = data_augmentation()        # 测试单张图像使用
transform_test = data_transform['test']

dataloaders = MyData()                      # 测试多张图像和文件夹使用
dataloader = dataloaders['test']

class_names = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']

if __name__ == '__main__':

    # visual_image_single(opt.test_img_path, transform_test, model, class_names)
    # visual_image_multi(dataloader, model, class_names)
    visual_img_dir(dataloader, model, class_names=class_names)

程序运行结果如下所示:
visual_image_single
在这里插入图片描述
visual_image_multi
在这里插入图片描述
visual_img_dir
在这里插入图片描述
在这里插入图片描述

分类报告:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00         4
           3       0.68      1.00      0.81        48

    accuracy                           0.68        71
   macro avg       0.17      0.25      0.20        71
weighted avg       0.46      0.68      0.55        71

[accuracy:0.6761]  [precision:0.1690]  [recall:0.2500]  [f1:0.2017]

😺九、pth2onnx.py

evaluate.py需要注意:

模型转换时,需要指定模型的输入大小,即input变量。

import torch
from torch.autograd import Variable
import onnx
from model import My_CNN, ResNet
from option import get_args
opt = get_args()

model = ResNet()
ckpt = torch.load(opt.test_model_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.eval()
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))

torch.onnx.export(model, input, opt.onnx_path, input_names=input_name, output_names=output_name, verbose=True)

# check .onnx model
onnx_model = onnx.load(opt.onnx_path)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

程序运行后就可以在checkpoints文件夹下发现.onnx文件。

😺十、onnx_inference.py

使用onnx模型进行推理。
注意在推理前,要把opt.batch_size改为 1。

import numpy as np
import onnxruntime
import time
from getdata import MyData
from option import get_args
opt = get_args()

def infer_test(model_path, data_loader, device):
    if device == 'cpu':
        print("using CPUExecutionProvider")
        session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    else:
        print("using CUDAExecutionProvider")
        session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])

    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    total = 0.0
    correct = 0
    start_time = time.time()
    for batch, data in enumerate(data_loader):
        X, y = data
        X = X.numpy()
        y = y.numpy()

        output = session.run([output_name], {input_name: X})[0]
        y_pred = np.argmax(output, axis=1)

        if y[0] == y_pred[0]:
            correct += 1
        total += 1
    end_time = time.time()
    print(end_time - start_time)
    print("accuracy is {}%".format(correct / total * 100.0))


def main():
    input_model_path = opt.onnx_path
    device = input("cpu or gpu?")
    dataloaders = MyData()
    infer_test(input_model_path, dataloaders['test'], device)


if __name__ == "__main__":
    main()

推理结果如下所示:

cpu or gpu?cpu
using CPUExecutionProvider
1.8580236434936523
accuracy is 67.6056338028169%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/996053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL Server2022版本 + SSMS安装教程(手把手安装教程)

SqlServer安装步骤如下: 下载请点点我 1.选择Developer版 2.点击浏览器下载标志,找到SqlServer在文件夹的位置。 3.进入下面的界面:选择自定义版本 4.将下载的C盘的盘符改为D盘 5.点击安装–进入下一步,显示如下界面。 过…

构建知识库的核心要义,试试我的工具和方法吧!

对于企业来说,如果所有人员的知识和经验,都集中沉淀到统一的文档系统中,那么,久而久之,会形成一个丰富的知识和经验库。 在构建知识库之前,我们要先确定知识库范围包含:素材整理、问题提炼、知识…

【Selenium】webdriver.ChromeOptions()官方文档参数

Google官方Chrome文档,在此记录一下 Chrome Flags for Tooling Many tools maintain a list of runtime flags for Chrome to configure the environment. This file is an attempt to document all chrome flags that are relevant to tools, automation, benchm…

7、Spring之依赖注入源码解析(下)

resolveDependency()实现 该方法表示,传入一个依赖描述(DependencyDescriptor),该方法会根据该依赖描述从BeanFactory中找出对应的唯一的一个Bean对象。 @Nullable Object resolveDependency(DependencyDescriptor descriptor, @Nullable String requestingBeanName,@Null…

拦截器学习

什么是拦截器 Spring MVC 中的拦截器( Interceptor )类似于ServLet中的过滤器( Filter ),它主要用于拦截用户请求并作出相应的处理。例如通过拦截器可以进行权限验证、记录请求信息的日志、判断用户是否登录等。 工作原理 一个拦截器,只有 preHandle …

mac电脑安装paste教程以及重新安装软件后不能使用解决方法

问题背景 mac电脑安装paste教程以及重新安装软件后不能使用解决方法。 mac电脑安装paste失败,安装好后还是无法使用,paste显示还是历史粘贴信息,导致无法使用。新 copy的内容也无法进入历史粘贴版里面。 笔者电脑配置信息:MacB…

7.13 在SpringBoot中 正确使用Validation实现参数效验

文章目录 前言引入Maven依赖一、POST/PUT RequestBody参数校验1.1 Valid或Validated注解配合constraints注解1.2 测试运行 二、GET/DELETE RequestParam参数校验2.1 Validated注解配合constraints注解2.2 测试运行 三、GET 无注解参数校验3.1 Valid或Validated注解配合constrai…

【OpenCV • c++】直方图计算 | 绘制 H-S 直方图 | 绘制一维直方图 | 绘制 RGB 三色直方图

文章目录 一、什么是直方图二、直方图的相关函数1、计算直方图 calcHist()2、找寻最值 minMaxLoc() 三、程序演示1、色调 —— 饱和度直方图2、一维直方图3、RGB 三色直方图 一、什么是直方图 直方图广泛应用于很多计算机视觉处理当中。通过标记帧与帧之间显著的边缘和颜色的变…

【数据结构】 二叉搜索树的实现

文章目录 🍀二叉搜索树的概念🛬二叉搜索树功能实现🚩查找关键字key📌代码实现: 🚩插入关键字key📌代码实现: 🚩删除关键字key📌代码实现: &#x…

OpenCV(三十四):轮廓外接最大、最小矩形和多边形拟合

目录 1.轮廓外接最大矩形boundingRect() 2.轮廓外接最小矩形minAreaRect() 3.轮廓外接多边形approxPolyDP() 1.轮廓外接最大矩形boundingRect() Rect cv::boundingRect ( InputArray array ) array:输入的灰度图像或者2D点集&#xff0c;数据类型为vector<Point>或者M…

OpenCV实现图像的混合

原理 这其实也是加法&#xff0c;但是不同的是两幅图像的权重不同&#xff0c;这就会给人一种混合或者透明的感觉。 图像混合的计算公式如下: g(x)(1-a)f0(x) af1(x) 通过修改α的值(0→1) &#xff0c;可以实现非常炫酷的混合。 现在我们把两幅图混合在一起。 第一幅图…

C 风格文件输入/输出---无格式输入/输出---(std::fputc,std::putc,std::fputs)

C 标准库的 C I/O 子集实现 C 风格流输入/输出操作。 <cstdio> 头文件提供通用文件支持并提供有窄和多字节字符输入/输出能力的函数&#xff0c;而 <cwchar>头文件提供有宽字符输入/输出能力的函数。 无格式输入/输出 写字符到文件流 std::fputc, std::putc in…

软件测试面试遇到之redis要怎么测试?

软件测试面试遇到&#xff1a;redis要怎么测试&#xff1f; 首先我们需要知道&#xff0c;redis是什么&#xff1f;它能做什么&#xff1f; redis是一个 key-value 类型的高速存储数据库。redis常被用做&#xff1a;缓存、队列、发布订阅等。 所以&#xff0c;“redis要怎么…

第18章_瑞萨MCU零基础入门系列教程之GPT

本教程基于韦东山百问网出的 DShanMCU-RA6M5开发板 进行编写&#xff0c;需要的同学可以在这里获取&#xff1a; https://item.taobao.com/item.htm?id728461040949 配套资料获取&#xff1a;https://renesas-docs.100ask.net 瑞萨MCU零基础入门系列教程汇总&#xff1a; ht…

基于SSM的学生管理系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

日志平台搭建第一章:Linux 安装elasticsearch-7.5.1

相关链接 官⽹&#xff1a; https://www.elastic.co/cn/downloads/elasticsearch 下载&#xff1a; wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.5.1-linux-x86_64.tar.gz 分词器&#xff1a; https://github.com/medcl/elasticsearch-an…

来看看Sublime Text运行Python程序(包含下载和安装)

py Sublime Text 是一款流行的文本编辑器&#xff0c;它体积小、运行速度快、文本功能强大、可以运行在 Windows、Linux 和 Mac OS X 平台上。 在程序员眼里&#xff0c;Sublime Text 还是一款非常好用的代码编辑器&#xff0c;它支持运行 C/C、Python、Java 等多种语言编写的程…

PyCharm集成开发环境安装、启动与设置

作为非开发工程师职业,大家多多少少都会对编程有抵触,其实没有必要对Python有太大的“戒心" ,把Python当做你的一个工具就可以了。——扎克伯格 一、Python的定义&#xff1a; Python是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。Python的设计具有…

21.添加websocket模块

这里默认读者了解websocket协议&#xff0c;若是还不了解可以看下这篇文章wesocket协议。 websocket主要有三个步骤&#xff0c;1通过HTTP进行握手连接&#xff0c;2进行双向通信&#xff0c;3.协商断开连接 第一步的握手连接需要HTTP&#xff0c;所以还需要使用到上一节讲解…

《C++ primer plus》精炼(OOP部分)——对象和类(2)

“学习是人类成长的喷泉。” - 亚里士多德 文章目录 内联函数对象的方法和属性构造函数和析构函数构造函数的种类使用构造函数析构函数列表初始化 const成员函数this指针对象数组类作用域作用域为类的常量类作用域内的枚举 内联函数 定义位于类声明中的函数自动成为内联函数。…