pytorch-MNIST测试实战

news2024/12/24 0:04:46

目录

  • 1. 为什么test
  • 2. 如何做test
  • 3. 什么时候做test
  • 4. 完整代码

1. 为什么test

如下图:上下两幅图中蓝色分别表示train的accuracy和loss,黄色表示test的accuracy和loss,如果单纯看train的accuracy和loss曲线就会认为模型已经train的很好了,accuracy一直在上升接近于1了,loss一直在下降已经接近于0了,殊不知此时可能已经出现了over fitting(本数据集准确率很高,其他数据准确率很低),此时就需要test了,从图中可以看出test在红色划线右侧的accuracy已经不变甚至下降了,loss曲线波动也比较大,甚至已经上升了。
在这里插入图片描述

2. 如何做test

如下图所示:
argmax找出概率最大的数字的index
softmax在这里使用与不使用结果是一样的,因为softmax不改变单调性(大的依然大,小的依然小)
使用torch.eq计算预测值与目标值是否相当,相等返回1不等返回0
correct.sum().float().item() /4是用来计算accuracy的,其他sum()是计算正确的个数,item是tensor转bumpy; /4是除以总样本数
在这里插入图片描述

3. 什么时候做test

  • 每几个batch做一次
  • 一个epoch做一次
    注意:为什么不一个batch做一次test呢?因为test的数据可能也比较大,每个batch都test会影响train的速度

4. 完整代码

从一下代码可知,test是一个epoch做一次,首先像train一样load test数据,并搬到GPU中,然后数据输入到网络中,计算loss,最后计算准确了并打印输出

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1620081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++初识--------带你从不同的角度理解引用的巧妙之处

1.对于展开的理解 我们这里的展开包括命名空间的展开和头文件的展开,两者的含义是不一样的: 头文件的展开就是把头文件拷贝到当前的文件里面; 命名空间的展开不是拷贝,而是因为编译器本身默认是到全局里面去找,当我…

【热议】硕士和读博士洗碗区别的两大理论

::: block-1 “时问桫椤”是一个致力于为本科生到研究生教育阶段提供帮助的不太正式的公众号。我们旨在在大家感到困惑、痛苦或面临困难时伸出援手。通过总结广大研究生的经验,帮助大家尽早适应研究生生活,尽快了解科研的本质。祝一切顺利!—…

SWOT分析法:知彼知己的战略规划工具

文章目录 一、什么是SWOT分析法二、SWOT分析法如何产生的三、SWOT分析法适合哪些人四、SWOT分析法的应用场景五、SWOT分析法的优缺点六、SWOT分析实例 一、什么是SWOT分析法 SWOT分析法是一种用于评估组织、项目、个人或任何其他事物的战略规划工具。SWOT是Strengths&#xff…

组态风格的工业可视化大屏,既同步状态又掌控数据,一箭双雕。

可视化大屏中加入了组态图,状态和数据一目了然了,我看还有谁说可视化大屏没啥用啦。 将组态图放入可视化大屏中可以起到以下几个作用: 1. 实时监控: 组态图可以用来实时监控设备、系统或者生产线的运行状态。通过大屏展示&#…

HackMyVM-Alzheimer

目录 信息收集 arp nmap FTP服务信息收集 匿名登陆 关键信息 knock WEB信息收集 信息收集 gobuster 目录爆破 ssh登录 提权 系统信息收集 提权 get root 信息收集 arp ┌──(root㉿0x00)-[~/HackMyVM] └─# arp-scan -l Interface: eth0, type: EN10MB, MAC…

应用于智能装备制造,钡铼IOy系列模块展现其强大的灵活性和实用性

随着科技的飞速发展,智能制造已经成为工业4.0时代的核心驱动力。在此背景下,钡铼技术推出的IOy系列模块以其独特的设计、卓越的性能以及无可比拟的灵活性与实用性,在智能装备制造领域展现出了强大的技术优势和应用价值。 首先,钡…

Excel 冻结前几行

Excel中有冻结首航和冻结首列的选项,但是如果想冻结前几行该怎么操作? 冻结首行或冻结首列 视图 -> 冻结窗格 -> 冻结首行或冻结首列 冻结前几行或前几列 视图 -> 冻结窗格 -> 冻结拆分窗格 具体冻结几行和几列取决于当前选中的单元格。…

力扣HOT100 - 114. 二叉树展开为链表

解题思路&#xff1a; class Solution {List<TreeNode> list new ArrayList<>();public void flatten(TreeNode root) {recur(root);for (int i 1; i < list.size(); i) {TreeNode pre list.get(i - 1);TreeNode cur list.get(i);pre.left null;pre.right…

不同伦敦金网上平台的投资者都在使用的平仓技术

现在几乎是百分之一百的伦敦金交易都在伦敦金网上平台进行。市面上有不同的伦敦金网上平台&#xff0c;那有没有一些交易技术&#xff0c;不论是什么伦敦金网上平台的投资者都喜欢用的呢&#xff1f;答案是肯定的&#xff0c;下面我们就从平仓这个角度来讨论一下伦敦金网上平台…

LeetCode - 11.盛最多水的容器

一. 题目链接 LeetCode - 11.盛最多水的容器 二. 思路解释 利用双指针的思想&#xff0c;定义一个left和reght&#xff0c;left指向首部&#xff0c;right指向尾部&#xff0c;计算当前两个指针所对应的高度构成容器的体积。根据当前双指针所指的高度的大小&#xff0c;然后让…

精益人效,实践为先|第四届狮山人力资源论坛圆满举办

4月19日 &#xff0c;在苏州日航酒店&#xff0c;由中国苏州人力资源服务产业园、苏州高新区人力资源服务产业园指导&#xff0c;盖雅工场、盖雅学苑和盖雅人效研究院主办的 「精益人效 实践为先——第四届狮山人力资源论坛」圆满结束。 700余位企业管理者与人力资源从业者&am…

【每日刷题】Day23

【每日刷题】Day23 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 138. 随机链表的复制 - 力扣&#xff08;LeetCode&#xff09; 2. 链表的回文结构_牛客题霸_牛客网 …

邂逅JavaScript逆向爬虫-------基础篇之面向对象

目录 一、概念二、对象的创建和操作2.1 JavaScript创建对象的方式2.2 对象属性操作的控制2.3 理解JavaScript创建对象2.3.1 工厂模式2.3.2 构造函数2.3.3 原型构造函数 三、继承3.1 通过原型链实现继承3.2 借用构造函数实现继承3.3 寄生组合式继承3.3.1 对象的原型式继承3.3.2 …

zabbix6.4告警配置(短信告警和邮件告警),脚本触发

目录 一、前提二、告警配置1.邮件告警脚本配置2.短信告警脚本配置3.zabbix添加报警媒介4.zabbix创建动作4.给用户添加报警媒介 一、前提 已经搭建好zabbix-server 在需要监控的mysql服务器上安装zabbix-agent2 上述安装步骤参考我的上篇文章&#xff1a;通过docker容器安装za…

软考-系统集成项目管理中级--合同管理

本章历年考题分值统计(16年11月及以后按新教材考的&#xff09; 本章重点常考知识点汇总清单(学握部分可直接理解记忆) 8、合同签订管理(掌握)10下53&#xff0c;14上53&#xff0c;15上53 考题 签订合同的前期调查&#xff0c;每一项合同在签订之前&#xff0c;应当做好以下几…

Python蜘蛛侠

目录 写在前面 蜘蛛侠 编写代码 代码分析 更多精彩 写在后面 写在前面 本期小编给大家推荐一个酷酷的Python蜘蛛侠&#xff0c;一起来看看叭~ 蜘蛛侠 蜘蛛侠&#xff08;Spider-Man&#xff09;是美国漫威漫画宇宙中的一位标志性人物&#xff0c;由传奇创作者斯坦李与艺…

MySQL主从结构搭建

说明&#xff1a;本文介绍如何搭建MySQL主从结构&#xff1b; 原理 主从复制原理如下&#xff1a; &#xff08;1&#xff09;master数据写入&#xff0c;更新binlog&#xff1b; &#xff08;2&#xff09;master创建一个dump线程向slave推送binlog&#xff1b; &#xff…

javaScript中的作用域和作用域链

作用域&#xff08;Scope&#xff09; 什么是作用域 作用域是在运行时代码中的某些特定部分中变量、对象和函数的可访问性。 换句话说&#xff0c;作用域决定了代码区块中变量和其他资源的可见性。 示例&#xff1a; function outFun2() {var inVariable "内层变量2…

通过本机端口映射VMware中虚拟机应用(例如同一局域网别人想远程连接你虚拟机中的数据库)

需要 虚拟机中安装一下达梦数据库&#xff0c;并且以后大家都连接你虚拟机中达梦数据库进行开发。。。。。。在不改动自己虚拟机配置&#xff0c;以及本地网卡任何配置的情况下如何解决&#xff1f;本虚拟机网络一直使用的NAT模式。 解决 找到NAT设置添加端口转发即可解决。…

springboot+springsecurity+vue前后端分离权限管理系统

有任何问题联系本人QQ: 1205326040 1.介绍 优秀的权限管理系统&#xff0c;核心功能已经实现&#xff0c;采用springbootvue前后端分离开发&#xff0c;springsecurity实现权限控制&#xff0c;实现按钮级的权限管理&#xff0c;非常适合作为基础框架进行项目开发。 2.效果图…