Pytorch-ResNet50-MINIST Classify 网络实现流程

news2024/9/24 3:30:06

分两个文件讲解:1、train.py训练文件     2、test.py测试文件.

1、train.py训练文件

1)从主函数入口开始,设置相关参数

# 主函数入口
if __name__ == '__main__':
    # ----------------------------#
    #   是否使用Cuda
    #   没有GPU可以设置成Fasle
    # ----------------------------#
    cuda = True
    # ----------------------------#
    #   是否使用预训练模型
    # ----------------------------#
    pre_train = True
    # ----------------------------#
    #   是否使用余弦退火学习率
    # ----------------------------#
    CosineLR = True
    # ----------------------------#
    #   超参数设置
    #   lr:学习率
    #   Batch_size:batchsize大小
    # ----------------------------#
    lr = 1e-3
    Batch_size = 2
    Init_Epoch = 0
    Fin_Epoch = 100

 2)创建模型

# 创建模型
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)

#判断是否需要预训练模型,在1)已经设置pre_train=True,这里会加载预训练模型,
#为"logs/resnet50-mnist.pth"。
#这里加载的是预训练模型的权重参数,实例化到本地模型ResNet上
if pre_train:
    model_path = 'logs/resnet50-mnist.pth'
    model.load_state_dict(torch.load(model_path))

#判断cuda是否可用,如果cuda可用,模型将调用GPU,否则将调用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

3)创建数据集

# ----------------------------#
root='data/' :路径
train=True   :训练设置为True
transform=transforms.ToTensor() :转化成Tensor
download=True :下载
# ----------------------------#
train_dataset = datasets.MNIST(root='data/', train=True,
                               transform=transforms.ToTensor(), download=True)
#这里train = False, download=False,此时下载验证集
test_dataset = datasets.MNIST(root='data/', train=False,
                              transform=transforms.ToTensor(), download=False)

4)加载数据集

# ----------------------------#
#DataLoader加载数据集
batch_size=Batch_size 批量输入
shuffle=True 打乱数据
num_workers=0 单个工作进程
# ----------------------------#
gen = DataLoader(dataset=train_dataset, batch_size=Batch_size, shuffle=True, num_workers=0)
gen_test = DataLoader(dataset=test_dataset, batch_size=Batch_size // 2, shuffle=True, num_workers=0)

5)设置损失函数和优化器

#损失函数为交叉熵损失
softmax_loss = torch.nn.CrossEntropyLoss()
#优化器选择Adams
optimizer = torch.optim.Adam(model.parameters(), lr)

6)设置学习率

#如果CosineLR = True,学习率为CosineAnnealingLR,否则为StepLR
if CosineLR:
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-10)
else:
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

7)训练

# ----------------------------#
epoch_size 训练集一次加载多少个batch
epoch_size_val 验证集一次加载多少个batch
# ----------------------------#
epoch_size = len(gen) 
epoch_size_val = len(gen_test)

# ----------------------------#
Init_Epoch 起始训练为0
Fin_Epoch  终止训练为100次
fit_one_epoch()函数进行训练数据
lr_scheduler.step()一次训练结束后,学习率进行更新
# ----------------------------#
for epoch in range(Init_Epoch, Fin_Epoch):
    fit_one_epoch(net=model, softmaxloss=softmax_loss, epoch=epoch, epoch_size=epoch_size,epoch_size_val=epoch_size_val, gen=gen, gen_test=gen_test, Epoch=Fin_Epoch, cuda=cuda)
    lr_scheduler.step()

2、test.py测试文件

展示运行结果

1)整段讲解

 

import torch
from nets.resnet50 import ResNet,Bottleneck
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
import cv2
import time

# 设置权重文件路径
PATH = './logs/resnet50-mnist.pth'
# 谁知手动输入单次识别字数
Batch_Size = int(input('每次预测手写字体图片个数:'))
# 加载模型
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)
model.load_state_dict(torch.load(PATH))
model = model.cuda()

# 进入测试程序
model.eval()
# 设置测试数据集并加载
test_dataset = datasets.MNIST(root='data/', train=False,
                                    transform=transforms.ToTensor(), download=False)
gen_test = DataLoader(dataset=test_dataset, batch_size=Batch_Size, shuffle=True)

# 进入循环
while True:
    # 获取图片和标签
    images, lables = next(iter(gen_test))
    img = torchvision.utils.make_grid(images, nrow=Batch_Size)
    img_array = img.numpy().transpose(1, 2, 0)
    # 获取开始时间
    start_time = time.time()
    
    # 输出预测结果
    outputs = model(images.cuda())
    _, id = torch.max(outputs.data, 1)
    end_time = time.time()
    
    # 打印用时和预测结果,由于输出的id为tensor,这里必须转换为numpy
    print('预测用时:', end_time-start_time)
    print('预测结果为', id.data.cpu().numpy())
    # 展示图片
    cv2.imshow('img', img_array)
    cv2.waitKey(0)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/735238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IDEA+SpringBoot+mybatis+bootstrap+jquery+Mysql车险理赔管理系统

IDEASpringBootmybatisbootstrapjqueryMysql车险理赔管理系统 一、系统介绍1.环境配置 二、系统展示1. 管理员登录2.编辑个人信息3.用户管理4.添加用户5.申请理赔管理6.赔偿金发放管理7.待调查事故保单8.已调查记录9.现场勘察管理10.勘察记录11.我的保险管理12.我的理赔管理 三…

Atcoder Beginner Contest 309——D-F讲解

前言 由于最近期末考试,所以之前几场都没打,给大家带了不便,非常抱歉。 这个暑假,我将会持续更新,并给大家带了更好理解的题解!希望大家多多支持。 由于, A ∼ C A\sim C A∼C 题比较简单&am…

现代C++新特性 扩展的聚合类型(C++17 C++20)(PC浏览效果更佳)

文字版PDF文档链接:现代C新特性(文字版)-C文档类资源-CSDN下载 1.聚合类型的新定义 C17标准对聚合类型的定义做出了大幅修改,即从基类公开且非虚继承的类也可能是一个聚合。同时聚合类型还需要满足常规条件。 1.没有用户提供的构造函数。…

用C语言写一个压缩文件的程序

本篇目录 数据在计算机中的表现形式huffman 编码将文件的二进制每4位划分,统计其值在文件中出现的次数构建二叉树搜索二叉树的叶子节点运行并输出新的编码文件写入部分写入文件首部写入数据部分压缩运行调试解压缩部分解压缩测试为可执行文件配置环境变量总结完整代…

23数字图像置乱技术(matlab程序)

1.简述 一、引言 所谓“置乱”,就是将图像的信息次序打乱,a像素移动到b像素位置上,b像素移动到c像素位置上,……,使其变换成杂乱无章难以辨认的图片。数字图像置乱技术属于加密技术,是指发送发借助数学或者…

Python实现PSO粒子群优化算法优化Catboost分类模型(CatBoostClassifier算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 PSO是粒子群优化算法(Particle Swarm Optimization)的英文缩写,是一…

《低代码指南》——轻流5.0发布,无代码引擎矩阵全面升级

7月6日,由轻流主办「无代码无边界 202376Day|轻流无代码探索者大会」于上海顺利举行。轻流也在会上重磅发布了更加开放、灵活、低门槛的轻流5.0,和全面升级的专有轻流。 轻流5.0全面迭代升级了轻流的无代码引擎矩阵(表单引擎、流程引擎、报表引擎、门户引擎、数据引擎)。…

软件测试项目实战,电商项目测试实例 - 业务测试(重点)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 支付功能怎么测试…

pytest自动化测试实战之执行参数

上一篇介绍了如何运行pytest代码,以及用例的一些执行规则,执行用例发现我们中间print输出的内容,结果没有给我们展示出来,那是因为pytest执行时,后面需要带上一些参数。 参数内容 我们可以在cmd中通过输入 pytest -h…

域名捡漏的好方法,希望能够帮到你:域霸扫描器 V0.44 绿色免费版,供大家学习研究参考

高速扫描域名的工具,一均程序每小时五万条。 扫描域名是否注册,注册商是谁,域名的注册日期与过期日期。 供大家学习研究参考! 下载:https://download.csdn.net/download/weixin_43097956/88025564

【SpringBoot——Error记录】

IDEA正常安装后,运行按钮为灰色解决方法尝试 解决方法一(本人适用)解决方法二 解决方法一(本人适用) 检查创建项目时JDK是否添加,版本是否正确。 解决方法二 点击左下角的Structure 参考链接&#xff1…

回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测

回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测预测效果基本介绍模型描述程序设计学习总结参考资料 预测效果 基本介绍 回归预测 …

node中的数据持久化之mongoDB

一、什么是mongoDB MongoDB是一种开源的非关系型数据库,正如它的名字所表示的,MongoDB支持的数据结构非常松散,是一种以bson格式(一种json的存储形式)的文档存储方式为主,支持的数据结构类型更加丰富的NoS…

mysql多表查询练习题

创建表及插入数据 create table if not exists dept3( deptno varchar(20) primary key , -- 部门号 name varchar(20) -- 部门名字 ); -- 创建员工表 create table if not exists emp3( eid varchar(20) primary key , -- 员工编号 ename varchar(20), -- 员工名字 age int, -…

换零钱——最小钱币张数(贪心算法)

贪心算法:根据给定钱币面值列表,输出给定钱币金额的最小张数。 (本笔记适合学完python基本数据结构,初通 Python 的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣…

CS EXE上线主机+文件下载上传键盘记录

前言 书接上文,CobaltStrike_1_部署教程及CS制作office宏文档钓鱼教程,该篇介绍【使用CS生成对应exe木马,上线主机;对上线主机进行,文件下载,文件上传,键盘记录】。 PS:文章仅供学习…

unseping

代码审计 <?php highlight_file(__FILE__);class ease{private $method;private $args;function __construct($method, $args) {$this->method $method;$this->args $args;}function __destruct(){if (in_array($this->method, array("ping"))) {call…

关于 colab Tutorial的介绍

&#xff08;一&#xff09;常用的快捷键 (二) 网上环境的配置 按照官网上所给的提示一步一步操作即可 注意&#xff1a;此平台需要科学的上网

word因导入mathtype不能使用复制粘贴快捷键的解决方法

1. 我们安装完mathtype后&#xff0c;有时会有两个mathtype显示&#xff0c;其中一个是属于office文件夹下的&#xff0c;另一个是win文件夹下的。如图&#xff1a; 2. 如果word中的复制粘贴快捷键&#xff08;CTRLC和CTRLV&#xff09;不能用&#xff0c;通常是因为office路径…

Arduino STM32F103C8+ST7735 1.8‘‘3D矢量图形demo

Arduino STM32F103C8ST7735 1.8’3D矢量图形demo &#x1f4cc;开源项目地址&#xff1a;https://github.com/cbm80amiga/ST7735_3d_filled_vector&#x1f527;所需库&#xff1a;https://github.com/cbm80amiga/Arduino_ST7735_STM&#x1f516;本开源工程基于Arduino开发平台…