[chapter 26][PyTorch][MNIST 测试实战】

news2024/11/27 11:39:38

前言

      这里面结合手写数字识别的例子,讲解一下训练时候注意点

目录

  1.  训练问题
  2. 解决方案
  3. 参考代码

   一  训练问题

   训练的时候,我们的数据集分为Train Data 和 validation Data。

随着训练的epoch次数增加,我们发现Train Data 上精度

先逐步增加,但是到一定阶段就会出现过拟合现象。

validation Data 上面不再稳定,反而出现下降的趋势,泛化能力变差.


二  解决方案

   test once serveral batch(几个batch,验证一次)

   test once per epoch(每一轮训练完后,验证一次)

    test once serveral epoch(几轮训练后,验证一次)

   

   当发现验证集acc到达一定精度,且下降后,停止训练


    三  参考代码

# -*- coding: utf-8 -*-
"""
Created on Mon Apr 10 21:51:21 2023

@author: cxf
"""

import torch
import torch.nn.functional as F


def validation():
    
    logits = torch.rand(6,10)
    pred = F.softmax(logits, dim=1)
    print(pred.shape)
    
    
    pred_label= pred.argmax(dim=1)
    print(pred_label)
    
    label= torch.tensor([0,1,2,3,4,5])
    N = label.shape[0]
    
    correct = torch.eq(pred_label, label)
    
    print(correct)
    
    acc = correct.sum().float().item()/N
    
    print("\n acc %f"%acc)
    
validation()
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms

#超参数
batch_size=200
learning_rate=0.01
epochs=10

#获取训练数据
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,          #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ])),
    batch_size=batch_size, shuffle=True)                          #按batch_size分出一个batch维度在最前面,shuffle=True打乱顺序

#获取测试数据
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)


class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        
        # 定义网络的每一层,nn.ReLU可以换成其他激活函数,比如nn.LeakyReLU()
        self.model = nn.Sequential(     
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)
        return x
    
device = torch.device('cuda:0')                     #使用第一张显卡
net = MLP().to(device)
# 定义sgd优化器,指明优化参数、学习率
# net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28).to(device)          # 将二维的图片数据摊平[样本数,784]
        target = target.to(device)
        logits = net(data)                  # 前向传播
        loss = criteon(logits, target)       # nn.CrossEntropyLoss()自带Softmax

        optimizer.zero_grad()                # 梯度信息清空
        loss.backward()                      # 反向传播获取梯度
        optimizer.step()                     # 优化器更新

        if batch_idx % 100 == 0:             # 每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0                                         # correct记录正确分类的样本数
    for data, target in test_loader:
        data = data.view(-1, 28 * 28).to(device)
        target = target.to(device)
        logits = net(data)
        test_loss += criteon(logits, target).item()     # 其实就是criteon(logits, target)的值,标量

        pred = logits.data.max(dim=1)[1]                # 也可以写成pred=logits.argmax(dim=1)
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

参考:

    课时53 MNIST测试实战_哔哩哔哩_bilibili

https://www.cnblogs.com/douzujun/p/13323078.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

协议篇之以太网ARP协议

协议篇之以太网ARP协议一、什么是ARP协议?作用是什么?二、ARP请求与ARP应答三、以太网ARP数据报格式四、总结一、什么是ARP协议?作用是什么? ARP(Address Resolution Protocol),地址解析协议&am…

GUID分区与MBR分区有什么区别? 操作系统知识

GUID分区与MBR分区有什么区别? 操作系统知识 1、MBR分区表类型的磁盘 主引导记录(Master Boot Record,缩写:MBR),又叫做主引导扇区,它仅仅包含一个64个字节的硬盘分区表。由于每个分区信息需要…

以ChatGPT为例进行自然语言处理学习——入门自然语言处理

⭐️我叫忆_恒心,一名喜欢书写博客的在读研究生👨‍🎓。 如果觉得本文能帮到您,麻烦点个赞👍呗! 近期会不断在专栏里进行更新讲解博客~~~ 有什么问题的小伙伴 欢迎留言提问欧,喜欢的小伙伴给个三…

Python办公自动化之PostgreSQL篇2——利用Python连接PostgreSQL并读取一张表

在上一篇我们已经安装好了最新的PostgreSQL,以及最方便的可视化工具,Navicat 如果错过的小伙伴,可以去上一篇查看:点我查看 今天我们来用Python连接一下PostgreSQL,然后准备一张测试表,导入PostgreSQL&am…

elasticsearch 拼音分词器 自动补全。

elasticsearch 拼音分词器 & 自动补全。 文章目录elasticsearch 拼音分词器 & 自动补全。2. 自动补全。2.1. 拼音分词器。2.2. 自定义分词器。2.3. 自动补全查询。2.4. 实现酒店搜索框自动补全。2.4.1. 修改酒店映射结构。2.4.2. 修改 HotelDoc 实体。2.4.3. 重新导入。…

Shader Graph10-Min, Max, Clamp, Saturate节点

打开UE,新建Material叫做DemoMinMaxClamp,双击打开 一、Minimum节点,两个值比较取较小的。 Min的含义是,红框的0.5为参数B的值,1.0为白色圆形的值,下面的0.5为背景颜色值。图片中每个像素值与0.5进行比较&a…

java基于mvc的停车收费系统mysql

系统需要解决的主要问题有: (1)车位管理模块 添加车位、查看车位状态、车位信息查询等。 (2)客户信息管理模块 客户基本信息录入、客户信息查询等。 (3)卡业务办理 添加卡信息、查余额查询、卡充值。 (4)车辆信息管理模块 车牌信息录入等。 (5)收费管理 可以调整相应…

【Java 数据结构】集合类 (精华篇)

🎉🎉🎉点进来你就是我的人了 博主主页:🙈🙈🙈戳一戳,欢迎大佬指点!人生格言:当你的才华撑不起你的野心的时候,你就应该静下心来学习! 欢迎志同道合的朋友一起加油喔🦾&am…

一本通 3.4.3 图的连通性

1383:刻录光盘(cdrom) 【题目描述】 在FJOI2010夏令营快要结束的时候,很多营员提出来要把整个夏令营期间的资料刻录成一张光盘给大家,以便大家回去后继续学习。组委会觉得这个主意不错!可是组委会一时没有足够的空光盘&#xff…

数学术语——指数的发展历程

指数的发展历程 指数(exponents)的历史可以追溯到许多世纪以前,欧几里德(Euclid)被认为是第一个已知的指数用法。他用“幂(power)”这个词来表示我们今天所知的一个数自乘的次数(注:底数连同其右上角的指数一起的整体形式称为“幂”)。古希腊数学家使用…

寄存器:计算机中的小而强大的存储器件

目录 什么是寄存器? 寄存器的作用 提高计算机的性能 存储处理器需要快速访问的数据 存储函数调用时的参数和返回值 存储中间计算结果 寄存器的种类 程序计数器 指令寄存器 状态寄存器 通用寄存器 寄存器的进化过程 寄存器:计算机中的小而强大…

Linux操作基础(文件系统和日志分析)

文章目录一、inode与block1.1inode和block概述1.2 inode包含文件的元信息1.3 linux文件系统的三个时间戳1.4 inode的号码1.5 inode的大小1.6 inode号的特点1.7软连接与硬链接二 、文件恢复2.1 xfsdump恢复2.2 opic恢复方式三 、日志文件3.1 日志文件的分类3.2 日志的格式3.3 常…

大数据分析案例-基于决策树算法构建信用卡违约预测模型

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

定制你的专属大模型 Finetuner+体验开启!

如 ChatGPT、GPT4 这样的大型语言模型就像是你为公司请的一个牛人顾问,他在 OpenAI、Google 等大公司被预训练了不少的行业内专业知识,所以加入你的公司后,你只需要输入 Prompt 给他, 介绍一些业务上的背景知识,他就能…

Flink学习:Flink如何打印窗口的开始时间和结束时间

Window一、简介二、代码实现三、测试一、简介 大家知道,Flink用水位线和窗口机制配合来处理乱序事件,保证窗口计算数据的正确性,当水位线超过窗口结束时间的时候,就会触发窗口计算 水位线是动态生成的,根据进入窗口的最大事件时间-允许延迟时间 那么窗口的开始时间和结束时间…

力扣70爬楼梯:思路分析+优化思路+代码实现+补充思考

文章目录第一部分:题目描述第二部分:思路分析2.1 初步分析2.2 问题描述2.3 优化思路第三部分:代码实现第四部分:补充思考第一部分:题目描述 🏠 链接:70. 爬楼梯 - 力扣(LeetCode&am…

“衰老标志物”重磅综述:细胞衰老、器官衰老、衰老时钟及其应用

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 随着人口老龄化程度不断加深,实现“健康老龄化(healthy aging)”已成为我国乃至世界迫切需要解决的重大社会和科学问题。据测算,我国60岁及…

LVGL界面开发之模拟器环境搭建

前言 通常我们在使用 LVGL 进行界面开发时,会先在PC上搭建模拟器环境,而不是直接烧录到硬件板子上,使用模拟器是百利而无一害的,而且它是跨平台的,任何Windows,Linux或macOS系统都可以运行PC模拟器。每当界…

网上投票系统的设计与实现(论文+源码)_kaic

摘要 随着全球Internet的迅猛发展和计算机应用的普及,特别是近几年无线网络的广阔覆盖以及无线终端设备的爆炸式增长,使得人们能够随时随地的访问网络,以获取最新信息、参与网络活动、和他人在线互动。为了能及时地了解民情民意,把…

【高项】项目风险管理与采购管理(十大管理)

【高项】项目风险管理与采购管理(十大管理) 文章目录1、风险管理1.1 什么是风险管理?1.2 规划风险管理 & 识别风险(规划)1.3 实施定性风险分析(规划)1.4 实施定量风险分析(规划&…