【pytorch23】MNIST测试实战

news2024/9/27 7:22:15

理解

训练完之后也需要做测试
为什么要做test?
在这里插入图片描述
上图蓝色代表train的accuracy
下图蓝色代表train的loss
基本上符合预期,随着epoch增大,train的accuracy也会上升,loss也会一直下降,下降到一个较小的程度

但是如果只看train的情况的话,就会被欺骗,虽然accuracy高并且loss很低,你会以为这个算法就很好了,但是做其他的事情就不是特别好,这就是过拟合(overfitting)

deep learning表达能力非常强表达效果很好,在模型在训练数据上表现非常好,但在新的、未见过的数据上表现不佳的情况。这是因为模型学习到了训练数据中的特定噪声和细节,而不是更通用的特征。

如何缓解这种情况?
在train的时候做一个test,这个test使用validation set 验证集做的,在刚开始的阶段,蓝色的线在上升的时候,验证集的accuracy也会上升loss也与train基本一致,只不过是在训练集上面train在验证集上测试不一定完全符合,所以波动会有点大,很明显train会的更好,validation的表现(包括accuracy和loss)也会变的更好

说明在刚开始的阶段确实学到了一些通用的特征,随着时间的推移,就开始over fitting了,开始去记住一些噪声和细节,这样的话泛化能力会变差,所以在训练集上训练后,在验证集上测试的时候,accuracy会保持不变或者可能下降同样的loss也会巨幅的波动

深度学习所以并不是越训练越好,数据量和架构是核心问题,有一个好的结构再加上足够的数据才能取得一个好的结果

在这里插入图片描述
logits是一个是十个节点的向量,经过cross entropy loss(包含softmax和log和nll_loss)训练,得到loss和accuracy(经过softmax之后就变成了Y=i,i代表第i号节点的概率,只需要argmax之后就能得到概率最大所在的位置),这里对softmax之前和之后都做了一下argmax,其实是一样的效果,因为softmax不会改变单调性,即原来大的数据在softmax之后也会大

这是计算accuracy的基本流程

在这里插入图片描述
什么时候计算test的accuracy和loss

不能够每做一个batch就训练一次,这样就会花大量的时间做测试,不合理,尤其是对于大型数据集

一般情况:

  • 训练若干个batch做一次测试
  • 训练一个epoch做一次测试

如何做测试
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


def load_data(batch_size):
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=True)
    return train_loader, test_loader


class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x


def training(train_loader, net, device):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()))


def testing(test_loader, net, device):
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


global net

if __name__ == '__main__':
    batch_size = 200
    learning_rate = 0.01
    epochs = 10

    train_loader, test_loader = load_data(batch_size)

    device = torch.device('cuda:0')
    net = MLP().to(device)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)
    criteon = nn.CrossEntropyLoss().to(device)

    for epoch in range(epochs):
        training(train_loader, net, device)
        testing(test_loader, net, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1909728.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

世上最懂交易原理的是佛学

僧肇《肇论不空真论》思想非有非無逻辑 价值观矛盾冲突时(不落两边),血性即行迹逻辑(俗谛),才气即逻辑心证(真谛);意气即是美,美即是意气;一切以…

使用bypy丝滑传递百度网盘-服务器文件

前言 还在为百度网盘的数据集难以给服务器做同步而痛苦吗,bypy来拯救你了!bypy是一个强大而灵活的百度网盘命令行客户端工具。它是基于Python开发的开源项目,为用户提供了一种通过命令行界面与百度网盘进行交互的方式。使用bypy,…

仿写SpringIoc

1.SpringIoc简单注解 1.1 Autowired package com.qcby.iocdemo1.annotation;import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;Target(ElementType.FIEL…

Git 快速上手

这个文档适用于需要快速上手 Git 的用户,本文尽可能的做到简单易懂 ❤️❤️❤️ git 的详细讲解请看这篇博客 Git 详解(原理、使用) 1. 什么是 Git Git 是目前最主流的一个版本控制器,并且是分布式版本控制系统,可…

音视频开发—FFmpeg处理流数据的基本概念详解

文章目录 多媒体文件的基本概念相关重要的结构体操作数据流的基本步骤1.解复用(Demuxing)2.获取流(Stream)3. 读取数据包(Packet)4. 释放资源(Free Resources)完整示例 多媒体文件的…

聚焦云技术,探讨 AGI 时代的云原生数据计算系统

6月22日,开源中国社区在上海举办了 OSC 源创会活动,本期活动以「云技术」为主题,邀请了来自华为 openEuler、字节跳动、AutoMQ 等厂商的技术大咖进行分享,拓数派作为云原生数据计算领域的引领者,受邀参与了本次活动&am…

智慧城市可视化页面怎么做?免费可视化工具可以帮你

智慧城市是一个综合性的概念,广泛应用于各个领域,如基础设施建设、信息化应用、产业经济发展、市民生活品质等。 可视化页面的制作也是一个综合性的过程,需要确定展示内容、数据收集与处理、设计可视化元素等多个环节紧密配合。 1. 明确展示…

MySQL Innodb存储引擎中,当页默认的大小是16K时,页中最多存放多少行的记录?

1、题目引入 Innodb存储引擎是面向行的(row-oriented),也就是说数据的存放按行进行,每页存放的行记录是有硬性定义的,当页默认的大小是16K时,页中最多存放多少行的记录? A、1600 行B、8192 行C、16383 行D、7992 行 …

occ geo

随笔 - 12 文章 - 18 评论 - 117 阅读 - 13万 opencascade造型引擎功能介绍 现今的CAD 系统大多通常都基于CAD 系统提供的二次开发包,用户根据要求定制符合自己要求的功能。AutoCAD就提供了AutoLISP、ADS 等都是比较通用的开发工具包。UG 也提供了多种二次开发…

温州海经区管委会主任、乐清市委书记徐建兵带队莅临麒麟信安调研

7月8日上午,温州海经区管委会主任、乐清市委书记徐建兵,乐清市委常委、副市长叶序锋,乐清市委办主任郑志坚一行莅临麒麟信安调研,乐清市投资促进服务中心及湖南省浙江总商会相关人员陪同参加。麒麟信安董事长杨涛、总裁刘文清热情…

grafana数据展示

目录 一、安装步骤 二、如何添加喜欢的界面 三、自动添加注册客户端主机 一、安装步骤 启动成功后 可以查看端口3000是否启动 如果启动了就在浏览器输入IP地址:3000 账号密码默认是admin 然后点击 log in 第一次会让你修改密码 根据自定义密码然后就能登录到界面…

机器学习笔记:初始化0的问题

1 前言 假设我们有这样的两个模型: 第一个是逻辑回归 第二个是神经网络 他们的损失函数都是交叉熵 sigmoid函数的导数: 他们能不能用0初始化呢? 2 逻辑回归 2.1 求偏导 2.1.1 结论 2.1.2 L对a的偏导 2.1.3 对w1,w2求偏导 w2同…

k8s record 20240708

一、PaaS 云平台 web界面 资源利用查看 Rancher 5台 CPU 4核 Mem 4g 100g的机器 映射的目录是指docker重启后,数据还在 Rancher可以创建集群也可以托管已有集群 先docker 部署 Rancher,然后通过 Rancher 部署 k8s 想使用 kubectl 还要yum install 安…

中国AI大模型论文数量全球第一,清华力压麻省理工、斯坦福

论文是研究新技术、开发新产品获取“图纸”的重要途径之一,OpenAI的研究人员正是借鉴了Transformer的论文(被引用超过9万次),才开发出了对全球各行业影响巨大的产品ChatGPT。 而论文的数量、通过率和被引用次数是衡量一个国家科技…

电脑文件夹怎么设置密码?让你的文件更安全!

在日常使用电脑的过程中,我们常常会有一些需要保护的个人文件或资料。为了防止这些文件被他人未经授权访问,对重要文件夹设置密码是一种有效的保护措施,可是电脑文件夹怎么设置密码呢?本文将介绍2种简单有效的方法帮助您为电脑文件…

红酒与运动后的恢复:健康的双重助力

在繁忙的都市生活中,运动已成为许多人追求健康与活力的方式。当汗水洒落,肌肉得到锻炼,一场酣畅淋漓的运动后,身心仿佛得到了洗礼。而在这份宁静与满足之余,你是否想过,一杯优雅的红酒也能为你的运动后恢复…

以SGET协会OSM标准首创有662引脚的OSM模组——凌华智能引领嵌入式运算市场

在可焊接的45 x 45mm尺寸上提升功率 开启嵌入式运算发展的新时代 摘要: 1.开放式标准模块(OSM™),最大尺寸仅45 x 45mm,采用零开销的模块化系统简化生产,并提供662个引脚以增强小型化和物联网应用。 2.凌华智能提供基于NXP i.M…

二叉树中的最大路径和(Java版)

二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root &#xff0c…

AI自动生成PPT哪个软件好?高效制作PPT就用这4个

学生时期做各种小组作业需要做PPT,毕业后开始上班每周大大小小的各种会议和汇报,也少不了PPT的折磨。 倘若你也刚好有这种烦恼,那么不妨试试下面我给大家安利的这4款AI自动生成PPT免费软件~保准你用上以后可不再为PPT制作而发愁!…

面向过程编程详解

目录 前言1. 面向过程编程的定义2. 面向过程编程的特点2.1 过程和函数2.2 顺序执行2.3 全局变量2.4 控制结构 3. 面向过程编程的应用场景3.1 系统级编程3.2 科学计算3.3 小型项目 4. 面向过程编程的优缺点4.1 优点4.2 缺点 5. 代表性的编程语言5.1 C语言5.2 Pascal5.3 Fortran …