01——LenNet网络结构,图片识别

news2024/11/18 18:35:04

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的
        self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的
        self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2
        self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。
                                           # 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的

    def forward(self, x):   # 正向传播
        # Pytorch Tensor的通道排序:[channel,height,width]
        '''
            卷积后的尺寸大小计算:
                N = (W-F+2P)/S + 1
                其中,默认的padding:0   stride:1
                    ①输入图片大小:WxW
                    ②Filter大小 FxF  (卷积核大小)
                    ③步长S
                    ④padding的像素数P
        '''
        x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)
        x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;
        x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32
        x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)

        x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1表示第一个维度batch需要自动推理
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transforms

from pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet

# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose(
    [transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]
     transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=transform,
                                        download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=transform,
                                       download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)

test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

'''
    标准化处理:output = (input - 0.5) / 0.5
  反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))


# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮
    running_loss = 0.0  # 叠加,训练过程中的损失
    for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本
        inputs, labels = data   # 获取图像及其对应的标签
        optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加

        outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播
        loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():  # with 是一个上下文管理器
                outputs = net(test_image)  # [batch,10]
                predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值
                print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0
print('Finished Training')

save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

测试下展示图片:

运行下,train.py文件,看下正确率、损失率:

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。

im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。

with torch.no_grad():  # 不需要计算损失梯度
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;
                                    # torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。
                                    # .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()
                                    # .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])

#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1521558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c语言的字符串函数详解

文章目录 前言一、strlen求字符串长度的函数二、字符串拷贝函数strcpy三、链接或追加字符串函数strcat四、字符串比较函数strcmp五、长度受限制字符函数六、找字符串2在字符串1中第一次出现的位置函数strstr七、字符串切割函数strtok(可以切割分隔符)八、…

基于springboot实现酒店客房管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现酒店客房管理平台系统演示 摘 要 随着人们的物质水平的提高,旅游业和酒店业发展的速度越来越快。近年来,市面上酒店的数量和规模都在不断增加,如何提高酒店的管理效率和服务质量成为了一个重要的问题。伴随着信息技术的发…

CSS中如何设置单行或多行内容超出后,显示省略号

1. 设置超出显示省略号 css设置超出显示省略号可分两种情况: 单行文本溢出显示省略号…多行文本溢出显示省略号… 但使用的核心代码是一样的:需要先使用 overflow:hidden;来把超出的部分隐藏,然后使用text-overflow:ellipsis;当文本超出时…

mybatis源码阅读系列(一)

源码下载 mybatis 初识mybatis MyBatis 是一个优秀的持久层框架,它支持定制化 SQL、存储过程以及高级映射。MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以使用简单的 XML 或注解用于配置和原始映射,将接口和 Java 的…

JDK8和JDK11在Ubuntu18上切换(解决nvvp启动报错)

本文主要介绍JDK8和JDK11在Ubuntu18上切换,以供读者能够理解该技术的定义、原理、应用。 🎬个人简介:一个全栈工程师的升级之路! 📋个人专栏:计算机杂记 🎀CSDN主页 发狂的小花 🌄人…

docker login 阿里云失败??

docker login 阿里云失败?? 首先参考 阿里云官方文档《Docker登录、推送和拉取失败常见问题》 看看是否是下面提到的情况: 我遇到的情况是超时: [rootk8snode1 software]# sudo docker login --usernametyleryun registry.cn-hangzhou.ali…

sqllab第十八关通关笔记

知识点: UA注入 不进行url解析,不能使用 %20 编码等操作出现在User-agent字段中一般为insert语句 insert 表名(字段1,字段2,。。。) values(数据1,数据2,。。。) 通过admin admin进行登录发现页面打印出了…

Oracle数据库:使用 bash脚本 + 定时任务 自动备份数据

Oracle数据库:使用 bash脚本 定时任务 自动备份数据 1、前言2、为什么需要自动化备份?3、编写备份脚本4、备份脚本授权5、添加定时任务6、重启 crond / 检查 crond 服务状态7、备份文件检查 💖The Begin💖点点关注,收…

Golang实现Redis分布式锁(Lua脚本+可重入+自动续期)

Golang实现Redis分布式锁(Lua脚本可重入自动续期) 1 概念 应用场景 Golang自带的Lock锁单机版OK(存储在程序的内存中),分布式不行 分布式锁: 简单版:redis setnx》加锁设置过期时间需要保证原…

3.Redis命令

Redis命令 Redis 根据命令所操作对象的不同, 可以分为三大类: 对 Redis 进行基础性操作的命令,对 Key 的操作命令,对 Value 的操作命令。 1.1 Redis 首先通过 redis-cli 命令进入到 Redis 命令行客户端,然后再运行下…

横向移动 --> PTT(Kerberos)

好不容易到了周末,终于有时间来写自己的东西了,那么就来讲一下PTT吧 目录 1.PTT(Past The Ticket) 2.Golden Ticket 1.Krbtgt的NTLM hash 2.获取域的sid 3.查看要伪造的管理员 4.查看域控名字 5.查看并且清除票据 6.制造黄金票据 3.Sliver Ticke…

Python 基础语法:基本数据类型(字典)

为什么这个基本的数据类型被称作字典呢?这个是因为字典这种基本数据类型的一些行为和我们日常的查字典过程非常相似。 通过汉语字典查找汉字,首先需要确定这个汉字的首字母,然后再通过这个首字母找到我们所想要的汉字。这个过程其实就代表了…

【Algorithms 4】算法(第4版)学习笔记 18 - 4.4 最短路径

文章目录 前言参考目录学习笔记0:引入介绍1:APIs1.1:API:加权有向边1.2:Java 实现:加权有向边1.3:API:加权有向图1.4:Java 实现:加权有向图1.5:AP…

NVidia NX 中 ROS serial软件包的安装

自己装的ROS是noetic版本,受限于网络,直接用命令安装串口包不行。于是手动安装了一次。 1 下载源码 git clone https://github.com/wjwwood/serial.git 或者直接在浏览器里面输入 https://github.com/wjwwood/serial.git 2 解压 然后在serial&#xf…

【考研数学】高等数学总结

文章目录 第一章 极限 函数 连续1.1 极限存在准则及两个重要极限1.1.1 夹逼定理1.1.1.1 数列夹逼定理1.1.1.2函数夹逼定理 1.1.2 两个重要极限1.1.2.1 极限公式11.1.2.1.1 证明1.1.2.1.2 数列的单调有界收敛准则1.1.2.1.2.1 二项式定理1.1.2.1.2.2 证明 1.1.2.2 极限公式21.1.2…

未来洞见:亚信安慧AntDB在数据可靠性上的愿景

和国外成熟稳定的商业数据库相比,国产数据库在性能、稳定性、生态等方面存在一定差距,我国数据库的自主可控替换,也不是简单的以库换库,而是用新体系替换旧体系,在架构、研发、上线、运维等方面,全面降低对…

Pyqt5中,QGroupBox组件标题字样(标题和内容样式分开设置)相对于解除继承

Python代码示例: import sys from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QGroupBox, QLabelclass MyApp(QWidget):def __init__(self):super().__init__()# 创建一个 QVBoxLayout 实例layout QVBoxLayout()# 创建 QGroupBox 实例self.grou…

系统分析与设计作业 --- 酒店管理系统(2~3周)

第二周 作业一: (1)需求分析NABCD 我们的项目是一个酒店管理系统,所i对应的NABCD描述 NABCD是一种产品描述框架,用于全面阐述产品的各个方面。其中,N代表需求(Need),描…

5_springboot_shiro_jwt_多端认证鉴权_禁用Cookie

1. Cookie是什么 ​ Cookie是一种在客户端(通常是用户的Web浏览器)和服务器之间进行状态管理的技术。当用户访问Web服务器时,服务器可以向用户的浏览器发送一个名为Cookie的小数据块。浏览器会将这个Cookie存储在客户端,为这个Co…

字符串分割(C++)

经常碰到字符串分割的问题,这里总结下,也方便我以后使用。 一、用strtok函数进行字符串分割 原型: char *strtok(char *str, const char *delim); 功能:分解字符串为一组字符串。 参数说明:str为要分解的字符串&am…