笔记3:torch训练测试VGG网络

news2024/12/1 0:30:51

(1)利用Netron查看网络实际情况

在这里插入图片描述
上图链接
python生成上图代码如下,其中GETVGGnet是搭建VGG网络的程序GETVGGnet.py,VGGnet是该程序中的搭建网络类。netron是需要pip安装的可视化库,注意do_constant_folding=False可以防止Netron中不显示Batchnorm2D层,禁用参数隐藏。

import torch
from torch.autograd import Variable
from GetVGGnet import VGGnet
import netron

net = VGGnet()
x = Variable(torch.FloatTensor(1,3,28,28))
y = net(x)
print(y.data.shape)
onnx_path = "./save_model/VGGnet.onnx"
torch.onnx.export(net, x, onnx_path,do_constant_folding=False)
print(net)
netron.start(onnx_path)

(2)VGG训练测试全过程

此次训练在CPU上进行,迭代次epoch = 10,迭代内轮次batch=300,训练集10000张,测试集2000张。
train loss和train corre分别代表损失和正确率,横轴是不同迭代下每一个伦次的loss&corre累加,一个迭代进行33个轮次,每个迭代最后一个伦次数据不足被网络舍弃,10个迭代总共320次。test loss和test corre是每个一个迭代下所有伦次的正确率平均值。根据图可以看出,训练和测试结果都较好。
在这里插入图片描述
训练的损失和正确率在波动,但总体趋势较好。
在这里插入图片描述
数据集大小可以在此处修改:在这里插入图片描述

代码:cifar10_handle和GetVGGnet在上几篇文章有说明

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: 楠楠星球
@time: 2024/5/10 10:15 
@file: VGGTrain.py-->test
@project: pythonProject
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from GetVGGnet import VGGnet
from cifar10_handle import train_dataset,test_dataset
import matplotlib.pyplot as plt

epoch = 10  #迭代次数
learn_rate = 0.01 #初始学习率

net = VGGnet().to(device='cpu') #模型实例化
loss_fun = nn.CrossEntropyLoss() #调用损失函数
train_data_loder = DataLoader(dataset=train_dataset,
                              batch_size=300,  #每一次迭代的调用的波次
                              shuffle=True,    #这个波次是否打乱数据集
                              num_workers=4,   # 线程数
                              drop_last=True)  # 最后一个波次数据不足是否舍去

test_data_loder = DataLoader(dataset=test_dataset,
                             batch_size=300,
                             shuffle=False,
                             num_workers=4,
                             drop_last=True)

# optimizer = torch.optim.Adam(net.parameters(), lr=learn_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learn_rate, momentum=0.5) #优化器

# scheduler = torch.optim.lr_scheduler.StepLR(optijumizer, step_size=5, gamma=0.9) #step_size=1表示每迭代一次更新一下学习率
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7) #学习率调整器


def train(epoch_num,train_net):
	# ------------------------------------------()--------------------------------------
	loss_base = []
	corre_base = []
	
	test_loss_base = []
	test_corre_base =[]
	for epoch in range(epoch_num):
		# ------------------------------------------(TRAIN)--------------------------------------
		train_net.train()
		for i, data in enumerate(train_data_loder):
			input_tensor, label = data
			input_tensor = input_tensor.to(device='cpu')
			label = label.to(device='cpu')
			
			output_tensor = train_net(input_tensor)
			loss = loss_fun(output_tensor, label)
			
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			
			_, pred = torch.max(output_tensor.data, dim=1)
			correct = pred.eq(label.data).cpu().sum()
			
			print(f"训练中:第{epoch + 1}次迭代的小迭代{i}的损失率为:{1.00 * loss.item()},正确率为:{100.00 * correct / 300}")
			loss_base.append(loss.item())
			corre_base.append(100.00 * correct.item() / 300)
	
		scheduler.step()
		
		# ------------------------------------------(TEST)--------------------------------------
		sum_test_loss = 0
		sum_test_corre = 0
		train_net.eval()
		for i, test_data in enumerate(test_data_loder):
			input_tensor, label = test_data
			input_tensor = input_tensor.to(device='cpu')
			label = label.to(device='cpu')
			
			output_tensor = train_net(input_tensor)
			loss = loss_fun(output_tensor, label)
			_, pred = torch.max(output_tensor.data, dim=1)
			correct = pred.eq(label.data).cpu().sum()

			sum_test_loss += loss.item()
			sum_test_corre += correct.item()
			
		
		test_loss = sum_test_loss * 1.0 / len(test_data_loder)
		test_corre = sum_test_corre * 100.0 / len(test_data_loder) / 300
		test_loss_base.append(test_loss)
		test_corre_base.append(test_corre)
		print(f"测试中:当前迭代的测试集损失为:{test_loss},正确率为:{test_corre}")
	return loss_base,corre_base,test_loss_base,test_corre_base
	# ------------------------------------------()--------------------------------------

if __name__ == '__main__':
	[train_loss,train_corre,test_loss,test_corr] = train(epoch,net)
	fig, axes = plt.subplots(2, 2)
	
	axes[0, 0].plot(list(range(1, len(train_loss)+1 )), train_loss,color ='r')
	axes[0, 0].set_title('train loss')
	
	axes[0, 1].plot(list(range(1, len(train_corre) + 1)), train_corre, color ='r')
	axes[0, 1].set_title('train corre')
	
	axes[1, 0].plot(list(range(1, len(test_loss) + 1)), test_loss,color ='r')
	axes[1, 0].set_title('test loss')

	axes[1, 1].plot(list(range(1, len(test_corr) + 1)), test_corr,color ='r')
	axes[1, 1].set_title('test corre')
	plt.show()
	
	# torch.save(net.state_dict(), './save_model/example1.pt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1667626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis数据结构扩容源码分析

1 Redis数据结构 redis的数据存储在dict.中,其数据结构为(c源码) ypedef struct dict { dictType *type; //理解为面向对象思想,为支持不同的数据类型对应dictType抽象方法,不同的数据类型可以不同实现 void *privdata; //也可不同的数据类…

[AutoSar]BSW_Diagnostic_004 ReadDataByIdentifier(0x22)的配置和实现

目录 关键词平台说明背景一、配置DcmDspDataInfos二、配置DcmDspDatas三、创建DcmDspDidInfos四、创建DcmDspDids五、总览六、创建一个ASWC七、mapping DCM port八、打开davinci developer,创建runnabl九、生成代码 关键词 嵌入式、C语言、autosar、OS、BSW、UDS、…

Maven:继承和聚合

Maven高级 分模块设计和开发 如果在我们自己的项目中全部功能在同一个项目中开发,在其他项目中想要使用我们封装的组件和工具类并不方便 不方便项目的维护和管理 项目中的通用组件难以复用 所以我们需要使用分模块设计 分模块设计 在项目设计阶段,可以将大的项目拆分成若…

欢乐钓鱼大师攻略,兑换码怎么操作?

在努力钓鱼的同时,别忘了收获丰富的奖励和成就,这将是你在游戏中的最大动力和满足感。 完成任务和挑战: 游戏中有各种各样的任务和挑战等着你去完成。通过完成这些任务和挑战,你可以获得丰富的奖励和成就,提升自己的钓…

[Java EE] 文件IO(一):文件概念与文件系统操作

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏:🍕 Collection与数据结构 (91平均质量分)https://blog.csdn.net/2301_80050796/category_12621348.html?spm1001.2014.3001.5482 🧀Java …

android studio配置Http Proxy

1、问题描述: Error:Unable to tunnel through proxy. Proxy returns “HTTP/1.1 400 Bad Request” 解决:HTTP Proxy设置 1.File→Settings…→System Settings → HTTP Proxy → Auto-detect proxy settings”; 2.勾选下方“Automatic prox…

景源畅信电商:经营抖店需要电脑吗?

经营抖店是否需要电脑?这个问题看似简单,实则关乎着商家的运营效率和成本投入。在当前数字化、网络化的商业环境中,电脑已经成为了不可或缺的工具。那么,经营抖店究竟是否需要电脑呢?答案是肯定的。 一、高效处理订单 电脑能够高效地处理大…

【408真题】2009-03

“接”是针对题目进行必要的分析,比较简略; “化”是对题目中所涉及到的知识点进行详细解释; “发”是对此题型的解题套路总结,并结合历年真题或者典型例题进行运用。 涉及到的知识全部来源于王道各科教材(2025版&…

【Linux】- Linux环境变量[8]

目录 环境变量 $符号 自行设置环境变量 环境变量 环境变量是操作系统(Windows、Linux、Mac)在运行的时候,记录的一些关键性信息,用以辅助系统运行。在Linux系统中执行:env命令即可查看当前系统中记录的环境变量。 …

代数结构:5、格与布尔代数

16.1 偏序与格 偏序集:设P是集合,P上的二元关系“≤”满足以下三个条件,则称“≤”是P上的偏序关系(或部分序关系) (1)自反性:a≤a,∀a∈P; (2…

将Flutter程序打包为ios应用并进行安装使用

如果直接执行flutter build ios: Building com.example.myTimeApp for device (ios-release)...════════════════════════════════════════════════════════════════════════════════No vali…

Multisim 14简易三人抢答器电路设计

multisim multisim,即电子电路仿真设计软件。Multisim是美国国家仪器(NI)有限公司推出的以Windows为基础的仿真工具,适用于板级的模拟/数字电路板的设计工作。它包含了电路原理图的图形输入、电路硬件描述语言输入方式&#xff0…

汇昌联信:做拼多多网点需要具备什么能力?

在当前电商行业高速发展的背景下,拼多多以其独特的商业模式迅速崛起,成为众多创业者和商家关注的焦点。想要运营一家成功的拼多多网点,不仅需要对平台规则有深入的了解,还需要具备多方面的能力。这些能力是确保网点稳定运营并实现…

【安全每日一讲】加强数据安全保护 共享数字化时代便利

前言 数据安全是数据治理的核心内容之一,随着数据治理的深入,我不断的碰到数据安全中的金发姑娘问题(指安全和效率的平衡)。 DAMA说,降低风险和促进业务增长是数据安全活动的主要驱动因素,数据安全是一种资…

47.乐理基础-音符的组合方式-连线

连线与延音线长得一模一样 它们的区别就是延音线的第三点,延音线必须连接相同的音 连线在百分之九十九的情况下,连接的是不同的音,如下图的对比,连线里的百分之1,以现在的知识无法理解,后续再写 在乐谱中遇…

【MySQL】聊聊你不知道的前缀索引原理以及使用场景

背景 在本周的时候,接到一个需求,需要通过加密后的身份证 md5 去数据库里匹配。由于业务方存储的是身份证 md5username 构建的一列,并且没有加索引。 解决方案:1.新建一列 md5的列,加索引 2.对现有的列进行加前缀索引…

数组实现循环队列

1、分析 循环队列最主要的特点为当前面的空间被pop后,后面的数据可以插入到前面空余的数据中去; 所以最难的部分为判断什么时候为空什么时候为满: a、空满问题 我们先来分析当数据满时,head和tail相等(tail认为是指…

C++:9.scanf扩展——原来这么好用!

——scanf:我**不常用了? 有一天看到了一道题: C 输入一个时间,输出它属于,白天,下午还是黑夜。 输入样例: 15:20 00:00 13:14 05:20 11:45 14:00 ……??? 大胆题目小瞧我的编程水平!!!!!…

什么是 IIS

什么是 IIS 一、什么是 IIS二、IIS 的功能三、IIS 几点说明四、IIS 的版本五、IIS 常见的组合 欢迎关注【云边小网安】 一、什么是 IIS IIS:指 Internet Information Services ,是一种由微软公司开发的 Web 服务器应用程序。IIS:是一种 Web …

GPU prompt

提问: GPU是如何与CPU协调工作的? GPU也有缓存机制吗?有几层?速度差异是多少? GPU渲染流程有哪些阶段?他们的功能分别是什么? Early-Z技术是什么?发生在哪个阶段?这个…