基于CNN的FashionMNIST数据集识别3——模型验证

news2025/2/24 15:37:00

源码

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet



def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader


def test_model_process(model, test_dataloader):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'

    # 讲模型放入到训练设备中
    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

    # 计算测试准确率
    test_acc = test_corrects.double().item() / test_num
    print("测试的准确率为:", test_acc)




if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

源码讲解

当模型训练完毕后,我们得到的是一组最优的参数配置。

最后要做的就是验证这组参数的表现。

数据准备

def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader

测试数据的准备和训练数据的准备有明显不同:

  1. 测试数据集是将MINIST里所有的数据当做验证集。并且训练模式设置为false。
  2. dataloader里面的batch大小设置为1,也就是说对每个样本进行验证,不再存在“分批”的概念。

循环验证

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

和之前训练模型时的验证逻辑基本相同。只进行前向传播,将预测正确的样本个数进行累加。

代码运行

if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

在代码运行时,需要将参数载入到模型里,再进行验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2304461.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

洛谷P1135多题解

解法1&#xff1a;BFS&#xff0c;有n个节点每个节点最多被访问一次&#xff0c;所以BFS时间复杂度为O(n)。注意ab的特判。 #include<iostream> #include<cstring> #include<queue> using namespace std; const int N 205; int n, a, b; int k[N], s[N]; b…

用AI写游戏3——deepseek实现kotlin android studio greedy snake game 贪吃蛇游戏

项目下载 https://download.csdn.net/download/AnalogElectronic/90421306 项目结构 就是通过android studio 建空项目&#xff0c;改下MainActivity.kt的内容就完事了 ctrlshiftalts 看项目结构如下 核心代码 MainActivity.kt package com.example.snakegame1// MainA…

论文解读 | AAAI'25 Cobra:多模态扩展的大型语言模型,以实现高效推理

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入&#xff01; 点击 阅读原文 观看作者讲解回放&#xff01; 个人信息 作者&#xff1a;赵晗&#xff0c;浙江大学-西湖大学联合培养博士生 内容简介 近年来&#xff0c;在各个领域应用多模态大语言模型&#xff08;MLLMs&…

DPVS-3: 双臂负载均衡测试

测试拓扑 双臂模式&#xff0c; 使用两个网卡&#xff0c;一个对外&#xff0c;一个对内。 Client host是物理机&#xff0c; RS host都是虚拟机。 LB host是物理机&#xff0c;两个CX5网卡分别在两个子网。 配置文件 用dpvs.conf.sample作为双臂配置文件&#xff0c;其中…

记一次复杂分页查询的优化历程:从临时表到普通表的架构演进

1. 问题背景 在项目开发中&#xff0c;我们需要实现一个复杂的分页查询功能&#xff0c;涉及大量 IP 地址数据的处理和多表关联。在我接手这个项目的时候,代码是这样的 要知道代码里面的 ipsList 数据可能几万条甚至更多,这样拼接的sql,必然是要内存溢出的,一味地扩大jvm参数不…

架构师面试(六):熔断和降级

问题 在千万日活的电商系统中&#xff0c;商品列表页服务通过 RPC 调用广告服务&#xff1b;经过统计发现&#xff0c;在最近10秒的时间里&#xff0c;商品列表页服务在对广告服务的调用中有 98% 的调用是超时的&#xff1b; 针对这个场景&#xff0c;下面哪几项的说法是正确的…

细说 Java 引用(强、软、弱、虚)和 GC 流程(二)

一、前文回顾 在 细说Java 引用&#xff08;强、软、弱、虚&#xff09;和 GC 流程&#xff08;一&#xff09; 我们对Java 引用有了总体的认识&#xff0c;本文将继续深入分析 Java 引用在 GC 时的一些细节。 还是从我们在前文中提到的引用流程图里说起&#xff0c;这里不清…

【深度学习】Unet的基础介绍

U-Net是一种用于图像分割的深度学习模型&#xff0c;特别适合医学影像和其他需要分割细节的任务。如图&#xff1a; Unet论文原文 为什么叫U-Net&#xff1f; U-Net的结构像字母“U”&#xff0c;所以得名。它的结构由两个主要部分组成&#xff1a; 下采样&#xff08;编码…

ROS2机器人开发--服务通信与参数通信

服务通信与参数通信 在 ROS 2 中&#xff0c;服务&#xff08;Services&#xff09;通信和参数&#xff08;Parameters&#xff09;通信是两种重要的通信机制。服务是基于请求和响应的双向通信机制。参数用于管理节点的设置&#xff0c;并且参数通信是基于服务通信实现的。 1 …

DeepSeek写贪吃蛇手机小游戏

DeepSeek写贪吃蛇手机小游戏 提问 根据提的要求&#xff0c;让DeepSeek整理的需求&#xff0c;进行提问&#xff0c;内容如下&#xff1a; 请生成一个包含以下功能的可运行移动端贪吃蛇H5文件&#xff1a; 要求 蛇和食物红点要清晰&#xff0c;不超过屏幕外 下方有暂停和重新…

Python安全之反序列化——pickle/cPickle

一&#xff0e; 概述 Python中有两个模块可以实现对象的序列化&#xff0c;pickle和cPickle&#xff0c;区别在于cPickle是用C语言实现的&#xff0c;pickle是用纯python语言实现的&#xff0c;用法类似&#xff0c;cPickle的读写效率高一些。使用时一般先尝试导入cPickle&…

Deepin(Linux)安装MySQL指南

1.下载 地址&#xff1a;https://downloads.mysql.com/archives/community/ 2.将文件解压到 /usr/local 目录下 先cd到安装文件所在目录再解压&#xff0c;本机是cd /home/lu01/Downloads sudo tar -xvJf mysql-9.2.0-linux-glibc2.28-x86_64.tar.xz -C /usr/local3.创建软链…

vue-fastapi-admin 部署心得

vue-fastapi-admin 部署心得 这两天需要搭建一个后台管理系统&#xff0c;找来找去 vue-fastapi-admin 这个开源后台管理框架刚好和我的技术栈所契合。于是就浅浅的研究了一下。 主要是记录如何基于原项目提供的Dockerfile进行调整&#xff0c;那项目文件放在容器外部&#xf…

计算机视觉算法实战——三维重建(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ 1. 三维重建领域简介 三维重建&#xff08;3D Reconstruction&#xff09;是计算机视觉的核心任务之一&#xff0c;旨在通过多视角图像、视频…

迎接DeepSeek开源周[Kimi先开为敬]发布开源最新Muon优化器可替代 AdamW计算效率直接翻倍

Muon优化器在小规模语言模型训练中表现出色&#xff0c;但在大规模模型训练中的可扩展性尚未得到证实。月之暗面通过系统分析和改进&#xff0c;成功将 Muon 应用于 3B/16B 参数的 MoE 模型训练&#xff0c;累计训练 5.7 万亿 token。结果表明&#xff0c;Muon 可以替代 AdamW …

【工作流】Spring Boot 项目与 Camunda 的整合

【工作流】Spring Boot 项目与 Camunda 的整合 【一】Camunda 和主流流程引擎的对比【二】概念介绍【1】Camunda 概念&#xff1a;【2】BPMN 概念 【三】环境准备【1】安装流程设计器CamundaModeler【画图工具】&#xff08;1&#xff09;下载安装 【2】CamundaModeler如何设计…

Grouped-Query Attention(GQA)详解: Pytorch实现

Grouped-Query Attention&#xff08;GQA&#xff09;详解 Grouped-Query Attention&#xff08;GQA&#xff09; 是 Multi-Query Attention&#xff08;MQA&#xff09; 的改进版&#xff0c;它通过在 多个查询头&#xff08;Query Heads&#xff09;之间共享 Key 和 Value&am…

docker基操

docker基操 首先就是安装docker使用docker:创建容器-制作一个镜像-加载镜像首先就是安装docker 随便找一个教程安装就可以,安装过程中主要是不能访问谷歌,下面这篇文章写了镜像的一些问题: 安装docker的网络问题 使用docker:创建容器-制作一个镜像-加载镜像 主要是参考:这篇…

SF-HCI-SAP问题收集1

最近在做HCI的集成&#xff0c;是S4的环境&#xff0c;发现很多东西都跑不通&#xff0c;今天开始收集一下错误点 如果下图冲从0001变成0010&#xff0c;sfiom_rprq_osi表就会存数据&#xff0c;系统检查到此表就会报错&#xff0c;这个选项的作用就是自定义信息类型也能更新&a…

当 OpenAI 不再 open,DeepSeek 如何掀起 AI 开源革命?

开源与闭源的路线之争成为了行业瞩目的焦点&#xff0c;DeepSeek掀起的 AI 开源风暴&#xff01; 一、硅谷“斯普特尼克时刻” 1957年&#xff0c;苏联将人类首颗人造卫星“斯普特尼克”送上太空&#xff0c;美国举国震动。 这颗“篮球”般的卫星&#xff0c;刺痛了自诩科技霸…