Pytorch实现CIFAR10训练模型

news2025/1/19 17:23:37

文章目录

  • 简述
  • 模型结构
  • 模型参数、优化器、损失函数
    • 参数初始化
    • 优化器
    • 损失函数
  • 模型训练、测试集预测、模型保存、日志记录
    • 训练
    • 测试集测试
    • 模型保存
    • 模型训练完整代码
  • tensorboard训练可视化结果
    • train_loss
    • 测试准确率
    • 测试集loss
  • 模型应用
    • 模型独立应用代码`api.py`
    • 预测结果

简述

使用pytorch实现一个用于训练CIFAR10的模型,在训练过程中使用CIFAR10的测试数据集记录准确度。训练结束后,搜集一些图片,单独实现对训练后模型的应用代码。

另外会在文中尽量给出各种用法的官方文档链接。

代码分为:

  1. 模型训练代码train.py,包含数据加载、模型封装、训练、tensorboard记录、模型保存等;
  2. 模型应用代码api.py,包含对训练所保存模型的加载、数据准备、结果预测等;

注意:

本文目的是使用pytorch来构建一个结构完善的模型,体现出pytorch的各种功能函数、模型设计理念,来学习深度学习,而非训练一个高精度的分类识别模型。

不足:

  1. 参数初始化或许可以考虑kaiming(因为用的是ReLU);
  2. 可以加上k折交叉验证;
  3. 训练时可以把batch_size的图片加入tensorboard,文中batch_size=256,若每个batch_size都加的话数据太多了,所以文中是每逢整百的训练次数时记录一下该批次的loss值,加图片的话可以在该代码处添加;

模型结构

来源:https://www.researchgate.net/profile/Yiren-Zhou-6/publication/312170477/figure/fig1/AS:448817725218816@1484017892071/Structure-of-LeNet-5.png

在这里插入图片描述

在上述图片基础上增加了nn.BatchNorm2dnn.ReLU以及nn.Dropout,最终结构如下:

layers = nn.Sequential(  
    # shape(3,32,32) -> shape(32,32,32)  
    nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  
    nn.BatchNorm2d(32),  
    nn.ReLU(),  
    # shape(32,32,32) -> shape(32,16,16)  
    nn.MaxPool2d(kernel_size=2, stride=2),  
  
    # shape(32,16,16) -> shape(32,16,16)  
    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  
    nn.BatchNorm2d(32),  
    nn.ReLU(),  
    # shape(32,16,16) -> shape(32,8,8)  
    nn.MaxPool2d(kernel_size=2, stride=2),  
  
  
    # shape(32,8,8) -> shape(64,8,8)  
    nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  
    nn.BatchNorm2d(64),  
    nn.ReLU(),  
    # shape(64, 8, 8) -> shape(64,4,4)  
    nn.MaxPool2d(kernel_size=2, stride=2),  
  
    # shape(64,4,4) -> shape(64 * 4 * 4,)  
    nn.Flatten(),  
    nn.Linear(64 * 4 * 4, 64),  
    nn.ReLU(),  
    nn.Dropout(0.5),  
    nn.Linear(64, 10)  
)

可以看看使用tensorboard的writer.add_graph函数实现的模型结构图:

在这里插入图片描述

模型参数、优化器、损失函数

参数初始化

模型参数使用nn.init.normal_作初始化,但模型中存在ReLU,应考虑使用kaiming He初始化。

apply函数:Module — PyTorch 2.4 documentation

参数初始化函数:torch.nn.init — PyTorch 2.4 documentation

def init_normal(m):  
    # 考虑使用kaiming  
    if m is nn.Linear:  
        nn.init.normal_(m.weight, mean=0, std=0.01)  
        nn.init.zeros_(m.bias)

# 定义模型、数据初始化  
net = CIFAR10Net()  
net.apply(init_normal)

优化器

优化器使用Adam,即MomentumAdaGrad的结合。

文档:Adam — PyTorch 2.4 documentation

# 优化器  
weight_decay = 0.0001

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

损失函数

分类任务,自然是用交叉熵损失函数了。

loss_fn = nn.CrossEntropyLoss()

模型训练、测试集预测、模型保存、日志记录

注意,代码前面部分代码有定义
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

训练

net.train()  
for images, labels in train_loader:  
    images, labels = images.to(device), labels.to(device)  
  
    outputs = net(images)  
    loss = loss_fn(outputs, labels)  
  
    # 优化器处理  
    optimizer.zero_grad()  
    loss.backward()  
    optimizer.step()  
  
    total_train_step += 1  
    if total_train_step % 100 == 0:  
        print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  
        writer.add_scalar('train_loss', loss.item(), total_train_step)  
        current_time = time.time()  
        writer.add_scalar('train_time', current_time-start_time, total_train_step)

测试集测试

net.eval()  
total_test_loss = 0  
total_test_acc = 0  # 整个测试集正确个数  
with torch.no_grad():  
    for images, labels in test_loader:  
        images, labels = images.to(device), labels.to(device)  
        outputs = net(images)  
        loss = loss_fn(outputs, labels)  
        total_test_loss += loss.item()  
  
        accuracy = (outputs.argmax(1) == labels).sum()  
        total_test_acc += accuracy  
  
print(f'整个测试集loss值和: {total_test_loss:.4f}, batch_size: {batch_size}')  
print(f'整个测试集正确率: {(total_test_acc / test_data_size) * 100:.4f}%')  
writer.add_scalar('test_loss', total_test_loss, epoch + 1)  
writer.add_scalar('test_acc', (total_test_acc / test_data_size) * 100, epoch + 1)

模型保存

torch.save(net.state_dict(),  
           './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))

模型训练完整代码

train.py

import torch  
import torchvision  
from torch.utils.tensorboard import SummaryWriter  
from torchvision import transforms  
from torch.utils import data  
from torch import nn  
import time  
from datetime import datetime  
  
def load_data_CIFAR10(resize=None):  
    """  
    下载 CIFAR10 数据集,然后将其加载到内存中  
    transforms.ToTensor() 转换为形状为C x H x W的FloatTensor,并且会将像素值从[0, 255]缩放到[0.0, 1.0]  
    """    trans = [transforms.ToTensor()]  
    if resize:  
        trans.insert(0, transforms.Resize(resize))  
    trans = transforms.Compose(trans)  
  
    cifar_train = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=False)  
    cifar_test = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=False)  
    return cifar_train, cifar_test  
  
  
class CIFAR10Net(torch.nn.Module):  
    def __init__(self):  
        super(CIFAR10Net, self).__init__()  
        layers = nn.Sequential(  
            # shape(3,32,32) -> shape(32,32,32)  
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(32),  
            nn.ReLU(),  
            # shape(32,32,32) -> shape(32,16,16)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
            # shape(32,16,16) -> shape(32,16,16)  
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(32),  
            nn.ReLU(),  
            # shape(32,16,16) -> shape(32,8,8)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
  
            # shape(32,8,8) -> shape(64,8,8)  
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(64),  
            nn.ReLU(),  
            # shape(64, 8, 8) -> shape(64,4,4)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
            # shape(64,4,4) -> shape(64 * 4 * 4,)  
            nn.Flatten(),  
            nn.Linear(64 * 4 * 4, 64),  
            nn.ReLU(),  
            nn.Dropout(0.5),  
            nn.Linear(64, 10)  
        )  
  
        self.layers = layers  
  
    def forward(self, x):  
        return self.layers(x)  
  
  
def init_normal(m):  
    # 考虑使用kaiming  
    if m is nn.Linear:  
        nn.init.normal_(m.weight, mean=0, std=0.01)  
        nn.init.zeros_(m.bias)  
  
  
if __name__ == '__main__':  
    # 超参数  
    epochs = 6  
    batch_size = 256  
    learning_rate = 0.01  
    num_workers = 0  
    weight_decay = 0  
  
    # 数据记录  
    total_train_step = 0  
    total_test_step = 0  
    train_loss_list = list()  
    test_loss_list = list()  
    train_acc_list = list()  
    test_acc_list = list()  
  
    # 准备数据集  
    train_data, test_data = load_data_CIFAR10()  
    train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)  
    test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)  
    train_data_size = len(train_data)  
    test_data_size = len(test_data)  
    print(f'训练测试集长度: {train_data_size}, 测试数据集长度: {test_data_size}, batch_size: {batch_size}\n')  
  
    # device = torch.device("cpu")  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
    print(f'\ndevice: {device}')  
  
    # 定义模型、数据初始化  
    net = CIFAR10Net().to(device)  
    # net.apply(init_normal)  
    # 损失函数  
    loss_fn = nn.CrossEntropyLoss().to(device)  
    # 优化器  
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)  
  
    # now_time = datetime.now()  
    # now_time = now_time.strftime("%Y%m%d-%H%M%S")    # tensorboard    writer = SummaryWriter('./train_logs')  
    # 随便定义个输入, 好使用add_graph  
    tmp = torch.rand((batch_size, 3, 32, 32)).to(device)  
    writer.add_graph(net, tmp)  
  
    start_time = time.time()  
    for epoch in range(epochs):  
        print('------------Epoch {}/{}'.format(epoch + 1, epochs))  
  
        # 训练  
        net.train()  
        for images, labels in train_loader:  
            images, labels = images.to(device), labels.to(device)  
  
            outputs = net(images)  
            loss = loss_fn(outputs, labels)  
  
            # 优化器处理  
            optimizer.zero_grad()  
            loss.backward()  
            optimizer.step()  
  
            total_train_step += 1  
            if total_train_step % 100 == 0:  
                print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  
                writer.add_scalar('train_loss', loss.item(), total_train_step)  
                current_time = time.time()  
                writer.add_scalar('train_time', current_time-start_time, total_train_step)  
  
        # 测试  
        net.eval()  
        total_test_loss = 0  
        total_test_acc = 0  # 整个测试集正确个数  
        with torch.no_grad():  
            for images, labels in test_loader:  
                images, labels = images.to(device), labels.to(device)  
                outputs = net(images)  
                loss = loss_fn(outputs, labels)  
                total_test_loss += loss.item()  
  
                accuracy = (outputs.argmax(1) == labels).sum()  
                total_test_acc += accuracy  
  
        print(f'整个测试集loss值和: {total_test_loss:.4f}, batch_size: {batch_size}')  
        print(f'整个测试集正确率: {(total_test_acc / test_data_size) * 100:.4f}%')  
        writer.add_scalar('test_loss', total_test_loss, epoch + 1)  
        writer.add_scalar('test_acc', (total_test_acc / test_data_size) * 100, epoch + 1)  
  
        torch.save(net.state_dict(),  
                   './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))  
  
    writer.close()

tensorboard训练可视化结果

train_loss

纵轴为每个batch_size损失值,横轴为训练次数,其中batch_size = 256。

在这里插入图片描述

测试准确率

纵轴为整个CIFAR10测试集的准确率(%),横轴为epoch,其中epochs=50。

在这里插入图片描述

测试集loss

纵轴为CIFAR10整个测试集的每个batch_size的loss之和,batch_size = 256。横轴为epoch,其中epochs=50。

在这里插入图片描述

模型应用

模型训练过程中,每个epoch保存一次模型。

torch.save(net.state_dict(),  './save/epoch_{}_params_acc_{}.pth'.format(epoch+1, (total_test_acc / test_data_size)))  

这里实现一个,将保存的模型加载,并对自行搜集的图片进行预测。

项目结构:

  1. ./autodl_save/cuda_params_acc_75.pth:训练时保存的模型参数文件;

  2. ./test_images:网上搜集的卡车、狗、飞机、船图片,大小不一,保存时未作处理,如下:
    在这里插入图片描述

  3. api.py:实现图片的预处理(裁剪、ToTensor、封装为数据集等)、模型加载、图片推理等;

模型独立应用代码api.py

import os  
  
import torch  
import torchvision  
from PIL import Image  
from torch import nn  
  
  
class CIFAR10Net(torch.nn.Module):  
    def __init__(self):  
        super(CIFAR10Net, self).__init__()  
        layers = nn.Sequential(  
            # shape(3,32,32) -> shape(32,32,32)  
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(32),  
            nn.ReLU(),  
            # shape(32,32,32) -> shape(32,16,16)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
            # shape(32,16,16) -> shape(32,16,16)  
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(32),  
            nn.ReLU(),  
            # shape(32,16,16) -> shape(32,8,8)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
            # shape(32,8,8) -> shape(64,8,8)  
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),  
            nn.BatchNorm2d(64),  
            nn.ReLU(),  
            # shape(64, 8, 8) -> shape(64,4,4)  
            nn.MaxPool2d(kernel_size=2, stride=2),  
  
            # shape(64,4,4) -> shape(64 * 4 * 4,)  
            nn.Flatten(),  
            nn.Linear(64 * 4 * 4, 64),  
            nn.ReLU(),  
            nn.Dropout(0.5),  
            nn.Linear(64, 10)  
        )  
  
        self.layers = layers  
  
    def forward(self, x):  
        return self.layers(x)  
  
  
def build_data(images_dir):  
    image_list = os.listdir(images_dir)  
    image_paths = []  
    for image in image_list:  
        image_paths.append(os.path.join(images_dir, image))  
  
    transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),  
                                                torchvision.transforms.ToTensor()])  
  
    # 存储转换后的张量  
    images_tensor = []  
    for image_path in image_paths:  
        try:  
            # 加载图像并转换为 RGB(如果它已经是 RGB,这步是多余的)  
            image_pil = Image.open(image_path).convert('RGB')  
            # 应用转换并添加到列表中  
            images_tensor.append(transform(image_pil))  
        except IOError:  
            print(f"Cannot open {image_path}. Skipping...")  
  
    # 转换列表为单个张量,如果需要的话  
    # 注意:这里假设所有图像都被成功加载和转换  
    if images_tensor:  
        # 使用 torch.stack 来合并张量列表  
        images_tensor = torch.stack(images_tensor)  
    else:  
        # 如果没有图像,返回一个空的张量或根据需要处理  
        images_tensor = torch.empty(0, 3, 32, 32)  
    return images_tensor, image_list  
  
  
def predict(state_dict_path, image):  
    net = CIFAR10Net()  
    net.load_state_dict(torch.load(state_dict_path))  
    net.cuda()  
  
    with torch.no_grad():  
        image = image.cuda()  
        output = net(image)  
  
    return output  
  
  
if __name__ == '__main__':  
    images, labels = build_data("./test_images")  
  
    outputs = predict("./autodl_save/cuda_params_acc_75.pth", images)  
  
    # 选取结果(即得分最大的下标)  
    res = outputs.argmax(dim=1)  
  
    kinds = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']  
  
    for i in range(len(res)):  
        classes_idx = res[i]  
        print(f'文件(正确标签): {labels[i]},  预测结果: {classes_idx}, {kinds[classes_idx]}\n')

预测结果

7个识别出4个。

在这里插入图片描述

注意这个索引和标签的对应关系可以从数据集中查看。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2074398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode1232一点小问题

解法 a x 2 − x 1 y 2 − y 1 , b y 1 − a x 1 a\frac{x_{2}-x_{1}}{y_{2}-y_{1}} ,by_{1}-ax_{1} ay2​−y1​x2​−x1​​,by1​−ax1​ d y n − y n − 1 x n − x n − 1 d\frac{y_{n}-y_{n-1}}{x_{n}-x_{n-1}} dxn​−xn−1​yn​…

【初阶数据结构】链表题的证明

环形链表题目方法的证明 证明1:为什么快指针每次⾛两步,慢指针⾛⼀步可以相遇,有没有可能遇不上,请推理证明! 证明二:为什么相遇点(meet)和头结点(head)到入环…

sql server导入mysql,使用工具SQLyog

概述 需要将sql server的数据导入到mysql中,由于2种数据库存在各种差异,比如表字段类型就有很多不同,因此需要工具来实现。 这里使用SQLyog来实现。 SQLyog安装 安装过程参考文档:https://blog.csdn.net/Sunshine_liang1/article/…

USART之串口发送+接收应用案例

文章目录 前言一、电路接线图二、应用案例代码三、应用案例分析3.1 USART模块初始化3.1.1 RCC开启时钟3.1.2 GPIO初始化3.1.3 配置USART3.1.4 开启中断、配置NVIC3.1.5 开启USART 3.2 USART串口收发模块3.2.1 Serial_SendByte(发送一个字节数据)3.2.2 US…

JVM对象创建和内存分配机制深度解析

一、对象创建方式 1、new关键字 这是最常见的创建对象的方式。通过调用类的构造方法(constructor)来创建对象。如:MyClass obj new MyClass()。这种方式会触发类的加载、链接、初始化过程(如果类还未被加载过的话)&…

递归搜索与回溯专题篇一

目录 组合 目标和 组合总和 字母大小全排列 组合 题目 思路 解决这道题利用DFS,决策树是怎样的?以n4,k3为例: 因为每个数只用到一次,因此需要剪枝,将出现重复数字的枝剪掉,因为组合中元素的…

Vue中的this.$emit()方法详解【父子组件传值常用】

​在Vue中,this.$emit()方法用于触发自定义事件。它是Vue实例的一个方法,可以在组件内部使用。 使用this.$emit()方法,你可以向父组件发送自定义事件,并传递数据给父组件。父组件可以通过监听这个自定义事件来执行相应的逻辑。 …

【PyQt6 应用程序】QTDesigner生成ui文件转成py源码并执行

要使用Qt Designer设计的UI界面生成Python代码并执行需要遵循几个步骤。确保已经安装了PyQt6和Qt Designer。Qt Designer是一个强大的工具,允许通过拖放组件来设计GUI界面,而不需要手写所有的代码。安装PyQt6时 Qt Designer通常会一起被安装。 文章目录 使用Qt Designer设计U…

米联客FDMA3.2源码分析以及控制BRAM、DDR3读写验证

文章目录 一、FDMA简介二、读写操作时序2.1 写时序2.2 读时序 三、FDMA源码分析四、源码仿真验证4.1 FDMA控制代码4.2 系统框图4.3 仿真结果4.3.1 写通道4.3.2 读通道 五、使用FDMA控制BRAM读写测试5.1 系统框图5.2 读写数据控制模块5.3 仿真结果5.4 下板验证 六、使用FDMA控制…

快讯 | 美军500天AI计划启动,“破解AI“与“反AI“策略亮相

在数字化浪潮的推动下,人工智能(AI)正成为塑造未来的关键力量。硅纪元视角栏目紧跟AI科技的最新发展,捕捉行业动态;提供深入的新闻解读,助您洞悉技术背后的逻辑;汇聚行业专家的见解,…

VBA之正则表达式(46)-- 解析业务逻辑公式

实例需求:某业务系统的逻辑公式如下所示(单行文本),保存在活动工作表的A1单元格中。 "DSO_90Day"->"FA_NoFunc"->"FCCS_No Intercompany"->"FCCS_Data Input"->"FCCS_…

<数据集>非洲动物识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:1504张 标注数量(xml文件个数):1504 标注数量(txt文件个数):1504 标注类别数:4 标注类别名称:[buffalo, elephant, rhino, zebra] 序号类别名称图片数框数1buffalo3…

Java生成一个5位的随机验证码(大小写字母和数字)

生成验证码 内容:可以是小写字母,也可以是大写字母,还可以是数字 规则:长度为5 内容中四位字母,一位数字 其中数字只有一位,但是可以出现在任意位置。 package test;impo…

arm-Pwn环境搭建+简单题目

前言 起因是看到一篇IOT CVE的分析文章。 正好也在学pwn,arm架构的也是IOT这些固件最常用的,所以先安一个arm-pwn的环境。 环境搭建/调试 1. 安装 gdb-multiarch sudo apt-get install gdb-multiarch2. 安装qemu ctf的arm_pwn只需要安装qemu-user就…

结构体内存的对齐

结构体的对齐规则 第一个成员在结构体变量偏移量为0的地址处。 其他成员变量要对齐到某个数字(对齐数)的整数倍的地址处 1) 对齐数 min( 编译器默认的一个对齐数, 该成员大小)。 2)默认的对齐数,可以通过宏…

kafka的12个重要概念

kafka的12个重要概念 1、服务器broker1.1、Broker 的主要功能1.2、Kafka Broker 的架构1.3、配置和管理1.4、高可用性和负载均衡1.5、总结 2、主题topic2.1、主要特点 3、事件Event4、生产者producer4.1、主要功能4.2、Producer 的配置选项4.3、Producer 的工作流程4.4、总结 5…

(javaweb)maven高级

目录 ​编辑 1.分模块设计与开发 2.继承与聚合--继承关系实现 3.继承与聚合--版本锁定 4.继承与聚合--聚合版本 5.私服 资源的上传与下载 1.分模块设计与开发 分模块:拆分成多个模块进行开发 不分模块:业务代码堆积成一个 不利于项目管理和维护并…

考研数学|零基础9月开始100天备考攻略

马上就要9月了,很多同学相比快要结束强化了,零基础的同学,进度可能会慢一些,但是别担心,考研数学的学习,进度不是最要紧的,学习效果才是!千万不要比进度,也不要赶进度&am…

Linux中的PCI配置空间

在计算机系统中,PCI(Peripheral Component Interconnect)总线是一种用于连接硬件设备的标准接口。PCI总线提供了一个通用的、高性能的数据传输通道,广泛应用于PC系统和服务器中。在Linux操作系统中,PCI设备的配置空间是…

Modern C++——不准确“类型声明”引发的非必要性能损耗

大纲 案例代码地址 C是一种强类型语言。我们在编码时就需要明确指出每个变量的类型,进而让编译器可以正确的编译。看似C编译器比其他弱类型语言的编译器要死板,实则它也做了很多“隐藏”的操作。它会在尝试针对一些非预期类型进行相应转换,以…