用pytorch实现一个简单的图片预测类别

news2025/2/21 8:35:04

前言:

        在阅读本文之前,你需要了解Python,Pytorch,神经网络的一些基础知识,比如什么是数据集,什么是张量,什么是神经网络,如何简单使用tensorboard,DataLoader。

        本次模型训练使用的是cpu。

目录

创建python文件:

model_train.py文件

1、准备数据集

2、打印看看该数据集的大小

3、加载数据集

4、创建网络模型

model.py文件

5、定义损失函数

6、定义优化器

7、设置一些训练网络所需参数

8、可视化训练过程

9、训练过程

model_vertification文件


创建python文件:

        model.py 自定义神经网络模型。

        model_train.py 训练 CIFAR - 10 数据集上的自定义模型并保存参数。

        model_vertification.py 用一张图片验证网络模型进行预测。

下面从各个文件讲解。

model_train.py文件

完整的模型训练步骤如下:

1、准备数据集

       这里选用 CIFAR10 数据集,这个数据集是 torchvision 里面自带的,一个十分类问题的数据集,该数据集较小(160MB左右),使用torchvision.datasets模块加载 CIFAR10 数据集。

# 下载并加载 CIFAR-10 训练数据集
# root 指定数据集存储的根目录;train=True 表示加载训练集;
# transform 将数据转换为 Tensor 类型;download=True 表示如果数据集不存在则进行下载
train_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root= r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)

2、打印看看该数据集的大小

# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")

        可以看到该数据集有50000张训练图片,10000张测试图片。 

3、加载数据集

        使用DataLoader分别加载训练和测试数据集。

# 使用 DataLoader 对训练数据进行批量加载,batch_size=64 表示每个批次包含 64 个样本
train_dataloader = DataLoader(train_data, batch_size=64)
# 对测试数据进行批量加载
test_dataloader = DataLoader(test_data, batch_size=64)

4、创建网络模型

        将自定义的网络模型放在model.py文件,在train.py中导入使用。

        根据此图片的神经网络模型来自定义一个网络模型。(其中卷积层Conv2d中的参数stride和padding需要经过如下的公式计算得到,该计算并不复杂)。

        计算公式 (pytorch官网torch.nn中Conv2d中查看)

model.py文件

import torch
from torch import nn


class zzy(nn.Module):
    def __init__(self):
        super(zzy, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64,10)
        )

    def forward(self, x):
        x = self.model(x)
        return x


if __name__ == '__main__':
    zzy1 = zzy()
    data = torch.ones((64,3,32,32)) # 网络模型的输入必须为四维张量
    output = zzy1(data)
    print(output.shape)

        运行此文件验证该网络模型是否能得到预想的结果 ,如下。

         在train.py中导入model.py后实例化网络模型。

# 创建网络模型
# 实例化自定义的网络模型 zzy
zzy1 = zzy()

5、定义损失函数

# 定义损失函数
# 使用交叉熵损失函数,常用于多分类问题
loss_fn = nn.CrossEntropyLoss()

6、定义优化器

# 优化器创建
# 学习率设置为 0.01
learning_rate = 1e-2
# 使用随机梯度下降(SGD)优化器,对 zzy1 模型的参数进行优化
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)

7、设置一些训练网络所需参数

# 设置训练网络参数
# 记录总的训练步数
total_train_step = 0
# 记录总的测试步数
total_test_step = 0
# 训练的轮数
epoch = 10

8、可视化训练过程

        为了可视化整个训练过程, 添加 TensorBoard 用于可视化训练过程。

# 在 '../logs_train' 目录下创建 SummaryWriter 对象
writer = SummaryWriter('../logs_train')

9、训练过程

        首先设置我们的训练轮次,这是外层循环,内层循环里分别进行训练数据集和测试数据集的训练。

        在训练集中,每次取出一批数据,即前面定义的64张图片,送入我们定义的网络模型得到输出,计算输出和真实值之间的损失,再进行反向传播更新模型参数进行优化,训练步数加一,再取出下一批数据,重复上面的过程。

        在测试集中,计算正确率。

# 开始训练循环,共训练 epoch 轮
for i in range(epoch):
    print(f'--------第{i + 1}次训练--------')
    # 遍历训练数据加载器中的每个批次
    for data in train_dataloader:
        # 从数据批次中解包图像和对应的标签
        imgs, target = data
        # 将图像输入到模型中进行前向传播,得到模型的输出
        outputs = zzy1(imgs)
        # 计算模型输出与真实标签之间的损失
        loss = loss_fn(outputs, target)

        # 梯度清零,防止梯度累积
        optimier.zero_grad()
        # 反向传播,计算梯度
        loss.backward()
        # 根据计算得到的梯度更新模型的参数
        optimier.step()

        # 训练步数加 1
        total_train_step += 1
        # 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoard
        if total_train_step % 100 == 0:
            print(f'训练次数:{total_train_step},Loss:{loss.item()}')
            # 将训练损失添加到 TensorBoard 中,用于后续可视化
            writer.add_scalar('train_loss', loss.item(), total_train_step)

    zzy.eval()
    # 测试步骤开始
    # 初始化总的测试损失为 0
    total_test_loss = 0
    # 初始化正确率为 0
    total_accuracy = 0
    # 上下文管理器,在测试过程中不进行梯度计算,减少内存消耗
    with torch.no_grad():
        # 遍历测试数据加载器中的每个批次
        for data in test_dataloader:
            # 从数据批次中解包图像和对应的标签
            imgs, targets = data
            # 将图像输入到模型中进行前向传播,得到模型的输出
            outputs = zzy1(imgs)
            # 计算模型输出与真实标签之间的损失
            loss = loss_fn(outputs, targets)
            # 计算正确率,outputs.argmax(1)表示横向看
            accuracy = (outputs.argmax(1) == targets).sum()
            # 累加测试损失
            total_test_loss += loss.item()
            total_accuracy += accuracy
    print(f'整体的测试损失: {total_test_loss}')
    print(f'整体的正确率: {total_accuracy/test_data_size}')
    # 将测试损失添加到 TensorBoard 中,用于后续可视化
    writer.add_scalar('test_loss', total_test_loss, total_test_step)
    writer.add_scalar('total_accuracy',total_accuracy/test_data_size,total_test_step)
    # 测试步数加 1
    total_test_step += 1
    torch.save(zzy,'zzy_{}.pth'.format(i))
    # 官方推荐的保存方式
    # torch.save(zzy.state_dict(),'zzy_{}.pth'.forma(i))
    print("模型已保存")
# 关闭 SummaryWriter,释放资源
writer.close()

完整的model_train.py代码如下:

import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.utils.data import DataLoader

# 下载并加载 CIFAR-10 训练数据集
train_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 下载并加载 CIFAR-10 测试数据集
test_data = torchvision.datasets.CIFAR10(root=r'D:\Desktop\数据集', train=False, transform=torchvision.transforms.ToTensor(), download=True)

# 计算训练集和测试集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为: {train_data_size}")
print(f"测试数据集的长度为: {test_data_size}")

# 加载数据
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
class zzy(nn.Module):
    def __init__(self):
        super(zzy, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

# 实例化自定义的网络模型 zzy
zzy1 = zzy()

# 定义损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器创建
learning_rate = 1e-2
optimier = torch.optim.SGD(zzy1.parameters(), lr=learning_rate)

# 设置训练网络参数
total_train_step = 0
total_test_step = 0
epoch = 10

# 添加 TensorBoard 用于可视化训练过程
writer = SummaryWriter('../logs_train')

# 开始训练循环,共训练 epoch 轮
for i in range(epoch):
    print(f'--------第{i + 1}次训练--------')
    # 遍历训练数据加载器中的每个批次
    for data in train_dataloader:
        # 从数据批次中解包图像和对应的标签
        imgs, target = data
        # 将图像输入到模型中进行前向传播,得到模型的输出
        outputs = zzy1(imgs)
        # 计算模型输出与真实标签之间的损失
        loss = loss_fn(outputs, target)

        # 梯度清零,防止梯度累积
        optimier.zero_grad()
        # 反向传播,计算梯度
        loss.backward()
        # 根据计算得到的梯度更新模型的参数
        optimier.step()

        # 训练步数加 1
        total_train_step += 1
        # 每训练 100 步,打印一次训练信息并将训练损失写入 TensorBoard
        if total_train_step % 100 == 0:
            print(f'训练次数:{total_train_step},Loss:{loss.item()}')
            # 将训练损失添加到 TensorBoard 中,用于后续可视化
            writer.add_scalar('train_loss', loss.item(), total_train_step)

    # 设置模型为评估模式
    zzy1.eval()
    # 测试步骤开始
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = zzy1(imgs)
            loss = loss_fn(outputs, targets)
            accuracy = (outputs.argmax(1) == targets).sum()
            total_test_loss += loss.item()
            total_accuracy += accuracy

    print(f'整体的测试损失: {total_test_loss}')
    print(f'整体的正确率: {total_accuracy / test_data_size}')
    # 将测试损失添加到 TensorBoard 中,用于后续可视化
    writer.add_scalar('test_loss', total_test_loss, total_test_step)
    writer.add_scalar('total_accuracy', total_accuracy / test_data_size, total_test_step)
    # 测试步数加 1
    total_test_step += 1

    # 保存模型状态字典
    torch.save(zzy1.state_dict(), f'zzy_{i}.pth')
    print("模型已保存")

# 关闭 SummaryWriter,释放资源
writer.close()

运行结果:

        从下图可以看到,最后训练出来的模型预测正确率为0.54左右,不算好,如果想继续优化,加大训练轮次,或者调整学习率。 

        打开tensorboard观察到整个训练过程的变化,图中深色线是经过平滑处(Smoothed)的训练损失值,能更清晰呈现损失总体变化趋势,减少波动干扰;浅色线代表原始的训练损失值,反映每个训练步骤上即时的损失情况,波动相对较大。

 

后续补充:

        想查看整个训练所用时间,可以导入time模块,设置一下开始训练时间和结束训练时间求差。

        当我把训练次数加大10倍(100次)后,模型预测的正确率为0.63左右,相对0.53没有提高很多,而且训练轮次较多时或者数据量较大时用cpu计算的时间花费就比较多了 (不要参考下面的时间,中途暂停过较长时间)。

        到中期训练了40轮次后,从正确率的变化可以看出,模型效果不佳 。

        当然可能有很多原因导致,数据量不足,网络模型结构不合理,优化器选择不合理等等,这里不过多赘述。

model_vertification文件 

        网上随便找了一个狗狗的图片保存为image.png,我们想要验证该网络模型的预测效果。

        注意:

        1、 模型训练时,模型保存方式和加载该模型的方式要对应。

        2、 需要将图片更改为网络模型能够处理的shape。

完整代码:

import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter
from model import  *
# 打开图片,使用绝对路径
image_path = r'D:\Desktop\deep_learning\pytorch入门\images\image.png'
image = Image.open(image_path)
print(image.size)

# 保留颜色通道
image = image.convert('RGB')

# 定义图像变换
# 首先转化尺寸,再转化为tensor类型
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])

# 应用变换
image = transform(image)
print(f'变换后的{image.shape}')
# 增加批量维度,以匹配模型输入要求
# 图像数据,常见的输入形状是 (batch_size, channels, height, width)
# 示例说明:假设 image 是一个形状为 (3, 32, 32) 的张量,代表一张 3 通道、高度为 32、宽度为 32 的图像。
# 当调用 image.unsqueeze(0) 后,得到的新张量形状将变为 (1, 3, 32, 32),
# 这里的 1 就是新插入的批量维度,表示这个批量中只有一张图像。
# unsqueeze 只能插入一个大小为 1 的维度
image = image.unsqueeze(0)
# 或者使用这种方式来更改图像shape
# image = torch.reshape(image, (1, 3, 32, 32))

print(image.shape)



# 实例化模型
model = zzy()

# 加载模型的状态字典
model_path = r'D:\Desktop\deep_learning\model_train\zzy_9.pth'  # 确保路径正确
model.load_state_dict(torch.load(model_path))
model.eval()

# 进行前向传播
# 不要漏掉 with torch.no_grad():
# 我们的目标仅仅是根据输入数据得到模型的预测结果
# 并不需要更新模型的参数,计算梯度是不必要的开销。
with torch.no_grad():
    output = model(image)

# 获取预测的类别
# _,表示一个占位符,只关心另一个值的输出
_, predicted = torch.max(output.data, 1)

# CIFAR-10 数据集的类别名称
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 打印预测结果
print(f"预测的类别是: {classes[predicted.item()]}")

运行结果:

        正确预测到了这个类别为dog。

        作者水平有限,有任何问题或错误,欢迎留言,我将持续分享深度学习相关的内容,你的投币点赞是我最大的创作动力!

        本文代码也可以在我的github上直接下载。https://github.com/Zik-code/CIFAR-10-model_train/tree/main/model_train。
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2299442.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习框架探秘|TensorFlow:AI 世界的万能钥匙

在人工智能(AI)蓬勃发展的时代,各种强大的工具和框架如雨后春笋般涌现,而 TensorFlow 无疑是其中最耀眼的明星之一。它不仅被广泛应用于学术界的前沿研究,更是工业界实现 AI 落地的关键技术。今天,就让我们…

Linux 服务器部署deepseek

把手教你在linux服务器部署deepseek,打造专属自己的数据库知识库 正文开始 第一步:安装Ollama 打开官方网址:https://ollama.com/download/linux 下载Ollama linux版本 复制命令到linux操作系统执行 [rootpostgresql ~]# curl -fsSL http…

DeepSeek、Kimi、文心一言、通义千问:AI 大语言模型的对比分析

在人工智能领域,DeepSeek、Kimi、文心一言和通义千问作为国内领先的 AI 大语言模型,各自展现出了独特的特点和优势。本文将从技术基础、应用场景、用户体验和价格与性价比等方面对这四个模型进行对比分析,帮助您更好地了解它们的特点和优势。…

CSDN、markdown环境下如何插入各种图(流程图,时序图,甘特图)

流程图 横向流程图 mermaid graph LRA[方形] --> B{条件a}B -->|满足| C(圆角)B -->|不满足| D(圆角)C --> E[输出结果1]D --> E效果图: 竖向流程图 mermaid graph TDC{条件a} --> |a1| A[方形]C --> |a2| F[竖向流程图]A --> B(圆角)B …

unity学习40:导入模型的 Animations文件夹内容,动画属性和修改动画文件

目录 1 Animations文件夹内容 2 每个模型文件的4个标签 3 model 4 rig 动画类型 5 Animation 5.1 新增动画和修改动画 5.2 限制动画某个轴的变化,烘焙 6 material 材料 1 Animations文件夹内容 下面有很多文件夹每个文件夹都是不同的动作模型每个文件夹下…

web第三次作业

弹窗案例 1.首页代码 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>综合案例</title><st…

GMSL 实例1:当 MAX96717 遇上 MAX96724,打通 Camera 视频数据传输

新年伊始&#xff0c;继 Deepseek 在 AI 圈掀起风波之后。比亚迪在2月10日发布会上重磅官宣&#xff1a;全系车型将搭载自研的高阶智驾系统“天神之眼”&#xff0c;覆盖从10万元级入门车型到高端豪华车型的所有范围。此举如一颗重磅炸弹投向当前一卷再卷的新能源汽车赛道&…

DeepSeek 助力 Vue 开发:打造丝滑的侧边栏(Sidebar)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

基于opencv的 24色卡IQA评测算法源码-可完全替代Imatest

1.概要 利用24色卡可以很快的分析到曝光误差&#xff0c;白平衡误差&#xff0c;噪声&#xff0c;色差&#xff0c;饱和度&#xff0c;gamma值。IQA或tuning工程一般用Imatest来手动计算&#xff0c;不便于产测部署&#xff0c;现利用opencv实现了imatest的全部功能&#xff0c…

数据结构与算法之排序算法-(计数,桶,基数排序)

排序算法是数据结构与算法中最基本的算法之一&#xff0c;其作用就是将一些可以比较大小的数据进行有规律的排序&#xff0c;而想要实现这种排序就拥有很多种方法~ &#x1f4da; 非线性时间比较类&#xff1a; 那么我将通过几篇文章&#xff0c;将排序算法中各种算法细化的&a…

MATLAB图像处理:图像特征概念及提取方法HOG、SIFT

图像特征是计算机视觉中用于描述图像内容的关键信息&#xff0c;其提取质量直接影响后续的目标检测、分类和匹配等任务性能。本文将系统解析 全局与局部特征的核心概念&#xff0c;深入讲解 HOG&#xff08;方向梯度直方图&#xff09;与SIFT&#xff08;尺度不变特征变换&…

kibana es 语法记录 elaticsearch

目录 一、认识elaticsearch 1、什么是正向索引 2、什么是倒排索引 二、概念 1、说明 2、mysql和es的对比 三、mapping属性 1、定义 四、CRUD 1、查看es中有哪些索引库 2、创建索引库 3、修改索引库 4、删除索引库 5、新增文档 6、删除文档 5、条件查询 一、认识…

手写一个Java Android Binder服务及源码分析

手写一个Java Android Binder服务及源码分析 前言一、Java语言编写自己的Binder服务Demo1. binder服务demo功能介绍2. binder服务demo代码结构图3. binder服务demo代码实现3.1 IHelloService.aidl3.2 IHelloService.java&#xff08;自动生成&#xff09;3.3 HelloService.java…

【动态规划篇】:当回文串遇上动态规划--如何用二维DP“折叠”字符串?

✨感谢您阅读本篇文章&#xff0c;文章内容是个人学习笔记的整理&#xff0c;如果哪里有误的话还请您指正噢✨ ✨ 个人主页&#xff1a;余辉zmh–CSDN博客 ✨ 文章所属专栏&#xff1a;动态规划篇–CSDN博客 文章目录 一.回文串类DP核心思想&#xff08;判断所有子串是否是回文…

Windows 安装 GDAL 并配置 Rust-GDAL 开发环境-1

Rust-GDAL 是 Rust 语言的 GDAL&#xff08;Geospatial Data Abstraction Library&#xff09; 绑定库&#xff0c;用于处理地理数据。由于 GDAL 依赖较多&#xff0c;在 Windows 上的安装相对复杂&#xff0c;本文档将介绍如何安装 GDAL 并配置 Rust-GDAL 的开发环境。 1. 检…

第1期 定时器实现非阻塞式程序 按键控制LED闪烁模式

第1期 定时器实现非阻塞式程序 按键控制LED闪烁模式 解决按键扫描&#xff0c;松手检测时阻塞的问题实现LED闪烁的非阻塞总结补充&#xff08;为什么不会阻塞&#xff09; 参考江协科技 KEY1和KEY2两者独立控制互不影响 阻塞&#xff1a;如果按下按键不松手&#xff0c;程序就…

开源语音克隆项目 OpenVoice V2 本地部署

#本机环境 WIN11 I5 GPU 4060ti 16G 内存 32G #开始 git clone https://github.com/myshell-ai/OpenVoice.git conda create -n opvenv python3.9 -y conda activate opvenv pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/…

DeepSeek大模型一键部署解决方案:全平台多机分布式推理与国产硬件优化异构计算私有部署

DeepSeek R1 走红后&#xff0c;私有部署需求也随之增长&#xff0c;各种私有部署教程层出不穷。大部分教程只是简单地使用 Ollama、LM Studio 单机运行量化蒸馏模型&#xff0c;无法满足复杂场景需求。一些操作配置也过于繁琐&#xff0c;有的需要手动下载并合并分片模型文件&…

如何利用PLM软件有效地推进制造企业标准化工作?

在智能制造浪潮的推动下&#xff0c;中国制造业正面临从“规模扩张”向“质量提升”的关键转型。工信部数据显示&#xff0c;85%的制造企业在产品研发、生产过程中因标准化程度不足导致效率损失超20%&#xff0c;而标准化水平每提升10%&#xff0c;企业综合成本可降低5%-8%。如…

环境影响评价(EIA)中,土地利用、植被类型及生态系统图件的制作

在环境影响评价&#xff08;EIA&#xff09;中&#xff0c;土地利用、植被类型及生态系统图件的制作需依据科学、法规和技术规范&#xff0c;以确保数据的准确性和图件的规范性。以下是主要的制作依据&#xff1a; 1. 法律法规与政策依据 《中华人民共和国环境影响评价法》 明确…