基于vgg16和pytorch框架进行cifar10数据集的图像分类

news2025/1/17 13:59:11

vgg16网络模型的实现

这里只讲怎么实现
百度搜到vgg16的网络模型图,用pytorch框架进行实现
在这里插入图片描述
图是这样,用pytorch实现就行,pyotrch不太熟悉的话可以去看小土堆的视频
命名mode.py 也可以使用其他名字,在后面的train.py里面改一下也行

import torch
import torch.nn as nn
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout()
        )
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
if __name__ == '__main__':
    VGG16 = VGG16()
    input = torch.ones((64, 3, 32, 32))
    output = VGG16(input)
    print(output.shape)

训练文件的编写

具体的都在注释里,需要注意的是 命名train.py

  • 模型文件与训练文件在同一文件夹下
  • batch_size以及epoch和学习率lr根据自己的需要进行更改,同时更改生成的模型文件名以及日志名(在代码里)要不会乱套
  • 数据集设置的是假如没有会自动下载,因此不用担心下载数据集
  • 本人的实验环境 ubuntu 18.04 显卡:2080ti window下的可以尝试应该没问题
import torchvision
from torch import optim
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *
import matplotlib.pyplot as plt
import time
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda:0")

train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
train_dataloader = DataLoader(train_data, batch_size=128)
test_dataloader = DataLoader(test_data, batch_size=128)

file = open('logs_epoch50_bt128_lr0.015.txt', 'w')


# 查看测试集,训练集的大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
file.write("训练数据集的长度为:{}".format(train_data_size) + '\n')
print("测试数据集的长度为:{}".format(test_data_size))
file.write("测试数据集的长度为:{}".format(test_data_size) + '\n')

# 创建网络模型
vgg16 = VGG16()
vgg16 = vgg16.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 优化器
learning_rate = 0.015  # 设置学习速率
optimizer = torch.optim.SGD(vgg16.parameters(), lr=learning_rate)
# 设置训练网络的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 50
# 添加tensorboard画图可视化
writer = SummaryWriter("logs_train")
for i in range(epoch):
    print("--------第{}轮训练开始---------".format(i + 1))
    file.write("--------第{}轮训练开始---------".format(i + 1) + '\n')
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = vgg16(imgs)
        loss = loss_fn(outputs, targets)
        # 梯度调0
        optimizer.zero_grad()
        # 反向传播 梯度
        loss.backward()
        # 调优
        optimizer.step()
        # 记录训练次数
        total_train_step = total_train_step + 1
        # 每100打印loss
        if total_train_step % 100 ==0:
            print("训练次数:{},Loss:{}".format(total_train_step, 
                                           loss.item()))
            file.write("训练次数:{},Loss:{}".format(total_train_step, 
                                           loss.item()) + '\n')
            writer.add_scalar("train_loss", loss.item(), total_train_step)
    # 测试,没梯度没有调优代码
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = vgg16(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            # 计算整体测试集上的正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print("整体测试集上的loss:".format(total_test_loss))
    file.write("整体测试集上的loss:".format(total_test_loss) + '\n')
    # 使用/进行的tensor整数除法不再支持,可以使用true_divide代替
    print("整体测试集上的正确率:{}".format(total_accuracy.true_divide(test_data_size)))
    file.write("整体测试集上的正确率:{}".format(total_accuracy.true_divide(test_data_size)) + '\n')
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    total_test_step = total_test_step + 1
    # 可视化正确率
    writer.add_scalar("test_accuracy", total_accuracy.true_divide(test_data_size), total_test_step)
    # 保存每一轮的模型 这是第一种保存方式非官方推荐
    torch.save(vgg16, "epoch50_bt128_lr0.015/vgg16_{}.pth".format(i+1))
    print("模型已保存")
writer.close()

训练日志如下(具体大家自行训练)

训练数据集的长度为:50000
测试数据集的长度为:10000
--------第1轮训练开始---------
训练次数:100,Loss:1.8452614545822144
训练次数:200,Loss:1.3886604309082031
训练次数:300,Loss:1.4469469785690308
整体测试集上的loss:
整体测试集上的正确率:0.5472999811172485
--------第2轮训练开始---------
训练次数:400,Loss:1.108359694480896
训练次数:500,Loss:1.1589107513427734
训练次数:600,Loss:1.156502604484558
训练次数:700,Loss:0.886533796787262
整体测试集上的loss:
整体测试集上的正确率:0.6538000106811523
--------第3轮训练开始---------
训练次数:800,Loss:0.884425163269043
训练次数:900,Loss:0.8080787062644958
训练次数:1000,Loss:0.7888829112052917
训练次数:1100,Loss:0.6955403089523315
整体测试集上的loss:
整体测试集上的正确率:0.7030999660491943

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/155738.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言进阶】内存函数和结构体内存对齐

目录一.strerror函数1.错误码变量errno2.strerror函数的使用3.perror函数二.memcpy函数1.函数介绍2.模拟实现三.memmove函数1.函数介绍2.模拟实现四.结构体的内存对齐一.strerror函数 1.错误码变量errno 规定: C语言库函数如果出现运行错误,会将对应错误信息的错误…

联邦学习 (FL) 中常见的3种模型聚合方法的 Tensorflow 示例

联合学习 (FL) 是一种出色的 ML 方法,它使多个设备(例如物联网 (IoT) 设备)或计算机能够在模型训练完成时进行协作,而无需共享它们的数据。 “客户端”是 FL 中使用的计算机和设备,它们可以彼此完全分离并且拥有各自不…

基于Java springmvc+mybatis酒店信息管理系统设计和实现

基于Java springmvcmybatis酒店信息管理系统设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取…

程序员接私活的几个平台和建议,避免掉坑!

大家对于程序员接私活这件事的看法,褒贬不一。但是你如果确实用钱,价格又合适,那就大胆去接。 如果不那么缺钱,那么接私活之前先考虑清楚,如果自己将空余时间用在接私活所产生的价值是不是大于提升自己。如果是的话&a…

2022年 大学生工程训练比赛[物料搬运]

本人和团结参加了2022年大学生工程训练(简称工训赛)校赛选拔,准备了几个月的时间和花费了较多的资金,由于疫情等多种情况,很遗憾未能参加湖南省省赛,过了这么久还是写个博客记录参赛准备和调试过程。 目录 一、比赛要求 二、整体…

第十章面向对象编程(高级部分)

10.1 类变量和类方法(关键字static) 10.1.31类变量快速入门 思考: 如果,设计一个 int count 表示总人数,我们在创建一个小孩时,就把 count 加 1,并且 count 是所有对象共享的就 ok 了! package com.hspedu.static_;public class ChildGame {…

MS【1】:Metric

文章目录前言1. Dice Loss1.1. Dice coefficient1.2. F1 score - Dice1.3. Dice Loss2. Sensitivity & Specificity2.1. Sensitivity2.2. Specificity3. Hausdorff distance3.1. 概念3.2. 单向 Hausdorff distance3.3. 双向 Hausdorff distance3.4. 部分 Hausdorff distanc…

使用ResNet18实现CIFAR100数据集的训练

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄 使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的…

二、数据仓库模型设计

数据仓库模型设计一、数据模型二、关系模型三、维度模型1、事实表(1)事务事实表(2)周期快照事实表(3)累计快照事实表(4)无事实的事实表2、维度表3、维度模型类型(1&#…

LVGL学习笔记16 - 进度条Bar

目录 1. Parts 2. 模式 2.1 LV_BAR_MODE_SYMMETRICAL:对称模式 2.2 LV_BAR_MODE_RANGE:范围模式 3. 动画 4. 样式 4.1 方向 4.2 渐变色 4.3 增加边框 4.4 滚动条方向 进度条有一个背景和一个指示器组成,通过lv_bar_create创建对象。…

mysql多表查询

一、关联查询(联合查询) 1.1 什么是关联查询 关联查询:两个或者多个表,一起查询。 前提条件: 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段&#x…

初识IL2CPP

在Unity中进行打包时,有两种打包方式选择:Mono和IL2CPP Mono和IL2Cpp是Unity的脚本后处理方式,通过脚本后处理实现Unity的跨平台 1.Mono (1). Mono组成组件: C#编辑器,CLI虚拟机,以及核心类别程序库 (2).跨平台过程 Mo…

【Linux】多线程概念

目录🌈前言🌸1、Linux线程概念🍡1.1、概念🍢1.2、线程的优点🍧1.3、线程的缺点🍨1.4、线程的异常和用途🌺2、Linux下进程 vs 线程🌈前言 这篇文章给大家带来线程的学习!…

PID算法入门(一)

1.简介 PID是Proportional(比例), Integral(积分), Differential(微分)的首字母缩写,他是一种结合比例,积分,微分三个环节于一体的闭环控制算法. 2.PID各环节 2.1比例环节 成比例地反应控制系统的偏差信号,即输出&a…

Codeforces Round #843 (Div. 2) A1 —— D

题目地址:Dashboard - Codeforces Round #843 (Div. 2) - Codeforces一个不知名大学生,江湖人称菜狗 original author: jacky Li Email : 3435673055qq.com Time of completion:2023.1.11 Last edited: 2023.1.11 目录 ​编辑 A1. Gardener…

读论文——day61 目标检测模型的决策依据与可信度分析

目标检测模型的决策依据与可信度分析本文贡献及原文1 相关工作(略看)1.3 目标检测模型2 背景知识(LIME)2.2 LIME3 目标检测决策依据及可信度分析3.1 决策依据3.2 对目标检测模型的预测进行可信度评价4 基于 LIME 的目标检测模型解…

(第四章)OpenGL超级宝典学习:必要的数学知识

必要的数学知识 前言 在本章当中,作者着重介绍了几个和3D图形学重要的数学知识,线性代数基础好的同学可以直接绕过本章,说实话这篇博客写到这里,我是非常犹豫的,本章节的内容可以说是很基础,但是相当…

SSM框架01_Spring

有一个效应叫知识诅咒:自己一旦知道了某事,就无法想象这件事在未知者眼中的样子。00-Spring课程介绍01-初识Spring今天所学的Spring其实是Spring家族中的Spring Framework;Spring Fra是Spring家族中其他框架的底层基础,学好Spring可以为其他S…

Morse1题解

原理摩尔斯电码和电报简单说一下电报和摩尔斯电码的原理最简单的电报模型就是一个电源,一个开关和一个电磁铁当需要长距离使用时候,需要用到继电器按下开关,电磁铁会吸引磁铁长按开关,电磁铁就会闭合一段时间,留下一划…

Jenkins集成GitLab Webhooks自动化构建

JenkinsGitLab Webhooks自动构建项目1 构建步骤1.1 Jenkins中设置构建触发器1.2 Build Authorization Token Root插件安装1.3 GitLab配置Webhooks2 测试webhooks2.1 测试推送事件2.2 测试合并请求事件2.3 代码修改提交测试1 构建步骤 1.1 Jenkins中设置构建触发器 这里先随便写…