计算机视觉实战-----pytorch官方demo(Lenet网络)实现

news2024/7/6 19:20:25

系列文章目录

文章目录

  • 系列文章目录
  • 前言
  • 零、环境搭建
  • 一、下载CIFAR10数据集
  • 二、测试图片
  • 三、模型搭建
  • 四、开始train
  • 五、测试
  • 六、tensorboard可视化
  • 总结


前言

通过一个官方列子,清楚深度学习中图像的训练的整个流程


零、环境搭建

  1. pycharm下载:pycharm官网
  2. pycharm安装及使用教程【400MB左右】:PyCharm使用教程(详细版 - 图文结合)
  3. anaconda安装及使用教程【可以跳过】:最新Anaconda3的安装配置及使用教程(详细过程)
  4. 安装pytorch包:pytorch官网
    如果电脑没有GPU,可以只用安装CPU版本的pytorch
conda install pytorch torchvision torchaudio cpuonly -c pytorch

在这里插入图片描述
6. 下载大佬的源代码【3MB左右】:deep-learning-for-image-processing
7. 大佬b站视频:pytorch官方demo(Lenet)

一、下载CIFAR10数据集

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MTGvqaD8-1672647885101)(计算机视觉-图像处理结合深度学习.assets/image-20230101223814211.png)]

在代码download处改为:True,其余代码先注释掉。

#下载数据集
#下载数据集
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
def main():
    transform = transforms.Compose(
        [transforms.ToTensor(), #将图片转换为tensor
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化
    #torchvision.datasets. 下载数据集
    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # root表示将数据集下载到什么地方 train = True表示导入训练数据集
    # transform = transform 对数据进行预处理
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)#transform是对函数预处理的函数
    #torchvision.datasets.可以查看数据集
if __name__ == '__main__':
    main()

下载成功
在这里插入图片描述

二、测试图片

![在这里插入图片描述](https://img-blog.csdnimg.cn/b97704a27e32485bbafadf91136a88c1.png

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
#
#
def main():
    transform = transforms.Compose(
        [transforms.ToTensor(), #将图片转换为tensor
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化
    #torchvision.datasets. 下载数据集

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # root表示将数据集下载到什么地方 train = True表示导入训练数据集
    # transform = transform 对数据进行预处理
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False, transform=transform)#transform是对函数预处理的函数
    #torchvision.datasets.可以查看数据集
#     #导入训练集
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,          #导入训练集  shuffle = True 表示打乱数据集
                                               shuffle=True, num_workers=0)       #num_workers表示线程数 windows下只能设置为0,否则会报错
    #导入测试集 10000张图片
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,  #batch_size改成4,10000张图片看不了
                                             shuffle=False, num_workers=0)#num_workers=0线程个数,windows下只能为0
    test_data_iter = iter(testloader) #转换迭代器
    test_image, test_label = test_data_iter.next() #通过.next()获得图片和标签值
    #类别,元组类型 plane->0
    classes = ('plane', 'car', 'bird', 'cat',
                'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    #测试 需要numpy和plot包
    def imshow(img):
        img = img / 2 + 0.5  # unnormalize 对图像进行反标准化处理
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0))) # h= w= channel=0
        plt.show()

    # print labels 打印标签
    print(' '.join(f'{classes[test_label[j]]:5s}' for j in range(4)))
    # show images 查看图片
    imshow(torchvision.utils.make_grid(test_image))

if __name__ == '__main__':
    main()

三、模型搭建

在这里插入图片描述
在这里插入图片描述
单步测试查看尺寸变化
在这里插入图片描述

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):   #类Lenet继承nn.Module父类
    def __init__(self):
        super(LeNet, self).__init__()         #super函数解决多继承可能遇到的问题
        self.conv1 = nn.Conv2d(3, 16, 5)      #定义第一个卷积层 3个通道(RGB),16个卷积核,5代表尺度,5x5的卷积核
        self.pool1 = nn.MaxPool2d(2, 2)       #定义池化层(下采样层),只改变图片的高和宽。池化核2x2,步距为2的最大池化操作
        self.conv2 = nn.Conv2d(16, 32, 5)     #定义第二个卷积层,深度为16,卷积核32,5x5的卷积核
        self.pool2 = nn.MaxPool2d(2, 2)       #定义第二个池化层
        self.fc1 = nn.Linear(32*5*5, 120)     #定义第一个全连接层,全连接层输入是一维向量,需要将特征矩阵展平
        self.fc2 = nn.Linear(120, 84)         #定义第二个全连接层,120为上一层的输出
        self.fc3 = nn.Linear(84, 10)          #定义第三个全连接层,84为上一层的输出  10根据super修改,这里是10个类别

    def forward(self, x):            #正向传播 x代表输入的数据
        x = F.relu(self.conv1(x))    # input(3, 32, 32) 第一层输出 output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14) 通过最大池化后,高度和宽度降为原来的一半,深度不变16
        x = F.relu(self.conv2(x))    # output(32, 10, 10) N = (W-F+2P)/S +1 => N=(14-5+2x0)/1 +1 = 10
        x = self.pool2(x)            # output(32, 5, 5)   通过第二个池化层,高度和宽度降为原来的一半
        x = x.view(-1, 32*5*5)       # output(32*5*5)     将全连接层拼接,变成一维向量,-1表示第一个维度,展平后的个数:32*5*5
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

#测试
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

四、开始train

在这里插入图片描述
第一个500步,训练损失率时1.747,测试准确率时0.436
迭代5个epoch后,测试准确率达到:0.652
最后生成模型权重文件:Lenet.pth文件
在这里插入图片描述

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


def main():
    transform = transforms.Compose(
        [transforms.ToTensor(), #将图片转换为tensor
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化
    #torchvision.datasets. 下载数据集

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # root表示将数据集下载到什么地方 train = True表示导入训练数据集
    # transform = transform 对数据进行预处理
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False, transform=transform)
    #导入训练集
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,          #导入训练集  shuffle = True 表示打乱数据集
                                               shuffle=True, num_workers=0)
    #导入测试集
    # 10000张验证图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # val_set = torchvision.datasets.CIFAR10(root='./data', train=True,
    #                                        download=False, transform=transform)
    # val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
    #                                          shuffle=False, num_workers=0)
    # val_data_iter = iter(val_loader)
    # val_image, val_label = val_data_iter.next()

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                             shuffle=False, num_workers=0)#num_workers=0线程个数,windows下只能为0
    test_data_iter = iter(testloader)
    test_image, test_label = test_data_iter.next() #通过.next()获得图片和标签值
    #类别,元组类型 plane->0
    classes = ('plane', 'car', 'bird', 'cat',
                'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()  #实例化模型
    loss_function = nn.CrossEntropyLoss() #定义损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.001)#使用Adam优化器 导入参数量,lr是学习率

    #训练过程
    for epoch in range(5):  # loop over the dataset multiple times 训练迭代多少轮 迭代5次

        running_loss = 0.0 #记录累积的训练损失
        for step, data in enumerate(train_loader, start=0): #遍历训练集样本
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data #输入图像和标签

            # zero the parameter gradients 清除历史损失梯度 如果不清除历史梯度,就会对计算的历史梯度进行累加
            #一般batchsize越大,训练效果越好,但由于设备硬件原因,batchsize可能不会太大,通过下面这个函数解决
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs) #将输入图片传入到网络中
            loss = loss_function(outputs, labels) #计算损失 outputs是网络预测的值,labels是输入图片对应的标签
            loss.backward() #将loss进行反向传播
            optimizer.step() #参数更新

            # print statistics 打印输出
            running_loss += loss.item() #累加损失
            if step % 500 == 499:    # print every 500 mini-batches 每隔500步打印信息
                with torch.no_grad(): #with是一个上下文管理器,这个函数在验证和测试阶段有用
                    outputs = net(test_image)  # [batch, 10] 进行正向传播
                    predict_y = torch.max(outputs, dim=1)[1] #得到预测最大值的标签类别 需要在第一个维度去寻找,只需要找到index
                    accuracy = (predict_y==test_label).sum().item() /test_label.size(0) #将预测标签与真实标签比较 ,前面得到的是tensor数据,需要使用.item()进行数据转换
                                                                                                 #除以测试样本的数量,除以测试样本的数量,就得到准确率
                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))#迭代到多少轮,多少轮的哪一步,每500步统计平均误差
                    running_loss = 0.0

    print('Finished Training')
    #保存模型
    save_path = './Lenet.pth'
    torch.save(net.state_dict(), save_path)#保存网络中的所有参数


if __name__ == '__main__':
    main()

五、测试

测试得到的结果准确
在这里插入图片描述

#train.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),#缩放:将任意大小图片转换为规定的格式:32x32
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化处理

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet() #实例化
    net.load_state_dict(torch.load('Lenet.pth')) #载入权重文件

    im = Image.open('3.jpg') #载入图像 高度,宽度,channal
    im = transform(im)  # [C, H, W] 预处理 深度 高度 宽度
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 增加一个维度

    with torch.no_grad():
        outputs = net(im)#将图像传入到网络中
        predict = torch.max(outputs, dim=1)[1].data.numpy()#找到预测最大值的index(索引)第0个维度是batchsize,第1个维度才是channal
    print(classes[int(predict)])
if __name__ == '__main__':
    main()

六、tensorboard可视化

参考:(傻瓜教程)TensorBoard可视化工具简单教程及讲解(TensorFlow与Pytorch)

  • TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。特别是在训练网络的时候,我们可以设置不同的参数(比如:权重W、偏置B、卷积层数、全连接层数等),使用TensorBoader可以很直观的帮我们进行参数的选择。它通过运行一个本地服务器,来监听6006端口。在浏览器发出请求时,分析训练时记录的数据,绘制训练过程中的图像。
  • 安装tensorboard
pin install tensorboard
或者
conda install tensorboard

在这里插入图片描述

  • 测试例子1
#tensorboard.py
import numpy as np
from tensorboardX import SummaryWriter

writer = SummaryWriter(log_dir='scalar')
for epoch in range(100):
    writer.add_scalar('scalar/test', np.random.rand(), epoch)
    writer.add_scalars('scalar/scalars_test', {'xsinx': epoch * np.sin(epoch), 'xcosx': epoch * np.cos(epoch)}, epoch)
writer.close()

运行tensorboard.py文件后,相关文件会保存到scalar文件中

  • 然后在终端输入命令
tensorboard --logdir scalar
  • 打开浏览器输入地址:http://localhost:6006/ 便得到可视化图
    在这里插入图片描述
    在这里插入图片描述
  • 测试例子2【直方图 (histogram)】

参考:详解PyTorch项目使用TensorboardX进行训练可视化

使用 add_histogram 方法来记录一组数据的直方图。

add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)

参数

tag (string): 数据名称
values (torch.Tensor, numpy.array, or string/blobname): 用来构建直方图的数据
global_step (int, optional): 训练的 step
bins (string, optional): 取值有 ‘tensorflow’、‘auto’、‘fd’ 等, 该参数决定了分桶的方式,详见这里。
walltime (float, optional): 记录发生的时间,默认为 time.time()
max_bins (int, optional): 最大分桶数

from tensorboardX import SummaryWriter
import numpy as np

writer = SummaryWriter('runs/embedding_example')
writer.add_histogram('normal_centered', np.random.normal(0, 1, 1000), global_step=1)
writer.add_histogram('normal_centered', np.random.normal(0, 2, 1000), global_step=50)
writer.add_histogram('normal_centered', np.random.normal(0, 3, 1000), global_step=100)

在这里插入图片描述

  • 在train.py代码中加入代码用tensorboard可视化训练结果
# print statistics 打印输出
            running_loss += loss.item() #累加损失
            writer = SummaryWriter(log_dir='scalar')
            if step % 500 == 499:    # print every 500 mini-batches 每隔500步打印信息
                with torch.no_grad(): #with是一个上下文管理器,这个函数在验证和测试阶段有用
                    outputs = net(test_image)  # [batch, 10] 进行正向传播
                    predict_y = torch.max(outputs, dim=1)[1] #得到预测最大值的标签类别 需要在第一个维度去寻找,只需要找到index
                    accuracy = (predict_y==test_label).sum().item() /test_label.size(0) #将预测标签与真实标签比较 ,前面得到的是tensor数据,需要使用.item()进行数据转换
                                                                                                 #除以测试样本的数量,除以测试样本的数量,就得到准确率
                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))#迭代到多少轮,多少轮的哪一步,每500步统计平均误差

                    writer.add_scalar('scalar/train_accuracy', accuracy, epoch)
                    writer.add_scalar('scalar/train_loss',running_loss / 500, epoch)
                    running_loss = 0.0
            writer.close()

在这里插入图片描述

总结

  • 通过对图像处理基础讲解,学习了卷积层、池化层、全连接层的概念。
  • 学习了训练数据集的整个流程。
  • 了解使用tensorboard可视化结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/134719.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】算法好题初阶(每日小细节010)

1.存在重复元素 力扣传送门、 这道题目的解题思路就是先排序然后比较相邻元素是不是有相等的,有就是true否者false 排序的算法比较多大,但是我用插入和快排plus版都没有过... 但是非递归的归并过了,对排序算法感兴趣的小伙伴可以去看我的博…

一款统计摸鱼时长的开源项目

对于我们程序员,在工作中一天8小时,不可能完全在写代码了,累了刷刷论坛、群里吹吹牛,这都是非常正常的。虽然一天下来,可能我们都可以按时完成工作,但是我们不知道,时间都花在哪里了&#xff0c…

小米万兆路由器里的Docker安装MySQL8.0

小米2022年12月份发布了万兆路由器,里面可以使用Docker。 今天尝试在小米的万兆路由器里安装MySQL8.0。 准备工作 请参考https://engchina.blog.csdn.net/article/details/128515422的准备工作。 创建存储 请参考https://engchina.blog.csdn.net/article/detail…

Faster RCNN网络源码解读(Ⅹ) --- FastRCNN部分正负样本采样及FastRCNN部分损失计算

目录 一、回顾以及本篇博客内容概述 二、代码解析 2.1 ROIHeads类(承接上篇博客的2.5节) 2.1.1 初始化函数 __init__回顾 2.1.2 正向传播forward回顾 2.1.3 select_training_samples 2.1.4 add_gt_proposals 2.1.5 assign_targets_to_proposals…

【Git】Git瘦身,清理Git历史提交/.git大文件清理(云效、UI 自动化项目)

目前项目是存在云效(codeup.aliyun.com)上 本地清理后,还需要到云效上清理「存储空间管理」 一、清理/瘦身效果二、到底是什么在占空间?1、先看一下项目里,什么最占空间?2、往下看在/.git里,什么最占空间?三…

车载诊断协议UDS——读取故障服务Service 19

汽车控制器诊断功能,可以通过诊断服务读取车内控制器故障信息。如本文所分享的内容,通过Service 19服务读取车内控制器故障信息。 一、DTC显示类型 在OEM定义的诊断需求规范中,会定义DTC(诊断故障码)与具体控制器具体故障类型相关联(一个DTC故障码对应一个具体故障)…

深度学习目标检测_YOLOV2超详细解读

文章目录YOLO v2概述Batch Normalization(批归一化)High Resolution Classifier(高分辨率预训练分类网络)New Network:Darknet-19神经网络中的filter (滤波器)与kernel(内核&#xf…

【Java语言】— 循环结构:while循环、do-while循环

while循环 1.while循环格式与执行流程 while循环格式 初始化语句; while (循环条件){循环体语句(被重复执行的代码);迭代语句; }示例&#xff1a; inti 0; while (i < 3){System.out.println("Hello World");i; }while循环执行流程 什么时候用for循环&#x…

蓝桥杯寒假集训第五天(子串分值和)

没有白走的路&#xff0c;每一步都算数&#x1f388;&#x1f388;&#x1f388; 题目描述&#xff1a; 输入一个字符串&#xff0c;然后计算所有连续子串中没有重复字母的个数 输入描述&#xff1a; 第一行&#xff1a; 一个字符串 输出描述&#xff1a; 所有子串中没有…

软件设计模式---结构型模式

结构型模式 结构型模式概述 结构型模式描述如何将类或者对象结合在一起形成更大的结构&#xff0c;就像搭积木&#xff0c;可以通过简单积木组合形成复杂的、功能更更为强大的结构 结构型模式可以分为类结构型模式和对象结构型模式 类结构型模式关心类的组合&#xff0c;由多…

1、Java多线程技能基础

文章目录第一章 Java多线程技能1.1进程和线程的定义以及多线程的优点1.2 使用多线程1.2.1继承Thread类1.2.2常见的3个命令分析线程的信息方法一\:cmdjsp方法二\:jmc.exe方法三&#xff1a;jvisualcm.exe1.2.3 线程随机性的展现1.2.4 执行start()的顺序不代表执行run()的顺序1…

hcip第五天实验

拓扑图 每台路由器都有两个环回&#xff0c;一个24的环回&#xff0c;一个32的环回&#xff1b;32的环回用于建邻&#xff0c;24的环回用于用户网段&#xff0c;最终所有24的环回可以ping通。 实验步骤 1.配置ip 2.让2,3,4号设备的IGP协议可以通信 3.两两之间建立BGP邻居关…

DR_CAN基尔霍夫电路题解法【自留用】

无目录如图所示电路&#xff0c;输入端电压eie_iei​&#xff0c;输出端电压eoe_oeo​&#xff0c;求二者之间关系。 对其中元件进行标号&#xff0c;并将电流环路标号&#xff0c;指出各元件的压降方向&#xff1a; v值得注意的是&#xff1a; 1&#xff09;电阻R2R_2R2​同时…

rabbitmq消息发送的可靠性:结合mysql来保证消息投递的可靠性

消息从生产者到Broker&#xff0c;则会触发confirmCallBack回调消息从exchange到Queue&#xff0c;投递失败则会调用returnCallBack 用一张表来记录发送到mq的每一条消息&#xff0c;方便发送失败需要重试。status&#xff1a; 1-正常&#xff0c;0-重试&#xff0c;2-失败。发…

【计算机视觉】OpenCV 4高级编程与项目实战(Python版)【1】:图像处理基础

目录 1. OpenCV简介 2. OpenCV开发环境搭建 3. 读取图像 4. 读取png文件出现警告 5. 显示图像 6. 保存图像 7. 获取图像属性 本系列文章会深入讲解OpenCV 4&#xff08;Python版&#xff09;的核心技术&#xff0c;并提供了大量的实战案例。这是本系列文章的第一篇&…

简单了解计算机的工作原理

文章目录一.计算机操作系统二.进程/任务三、进程控制块抽象(PCB)四、进程调度相关属性五、内存管理一.计算机操作系统 概念:操作系统是一组做计算机资源管理的软件的统称. 目前常见的操作系统有&#xff1a;Windows系列、Unix系列、Linux系列、OSX系列、Android系列、iOS系列…

百度安全在线查询,网站弹出风险提示怎么处理

站长们要避免网站打开弹出风险提示&#xff0c;需要要时刻关注自己的网站是否存在风险&#xff0c;时刻知道自己的网站是不是安全的。 百度安全在线查询步骤&#xff1a; 1、打开站长工具 2、添加需要查询的网站域名。 3、勾选百度安全。 4、点击开始查询。 等…

22个Python的万用公式分享

在大家的日常python程序的编写过程中&#xff0c;都会有自己解决某个问题的解决办法&#xff0c;或者是在程序的调试过程中&#xff0c;用来帮助调试的程序公式。小编通过几十万行代码的总结处理&#xff0c;总结出了22个python万用公式&#xff0c;可以帮助大家解决在日常的py…

再学C语言22:循环控制语句——循环嵌套和数组处理

嵌套循环&#xff08;nested loop&#xff09;&#xff1a;在一个循环内使用另一个循环 一、循环嵌套 示例代码&#xff1a; #include <stdio.h> int main(void) {int i;int j;for(i 0; i < 10; i){for(j 0; j < 9; j){printf("%5d", j); // 里面的…

共享模型之管程(二)

1.Moniter对象 1.1.Java对象头 1>.以32位虚拟机为例 ①.普通对象 Klass Word表示对象的类型,它是一个指针,指向了对象所从属的class; ②.数组对象 在32位虚拟机中,integer包装类型的长度为12个字节,而int基本数据类型的长度为4个字节; 其中Mark Word结构为: 2>.64位…