Pytorch实战01——CIAR10数据集

news2025/1/15 13:26:05

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的
        self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的
        self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2
        self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。
                                           # 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的

    def forward(self, x):   # 正向传播
        # Pytorch Tensor的通道排序:[channel,height,width]
        '''
            卷积后的尺寸大小计算:
                N = (W-F+2P)/S + 1
                其中,默认的padding:0   stride:1
                    ①输入图片大小:WxW
                    ②Filter大小 FxF  (卷积核大小)
                    ③步长S
                    ④padding的像素数P
        '''
        x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)
        x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;
        x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32
        x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)

        x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1代表第一个维度
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transforms

from pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet

# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose(
    [transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]
     transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=transform,
                                        download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=transform,
                                       download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)

test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

'''
    标准化处理:output = (input - 0.5) / 0.5
  反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))


# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮
    running_loss = 0.0  # 叠加,训练过程中的损失
    for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本
        inputs, labels = data   # 获取图像及其对应的标签
        optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加

        outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播
        loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():  # with 是一个上下文管理器
                outputs = net(test_image)  # [batch,10]
                predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值
                print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0
print('Finished Training')

save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。

im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。

with torch.no_grad():  # 不需要计算损失梯度
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;
                                    # torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。
                                    # .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()
                                    # .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])

#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1513831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day57 动态规划part17● 647. 回文子串 ● 516.最长回文子序列● 动态规划总结篇

如果大家做了很多这种子序列相关的题目,在定义dp数组的时候 很自然就会想题目求什么,我们就如何定义dp数组。 布尔类型的dp[i][j]:表示区间范围[i,j] (注意是左闭右闭)的子串是否是回文子串,如果是dp[i][j…

C++学习路线

C学习路线思维导图,肝了一个星期终于搞定,这么硬核求个赞不过分吧? 思维导图的内容,也是本文的内容框架,坐稳扶好, C 高速快车要发车了! 内容我会持续更新,点赞收藏,…

Window系统下Vscode配置C/Cpp运行+调试环境

Window系统下Vscode配置C/Cpp运行调试环境 文章目录 Window系统下Vscode配置C/Cpp运行调试环境1.安装Vscode2.安装C/Cpp插件3.配置gcc编译器4.配置Cpp运行环境5.配置Cpp调试环境 1.安装Vscode 安装VScode很简单,直接到官网进行下载,然后傻瓜安装即可。 …

读CDO代码

前置任务 module PIL.Image has no attribute ANTIALIASImage.ANTIALIAS 替换为 Image.LANCZOS,参考https://blog.csdn.net/fovever_/article/details/134690657 OSError: science is not a valid package style, path of style file, URL of这是对应可视化里面生…

GraphView实现测量工具

效果演示: 主模块代码: MeasureGraphView::MeasureGraphView(QWidget *parent): QWidget(parent) {ui.setupUi(this);m_measureType NoType;m_bll BllData::getInstance();m_scene new GraphicsScene;connect(m_bll, &BllData::pressMeasurePos…

力扣106 从中序与后续遍历序列构造二叉树

文章目录 题目描述解题思路代码 题目描述 给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入:inorder [9,3,15,20,7], …

【单片机毕业设计7-基于stm32c8t6的智能温室大棚系统】

【单片机毕业设计7-基于stm32c8t6的智能温室大棚系统】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 🔥这里是小殷学长,单片机毕业设计篇7基于stm32的智能衣柜系统 🧿创作不易,拒绝白嫖可私 一、功能介绍 ---------------…

Sqllab第一关通关笔记

知识点: 明白数值注入和字符注入的区别 数值注入:通过数字运算判断,1/0 1/1 字符注入:通过引号进行判断,奇数个和偶数个单引号进行识别 联合查询:union 或者 union all 需要满足字段数一致&…

基础小白快速入门opencv-------C++ 在opencv的应用以及opencv的下载配置

啥是opencv? OpenCV(开源计算机视觉库 Open Source Computer Vision Library)是一个跨平台的计算机视觉库,最初由Intel开发,现在由一个跨国团队维护。它免费提供给学术和商业用途,并且是用C语言写成的&…

doit,一个非常实用的 Python 库!

更多Python学习内容:ipengtao.com 大家好,今天为大家分享一个非常实用的 Python 库 - doit。 Github地址:https://github.com/pydoit/doit 在软件开发和数据处理过程中,经常会遇到需要执行一系列任务的情况,这些任务可…

华为数通方向HCIP-DataCom H12-821题库(多选题:161-180)

第161题 以下关于IPv6优势的描述,正确的是哪些项? A、底层自身携带安全特性 B、加入了对自动配置地址的支持,能够无状态自动配置地址 C、路由表相比IPv4会更大,寻址更加精确 D、头部格式灵活,具有多个扩展头 【参考答案】ABD 【答案解析】 第162题 在OSPF视图下使用Filt…

做伦敦银要等怎样的价格与行情?

对于不同的伦敦银投资者来说,合适的入市价格和好的行情机会,标准可能并不一样,因为不同人有不同的交易策略、风险偏好和盈利目标。对于喜欢做趋势跟踪的投资者来说,一波明显而持续的上涨或下跌趋势,可能就是最好的行情…

yolov5-v6.0详细解读

yolov5-v6.0详细解读 一、yolov5版本介绍二、网络结构2.1 Backbone特征提取部分2.1.1 ConvBNSiLU模块2.1.2 C3模块2.1.2.1 BottleNeck模块 2.1.3 SPPF模块 2.2 Neck特征融合部分2.2.1 FPN2.2.2 PANet 2.3Head模块 三、目标框回归3.1 yolo标注格式3.2 yolov4目标回归框3.3 yolov…

数据结构-链表(二)

1.两两交换列表中的节点 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 输入:head [1,2,3,4] 输出:[2…

框架漏洞Shiroweblogicfastjson || 免杀思路

继续来讲一下我们的框架漏洞,先讲一下Shiro 1.Shiro反序列化 1.原理 Shiro的漏洞形成呢,就是因为存在了RememberMe这样的一个字段 Shiro 框架在处理 "rememberMe" 功能时使用了不安全的反序列化方法,攻击者可以构造恶意序列化数据&#xff0…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的危险物品检测系统(深度学习模型+PySide6界面+训练数据集+Python代码)

摘要:本文深入介绍了一个采用深度学习技术的危险物品识别系统,该系统融合了最新的YOLOv8算法,并对比了YOLOv7、YOLOv6、YOLOv5等早期版本的性能。该系统在处理图像、视频、实时视频流及批量文件时,能够准确识别和分类各种危险物品…

9、组合模式(结构性模式)

组合模式又叫部分整体模式,它创建了对象组的树形结构,将对象组合成树状结构,以一致的方式处理叶子对象以及组合对象,不以层次高低定义类,都是结点类 一、传统组合模式 举例,大学、学院、系,它们…

勾八头歌之数据科学导论—数据预处理

第1关:引言-根深之树不怯风折,泉深之水不会涸竭 第2关:数据清理-查漏补缺 import numpy as np import pandas as pd import matplotlib.pyplot as pltdef student():# Load the CSV file and replace #NAME? with NaNtrain pd.read_csv(Tas…

人工智能的幽默“失误”

人工智能迷惑行为大赏 随着ChatGPT热度的攀升,越来越多的公司也相继推出了自己的AI大模型,如文心一言、通义千问等。各大应用也开始内置AI玩法,如抖音的AI特效~在使用过程中往往会遇到一些问题,让你不得不怀疑&#x…

反向传播 — 简单解释

一、说明 关于反向传播,我有一个精雕细刻的案例计划,但是实现了一半,目前没有顾得上继续充实,就拿论文的叙述这里先起个头,我后面将修改和促进此文的表述质量。 二、生物神经元 大脑是一个由大约100亿个神经元组成的复…