【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

news2025/1/22 12:31:23

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片

# 训练集
train_dataset = datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 

# 测试集
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # 灰度图
                out_channels=16,            # 要得到几多少个特征图
                kernel_size=5,              # 卷积核大小
                stride=1,                   # 步长
                padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
            ),                              # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(2),                # 输出 (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  
        output = self.out(x)
        return output
    
# 准确率作为评估标准
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] 
    rights = pred.eq(labels.data.view_as(pred)).sum() 
    return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = []

    for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环
        net.train()  # 将模型设置为训练模式
        output = net(data)  # 使用模型进行前向传播
        loss = criterion(output, target)  # 计算损失
        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新参数
        right = accuracy(output, target)  # 计算当前批次的准确率
        train_rights.append(right)  # 将准确率保存起来

        if batch_idx % 500 == 0:  # 每500个批次进行一次验证
            net.eval()  # 将模型设置为评估模式
            val_rights = []  # 存储验证集的准确率

            for (data, target) in test_loader:  # 在测试集上进行验证
                output = net(data)  # 使用模型进行前向传播
                right = accuracy(output, target)  # 计算验证集上的准确率
                val_rights.append(right)  # 将准确率保存起来

            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1],
                100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1327447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

竞赛保研 YOLOv7 目标检测网络解读

文章目录 0 前言1 yolov7的整体结构2 关键点 - backbone关键点 - head3 训练4 使用效果5 最后 0 前言 世界变化太快,YOLOv6还没用熟YOLOv7就来了,如果有同学的毕设项目想用上最新的技术,不妨看看学长的这篇文章,学长带大家简单的…

YOLOv8改进 | 主干篇 | 利用MobileNetV2替换Backbone(轻量化网络结构)

一、本文介绍 本文给大家带来的改进机制是MobileNetV2,其是专为移动和嵌入式视觉应用设计的轻量化网络结构。其在MobilNetV1的基础上采用反转残差结构和线性瓶颈层。这种结构通过轻量级的深度卷积和线性卷积过滤特征,同时去除狭窄层中的非线性&#xff…

【K8s】4# 使用kuboard部署开源项目实战

文章目录 1.开源项目2.实战2.1.创建spring-blade命名空间2.2.导入 spring-blade 到 K8S 名称空间2.3.设置存储卷参数2.4.调整节点端口2.5.确认导入2.6.查看集群2.7.导入配置到 nacos2.8.启动微服务工作负载 3.验证部署结果3.1.Nacos3.2. web 4.问题汇总Q1:Nacos启动…

centos7安装开源日志系统graylog5.1.2

安装包链接:链接:https://pan.baidu.com/s/1Zl5s7x1zMWpuKfaePy0gPg?pwd1eup 提取码:1eup 这里采用的shell脚本安装,脚本如下: 先使用命令产生2个参数代入到脚本中: 使用pwgen生成password_secret密码 …

CSS(五) -- 动效实现(立体盒子旋转-四方体+正六边)

一. 四面立体旋转 正方形旋转 小程序中 wxss中 <!-- 背景 --><view class"dragon"><!--旋转物体位置--><view class"dragon-position"><!--旋转 加透视 有立体的感觉--><view class"d-parent"><view …

Backtrader 文档学习-Data Feeds(上)

Backtrader 文档学习-Data Feeds 1.数据载入 Quickstart中已经学习了基础的数据载入到cerebro中。 self.datas 是按插入顺序的数组数组对象的别名self.data 和 self.data0 一样&#xff0c;都是指向第一组数据self.dataX 指向第N组数据 import backtrader as bt import bac…

【PC电脑windows-学习样例generic_gpio-拓展GPIO-ESP32的GPIO程序-问题解决-GPIO输出实验-基础样例学习(2)】

【PC电脑windows-学习样例generic_gpio-拓展GPIO-ESP32的GPIO程序-基础样例学习&#xff08;2&#xff09;】 1、概述2、实验环境3、 问题说明1&#xff1a;问题说明&#xff1a;使用官方样例&#xff0c;增加IO&#xff0c;编译会重新改回去。2&#xff1a;解决方式&#xff1…

STM32 使用ARM仿真器设置

STM32单片机程序下载到单片机芯片中有两种方式&#xff0c;①编译生成HEX&#xff0c;使用程序烧录软件刷到单片机芯片里。②使用ARM仿真器下载程序。使用ARM仿真器的优势是&#xff0c;在工程编译没问题直接在Keil软件里就可以将程序下载到单片机里&#xff0c;并且程序可以在…

苏州耕耘无忧物联网:降本增效,设备维护管理数字化转型的引领者

随着科技的快速发展和工业4.0的推动&#xff0c;设备维护管理已经从传统的被动式、经验式维护&#xff0c;转向了更加积极主动、数据驱动的维护模式。在这个过程中&#xff0c;苏州耕耘无忧物联科技有限公司以其深厚的技术积累和丰富的管理经验&#xff0c;引领着设备维护管理数…

ASP.NET Core基础之定时任务(二)-Quartz.NET入门

阅读本文你的收获 了解任务调度框架QuartZ.NET的核心构成学会在ASP.NET Core 中使用QuartZ.NET 在项目的开发过程中&#xff0c;难免会遇见需要后台处理的任务&#xff0c;例如定时发送邮件通知、后台处理耗时的数据处理等&#xff0c;上次分享了ASP.NET Core中实现定时任务的…

4. 行为模式 - 中介者模式

亦称&#xff1a; 调解人、控制器、Intermediary、Controller、Mediator 意图 中介者模式是一种行为设计模式&#xff0c; 能让你减少对象之间混乱无序的依赖关系。 该模式会限制对象之间的直接交互&#xff0c; 迫使它们通过一个中介者对象进行合作。 问题 假如你有一个创建…

el-date-picker时间戳问题

最近用el-date-picker时间插件&#xff0c;没想到只能得到格式化的日期&#xff0c;那能不能得到时间戳呢&#xff1f;答案是肯定的&#xff0c;最恶心的来了&#xff0c;按照大多数人提供的方案得到了一个莫名其妙的字符串&#xff0c;看起来很奇怪 经过不懈的努力找到了最终的…

通过U盘:将电脑进行重装电脑

目录 一.老毛桃制作winPE镜像 1.制作准备 2.具体制作 下载老毛桃工具 插入U盘 选择制作模式 正式配置U盘 安装提醒 安装成功 具体操作 二.使用ultrasio制作U盘 1.具体思路 2.图片操作 三.硬盘安装系统 具体操作 示例图 ​编辑 一.老毛桃制作winPE镜像 1.制作准…

神经网络:深度学习优化方法

1.有哪些方法能提升CNN模型的泛化能力 采集更多数据&#xff1a;数据决定算法的上限。 优化数据分布&#xff1a;数据类别均衡。 选用合适的目标函数。 设计合适的网络结构。 数据增强。 权值正则化。 使用合适的优化器等。 2.BN层面试高频问题大汇总 BN层解决了什么问…

使用@jiaminghi/data-view实现一个数据大屏

<template><div class"content bg"><!-- 全局容器 --><!-- <dv-full-screen-container> --><!-- 第二行 --><div class"module-box" style"align-items: start; margin-top: 10px"><!-- 左 -->…

【IntelliJ IDEA】打开项目Git突然无法识别解决方案

这个问题也是我今天突然偶尔遇到的&#xff0c;当时没在意&#xff0c;项目打开之后又关闭&#xff0c;后来很久才又打开&#xff0c;发现项目明明有git版本控制的&#xff0c;咋突然开发工具右下角没有标识了&#xff0c;然后检查了一下git配置还报错了。 其实从图上我们可以看…

ctfshow sql 195-200

195 堆叠注入 十六进制 if(preg_match(/ |\*|\x09|\x0a|\x0b|\x0c|\x0d|\xa0|\x00|\#|\x23|\|\"|select|union|or|and|\x26|\x7c|file|into/i, $username)){$ret[msg]用户名非法;die(json_encode($ret));}可以看到没被过滤&#xff0c;select 空格 被过滤了&#xff0c;可…

【Week-P2】CNN彩色图片分类-CIFAR10数据集

文章目录 一、环境配置二、准备数据三、搭建网络结构四、开始训练五、查看训练结果六、总结3.1 ⭐ torch.nn.Conv2d()详解3.2 ⭐ torch.nn.Linear()详解3.3 ⭐torch.nn.MaxPool2d()详解3.4 ⭐ 关于卷积层、池化层的计算4.2.1 optimizer.zero_grad()说明4.2.2 loss.backward()说…

SQL---Zeppeline前驱记录与后驱记录查询

内容导航 类别内容导航机器学习机器学习算法应用场景与评价指标机器学习算法—分类机器学习算法—回归机器学习算法—聚类机器学习算法—异常检测机器学习算法—时间序列数据可视化数据可视化—折线图数据可视化—箱线图数据可视化—柱状图数据可视化—饼图、环形图、雷达图统…

油猴脚本教程案例【键盘监听】-编写 ChatGPT 快捷键优化

文章目录 1. 元数据1. name2. namespace3. version4. description5. author6. match7. grant8. icon 2. 编写函数.1 函数功能2.1.1. input - 聚焦发言框2.1.2. stop - 取消回答2.1.3. newFunction - 开启新窗口2.1.4. scroll - 回到底部 3. 监听键盘事件3.1 监听X - 开启新对话…