Pytorch入门(四)使用VGG16网络训练CIFAR10数据集

news2024/11/25 20:43:27

本文使用Pytorch+VGG16+官方CIFAR10数据集完成图像分类。识别效果如下:
在这里插入图片描述

文章目录

  • 一、VGG16 神经网络结构
  • 二、VGG16 模型训练
  • 三、预测CIFAR10中的是个类别

一、VGG16 神经网络结构

VGG,又叫VGG-16,顾名思义就是有16层,包括13个卷积层和3个全连接层,是由Visual Geometry Group组的Simonyan和Zisserman在文献《Very Deep Convolutional Networks for Large Scale Image Recognition》中提出卷积神经网络模型,该模型主要工作是证明了增加网络的深度能够在一定程度上影响网络最终的性能。其年参加了ImageNet图像分类与定位挑战赛,取得了在分类任务上排名第二,在定位任务上排名第一的优异成绩。


下图给出了VGG16的具体结构示意图:
在这里插入图片描述
下面给出按照块划分的VGG16的结构图,可以结合上图进行理解:
在这里插入图片描述
下图给出了VGG的六种结构配置:
在这里插入图片描述
我们针对VGG16进行具体分析发现,VGG16共包含:

  • 13个卷积层(Convolutional Layer),分别用conv3-XXX表示
  • 3个全连接层(Fully connected Layer),分别用FC-XXXX表示
  • 5个池化层(Pool layer),分别用maxpool表示

其中,卷积层和全连接层具有权重系数,因此也被称为权重层,总数目为13+3=16,这即是VGG16中16的来源。(池化层不涉及权重,因此不属于权重层,不被计数)。


import torch
import torch.nn as nn
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, 3, 1, 1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2, 2))
        self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 128, 3, 1, 1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2, 2))
        self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 256, 3, 1, 1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2, 2))
        self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2))
        self.layer5 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, 1, 1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2))
        self.fc1 = nn.Sequential(nn.Flatten(),nn.Linear(512, 512),nn.ReLU(),nn.Dropout())
        self.fc2 = nn.Sequential(nn.Linear(512, 256),nn.ReLU(),nn.Dropout())
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
if __name__ == '__main__':
    VGG16 = VGG16()
    input = torch.ones((64, 3, 32, 32))
    output = VGG16(input)
    print(output.shape)

二、VGG16 模型训练

下面代码给出了使用VGG16网络训练CIFAR10数据集时的代码。

# http://www.syrr.cn/news/10118.html?action=onClick
import torchvision
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn
from vgg16_model import *
import matplotlib.pyplot as plt
import time
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
train_dataloader = DataLoader(train_data, batch_size=128)
test_dataloader = DataLoader(test_data, batch_size=128)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
# 创建网络模型
vgg16 = VGG16()
vgg16 = vgg16.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 优化器
learning_rate = 0.015  # 设置学习速率
optimizer = torch.optim.SGD(vgg16.parameters(), lr=learning_rate)
# 设置训练网络的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 50
# 添加tensorboard画图可视化
writer = SummaryWriter("logs_train")
for i in range(epoch):
    print("--------第{}轮训练开始---------".format(i + 1))
    for data in train_dataloader:
        imgs, targets = data
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            targets = targets.cuda()
        outputs = vgg16(imgs)
        loss = loss_fn(outputs, targets)
        # 梯度调0
        optimizer.zero_grad()
        # 反向传播 梯度
        loss.backward()
        # 调优
        optimizer.step()
        #记录训练次数
        total_train_step = total_train_step + 1
        # 每100打印loss
        if total_train_step % 100 ==0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))
        writer.add_scalar("train_loss", loss.item(), total_train_step)
    # 测试,没梯度没有调优代码
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            if torch.cuda.is_available():
                imgs=imgs.cuda()
                targets=targets.cuda()
            outputs = vgg16(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            # 计算整体测试集上的正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

        print("整体测试集上的loss:".format(total_test_loss))
        # 使用/进行的tensor整数除法不再支持,可以使用true_divide代替
        print("整体测试集上的正确率:{}".format(total_accuracy.true_divide(test_data_size)))
        writer.add_scalar("test_loss", total_test_loss, total_test_step)
        total_test_step = total_test_step + 1
        # 可视化正确率
        writer.add_scalar("test_accuracy", total_accuracy.true_divide(test_data_size), total_test_step)
        # 保存每一轮的模型 这是第一种保存方式非官方推荐
        torch.save(vgg16, "pth/vgg16_{}.pth".format(i+1))
        print("模型已保存")
writer.close()

训练日志如下(具体大家自行训练):

训练数据集的长度为:50000
测试数据集的长度为:10000
--------1轮训练开始---------
训练次数:100,Loss:1.8452614545822144
训练次数:200,Loss:1.3886604309082031
训练次数:300,Loss:1.4469469785690308
整体测试集上的loss:
整体测试集上的正确率:0.5472999821172485
--------2轮训练开始---------
训练次数:400,Loss:1.108359356482321
训练次数:500,Loss:1.1589107513427734
训练次数:600,Loss:1.156502604484558
训练次数:700,Loss:0.886533796787262
整体测试集上的loss:
整体测试集上的正确率:0.6538000106811523
--------3轮训练开始---------
训练次数:800,Loss:0.884425163269043
训练次数:900,Loss:0.8080787062644958
训练次数:1000,Loss:0.7888829112052917
训练次数:1100,Loss:0.6955403081231316
整体测试集上的loss:
整体测试集上的正确率:0.7230999660491943

三、预测CIFAR10中的是个类别

import os

import torch
from PIL import Image
from torchvision import transforms
from vgg16_model import *
'''
CIFAR10包含哪几类 这10类分别是airplane (飞机),automobile(汽车),bird(鸟),cat(猫),deer(鹿),
dog(狗),frog(青蛙),horse(马),ship(船)和truck(卡车)
'''
basepath=os.path.split(os.path.split(os.getcwd())[0])[0]

# 定义加载图片的方式
transformed=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])
# 加载模型
myModel=torch.load("pth/vgg16_50.pth",map_location="cpu")

while 1:
    img_path=input("请输入检测图片的名称:")
    img=Image.open(basepath+rf"\imgs\{img_path}.png")
    img=img.convert("RGB")
    img=transformed(img)
    img=torch.reshape(img,(1,3,32,32))

    myModel.eval()
    with torch.no_grad():
        output=myModel(img)
        x=output.argmax(1)
        if x==0:
            print("飞机")
        elif x==1:
            print("汽车")
        elif x==2:
            print("鸟")
        elif x==3:
            print("猫")
        elif x==4:
            print("鹿")
        elif x==5:
            print("狗")
        elif x==6:
            print("青蛙")
        elif x==7:
            print("马")
        elif x==8:
            print("船")
        elif x==9:
            print("卡车")

预测结果在文章开头。


在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/616314.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

地震勘探基础(十)之地震速度关系

地震速度 地震勘探中引入了多种速度的概念,如下图所示。 层速度、平均速度和均方根速度之间的关系 层速度指的是某一套地层垂向上,由于地质条件相对稳定,地层顶底厚度比上地震波的传播时间为层速度,用 v n v_n vn​ 表示。 如下…

一文看懂软件架构4+1视图

目录 一、概述 二、各视图详解 1. 场景视图 2. 逻辑视图 3. 开发视图 4. 处理视图 5. 物理视图 葵花宝典:一看就懂的理解方式 一、概述 41视图包括: 场景视图(也叫用例视图):黑盒视图。从外部视角&#xff…

chatgpt赋能python:Python如何分段数据的平均数

Python如何分段数据的平均数 Python是一门极其流行的编程语言,广泛应用于数据分析与科学计算领域。在数据分析中,计算各个数据段的平均数是一项常见的任务。本文将介绍如何使用Python分段计算数据的平均数,以及如何优化这一过程以使速度更快…

Linux中的lrzsz

一、介绍 lrzsz是一款在Linux里可代替ftp上传和下载的程序,也就是一款软件。它是开发者常用的一款工具,这个工具用于windows机器和远端的Linux机器通过XShell传输文件。 二、lrzsz的安装 在安装之前,我们可以使用下述命令先查看yum仓库中是否存在我们要安装的软件: yum…

CentOS7使用Docker快速安装Davinci

环境信息 操作系统:CentOS7Docker : 23.0.6 (已配置阿里云镜像加速) 安装步骤 安装docker-compose-plugin 官方的例子使用的是docker-compose,但是由于yum能够安装的最新斑斑是1.x,而且官方的docker-compose要求最低版本为2.2以…

首个区块链技术领域国家标准出台 ,中创助力打造区块链技术和应用创新高地

区块链作为数字中国的重要技术底座,正在深刻改变着我国社会生产方式。何谓区块链,对大众来说,也许尚陌生,殊不知,这一产业已稳稳起跑在我国高质量发展的“赛道”上。 近日,《区块链和分布式记账技术参考架…

【JavaScript】超全基础万字大总结

目录 一、初识 JavaScript 1.1 JavaScript 是什么? 1.2 发展历史 1.3 JavaScript 和 HTML 和 CSS 之间的关系 1.4 JavaScript 运行过程 1.5 JavaScript 的组成 二、前置知识 2.1 第一个程序 2.2 JavaScript 的书写形式 2.3 输入输出 三、语法概览 3.1 变…

Linux(CentOS 7) 安装 Mysql8 、Java 以及 mycat2 详细流程

目录 一、Mysql8 安装 1.下载mysql8 2. 解压Mysql 压缩包 3.重名命mysql 文件 4.创建data文件夹 储存文件 5.创建用户组以及用户 6.授权用户 将mysql文件夹的所有者和所有组都改为mysql 7.mysql初始化进入bin目录执行mysqld文件进行初始化 8.编辑my.cnf 9.添加mysqld…

有哪些虚拟化和容器化工具推荐? - 易智编译EaseEditing

以下是几个常用的虚拟化和容器化工具推荐: VMware vSphere: VMware vSphere 是一套完整的虚拟化平台,包括虚拟化服务器、虚拟化存储和虚拟化网络。 它提供了高性能的虚拟机管理和资源调度功能,适用于企业级的虚拟化部署。 Docke…

IT知识百科:什么是跨站脚本(XSS)攻击?

跨站脚本(Cross-Site Scripting,XSS)是一种常见的网络安全漏洞,攻击者利用该漏洞在受害者的网页中插入恶意脚本,从而能够获取用户的敏感信息、劫持会话或进行其他恶意活动。本文将详细介绍跨站脚本攻击的原理、类型、常…

vue props传值层级多,子级孙子级怎么修改传参

vue props传值层级多了,子级孙子级怎么修改传参 1.出现背景2.怎么在孙组件里改变传过来的值呢2.1这样改是不行的2.2可行的方法2.2.1 引用对象只改变单属性2.2.2 provide和inject 1.出现背景 本来自己写页面的话是直接全部写在一个vue文件里,一个vue文件…

【solidworks】此文档 templates\gba0.drwdot 使用字体长仿宋体,而该字体不可用

一、问题背景 在SolidWorks中绘制工程图纸时,新建一个图纸,打开后就弹出字体错误 此文档 templates\gba0.drwdot 使用字体长仿宋体,而该字体不可用。 二、解决办法 点击选择新的字体,拖到最下面选择汉仪长仿宋体。 上面之所…

41 管理虚拟机可维护性-虚拟机NMI Watchdog

文章目录 41 管理虚拟机可维护性-虚拟机NMI Watchdog41.1 概述41.2 注意事项41.3 操作步骤 41 管理虚拟机可维护性-虚拟机NMI Watchdog 41.1 概述 NMI Watchdog是一种用来检测Linux出现hardlockup(硬死锁)的机制。通过产生NMI不可屏蔽中断,…

win10+tf2.x+cuda+cudnn踩坑记录( Loaded cuDNN version 8400)

项目场景: 项目用到了tensorflow2.x: 想要用GPU跑算法win10系统下需要安装cuda和cudnn配置带有tenserflow-gpu的环境 问题描述 jyputer运行错误提示:Loaded cuDNN version 8400 Could not locate zlibwapi.dll. Please make sure it is in…

智安网络|保护企业网络空间资产安全的重要性

在数字化时代,企业网络空间资产的安全和保护变得越来越重要,并且拥有安全性能优越、系统完整的企业网络系统,是企业发展的必要条件。但想要实现网络空间安全首先需要关注网络漏洞问题。 保护企业网络空间资产的重要性 网络空间资产安全是企…

【深度学习】跌倒识别(带数据集和源码)从0到1,内含很多数据处理的坑点和技巧,收获满满

文章目录 前言1. 数据集1.1 数据初探1.2 数据处理1.3 训练前验证图片1.4 翻车教训和进阶知识 2. 训练3.效果展示 前言 又要到做跌倒识别了。 主流方案有两种: 1.基于关键点的识别,然后做业务判断,判断跌倒,用openpose可以做到。…

Neural Architecture Search: A Survey

本文是神经架构搜索相关主题的第一篇文章,针对《Neural Architecture Search: A Survey》的一个翻译。 神经架构搜索:综述 摘要1 引言2 搜索空间3 搜索策略4 性能评估策略5 未来方向 摘要 过去几年,深度学习在图像识别、语音识别和机器翻译…

网络故障排除

计算机网络构成了数字业务的基础。为了确保业务连续性,需要日夜监控和管理这些网络背后的 IT 基础架构。IT 管理员在管理 IT 基础架构时经常遇到问题,这是他们工作的关键部分,更重要的部分是解决网络问题。 什么是网络故障排除 网络故障排除…

推动开源与商业共生共赢 | 2023开放原子全球开源峰会开源商业化创新发展分论坛即将启幕

开源具有利他性,专有软件或私有软件具有利己性,而开源的商业模式也具有利己性。利他性的开源与利己性的商业模式相结合,如何真正为开源做贡献? 由开放原子开源基金会主办,软通动力信息技术(集团&#xff0…

【MySQL高级篇笔记-索引优化与查询优化(中) 】

此笔记为尚硅谷MySQL高级篇部分内容 目录 一、索引失效案例 二、关联查询优化 1、采用左外连接 2、采用内连接 3、join语句原理 1.驱动表和被驱动表 2.Simple Nested-Loop Join(简单嵌套循环连接) 3.Index Nested-Loop Join(索引嵌套循环连接) 4.Block Nested-Loop J…