学习pytorch20 pytorch完整的模型验证套路

news2024/11/13 8:03:13

pytorch完整的模型验证套路

  • 使用非数据集的测试数据,测试训练好模型的效果
  • 代码
  • 预测结果
  • 解决报错

B站小土堆pytorch学习视频 https://www.bilibili.com/video/BV1hE411t7RN/?p=32&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

在这里插入图片描述

使用非数据集的测试数据,测试训练好模型的效果

 测试:训练好的模型,提供对外真实数据的一个实际应用

从网上下载两张图片,整理图片的输入格式,输入模型测试模型效果
请添加图片描述
请添加图片描述

代码

import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import cv2

dog_path = './images/dog.jpg'
airplane_path = './images/airplane.jpg'
model_path = './images/net_epoch9_gpu.pth'

dog_pil = Image.open(dog_path)
airp_pil = Image.open(airplane_path)
print(dog_pil)  # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=258x174 at 0x237CC2CBE50>
# RGB 3通道 匹配模型输入的通道数
dog_pil = dog_pil.convert('RGB')  # def convert(self, mode=None, matrix=None, dither=None, palette=Palette.WEB, colors=256):
airp_pil = airp_pil.convert('RGB')
# dog_cv = cv2.imread(dog_path)  # numpy.array
# # print(dog_cv)
# img_trans = torchvision.transforms.ToTensor()  # 实例化转tensor的类
# dog_tensor = img_trans(dog_pil)
# dog_cv_tensor = img_trans(dog_cv)
# print(dog_tensor)
# print(dog_tensor.shape)
# print(dog_cv_tensor)
# 输入模型shape 需要是32*32大小的
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor()])
dog_tensor = transform(dog_pil)
airp_tensor = transform(airp_pil)
# print(dog_tensor)
print(dog_tensor.shape, airp_tensor.shape)


class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

# 加载模型要考虑是以哪种形式保存的模型  模型保存方式1:保存模型结构和参数 方式二:只保存模型参数
model = torch.load(model_path, map_location=torch.device('cpu'))
print(model)
dog_tensor = dog_tensor.reshape((1, 3, 32, 32))
airp_tensor = airp_tensor.reshape((1, 3, 32, 32))
model.eval() # 设置模型为测试状态 网络层有dropout batchNormal层不加eval函数会有问题
with torch.no_grad(): # 测试不做梯度计算 节省算力
    dog_output = model(dog_tensor)
    airp_output = model(airp_tensor)
print(dog_output)
print(dog_output.argmax())
print(dog_output.argmax(1))
print(airp_output)
print(airp_output.argmax(1))  # 概率值不便于解读 使用argmax 可以很方便的读出模型预测的是哪个类别

预测结果

在这里插入图片描述

tensor([[ 1.1317, -4.3441,  3.2116,  2.8930,  2.6749,  4.6079, -3.2860,  3.1357,
         -3.0432, -4.1703]])
tensor(5)
tensor([5])
tensor([[ 5.5993, -0.6140,  4.4758,  0.8463,  1.6311, -1.0217, -3.9990, -2.8343,
          1.1050, -1.6423]])
tensor([0])

预测结果和训练数据的标注一直,预测正确

解决报错

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x16 and 1024x64)
解决: dog_tensor = dog_tensor.reshape((1, 3, 32, 32)) 转换输入是4维的, 模型输入有一个batch-size维度

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.
解决:model = torch.load(model_path, map_location=torch.device(‘cpu’))
在gpu上训练的模型,要在cpu上测试,模型加载的时候指定cpu设备

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1304970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智能优化算法应用:基于鸡群算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于鸡群算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于鸡群算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.鸡群算法4.实验参数设定5.算法结果6.参考文献7.MA…

西南科技大学数字电子技术实验五(用计数器设计简单秒表)预习报告

一、计算/设计过程 说明:本实验是验证性实验,计算预测验证结果。是设计性实验一定要从系统指标计算出元件参数过程,越详细越好。用公式输入法完成相关公式内容,不得贴手写图片。(注意:从抽象公式直接得出结果,不得分,页数可根据内容调整) 1.设计个位电路图 QA、QB、…

简单的实现 mybatisplus实现真实的批量插入

总所周知&#xff0c;mybatisplus 的saveBatch()是一个伪批量插入&#xff0c;性能比较差。真实的批量插入需要for循环读取value 拼装成一条insert语句才插入。下面我将简单的介绍 使用mybatisplus实现真实的批量的步骤。 1.引入依赖&#xff0c;3.4.0之上的版本都可以 <de…

正向代理 反向代理

正向代理&#xff08;Forward Proxy&#xff09;和反向代理&#xff08;Reverse Proxy&#xff09;都是代理服务器的两种形式&#xff0c;它们在网络中扮演着不同的角色&#xff0c;并具有不同的应用场景。 正向代理 正向代理位于客户端和目标服务器之间。客户端通常需要配置…

mysql 快捷登陆

要将 MySQL 的登录命令添加到环境变量中并为其创建别名&#xff0c;可以按照以下步骤进行操作&#xff1a; 1. 打开终端并编辑 /etc/profile 文件&#xff08;使用所有用户的全局设置&#xff09; vim /etc/profile 2. 在文件的末尾添加以下行来设置环境变量和别名 # 将 &q…

基于ssm乐购游戏商城系统论文

摘 要 随着社会的发展&#xff0c;游戏品种越来越多&#xff0c;计算机的优势和普及使得乐购游戏商城系统的开发成为必需。乐购游戏商城系统主要是借助计算机&#xff0c;通过对信息进行管理。减少管理员的工作&#xff0c;同时也方便广大用户对个人所需信息的及时查询以及管理…

vue的小练习-翻转单词

先将字符串转成数组&#xff0c;用reverse&#xff08;&#xff09;翻转数组&#xff0c;再转成字符串 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevic…

python 实现 AIGC 大模型中的概率论:生日问题的基本推导

在上一节中&#xff0c;我们对生日问题进行了严谨的阐述&#xff1a;假设屋子里面每个人的生日相互独立&#xff0c;而且等可能的出现在一年 365 天中的任何一天&#xff0c;试问我们需要多少人才能让某两个人的生日在同一天的概率超过 50%。 处理抽象逻辑问题的一个入手点就是…

Docker部署Mysql5.7x和Myslq8.x

Docker部署Mysql5.7x和Myslq8.x 文章目录 1.部署mysql5.7.x2.部署mysql8.x3.创建用户授权及远程登录3.1 mysql5.7创建用户授权及远程登录3.2 mysql8创建用户授权及远程登录 4.总结 1.部署mysql5.7.x 在D盘下的mysql目录下新建如下目录&#xff1a; D:\mysql\conf\my.cnf内容如下…

OpenVINS学习2——VIRAL数据集eee01.bag运行

前言 周末休息了两天&#xff0c;接着做上周五那个VIRAL数据集没有运行成功的工作。现在的最新OpenVINS需要重新写配置文件&#xff0c;不像之前那样都写在launch里&#xff0c;因此需要根据数据集情况配置好estimator_config.yaml还有两个标定参数文件。 VIRAL数据集 VIRAL…

【工具栏】idea安装翻译工具

然后重启idea 打开设置 翻译方式&#xff1a; 选中要翻译的文本 然后右键 运行项目的时候&#xff0c;方便查找错误

GPT-4「变懒」问题将被修复;英伟达选择越南成公司“第二故乡”丨 RTE 开发者日报 Vol.104

开发者朋友们大家好&#xff1a; 这里是 「RTE 开发者日报」 &#xff0c;每天和大家一起看新闻、聊八卦。 我们的社区编辑团队会整理分享 RTE &#xff08;Real Time Engagement&#xff09; 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的…

基于VGG-16+Android+Python的智能车辆驾驶行为分析—深度学习算法应用(含全部工程源码)+数据集+模型(二)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理1&#xff09;数据集来源2&#xff09;数据集内容3&#xff09;数据集预处理 2. 模型构建1&#xff09;定义模型结构2&#xff09;优化损失函数 相关其它博客工程源代码下载其它资料下载 前言 本项目…

flex布局一行n个

上图 缩小后 主要用了 flex-basis flex-grow flex-shrink flex的三个属性 有兴趣的可以看看 深入理解CSS之flex精要之 flex-basis flex-grow flex-shrink 实战讲解 .bg{background-color: aquamarine;width: 100%;height: 100%;display: flex;flex-wrap: wrap;}.box1{backgr…

Python Thefuck库详解:让错误命令变得“友好”

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com Python中有许多强大的库&#xff0c;其中Thefuck库独具特色&#xff0c;它的作用是纠正用户在终端输入的错误命令&#xff0c;让操作变得更加友好和高效。在本篇博客文章中&#xff0c;我们将深入探讨Thefuck库的…

d2l绘图不显示的问题

之前试了各种方法都不行 在pycharm中还是不行&#xff0c;但是在anaconda中的命令行是可以的 anaconda prompt conda activaye py39 #进入f盘 F: #运行文件 python F:\python_code\softmax.py

vue-print-nb ,element-ui => table打印不全不说原理直接上代码

你的边框的颜色能深就深点&#xff0c;有的时候打印不出来 如果你出现这种情况请复制以下代码&#xff1a; <style media"print" scoped> page {size: auto;/* auto is the initial value *//* margin: 3mm; */margin-bottom: 0mm;/* this affects the margin…

小白必看!海外静态ip和动态ip解析!

在如今的时代&#xff0c;互联网已经成为我们生活中必不可少的一部分。无论是工作、学习还是娱乐&#xff0c;我们都得要一个稳定快速的网络连接。而在某些特殊情况下&#xff0c;海外静态ip和动态IP就变得非常重要。这篇文章就来解析这两种IP的类型&#xff0c;帮助新手们更好…

CGAN笔记总结第二弹~

CGAN原理与源码分析 一、复习GAN1.1损失函数1.2判别器源码1.3 生成器源码 二、什么是CGAN&#xff1f;2.1 CGAN原理图2.2条件GAN的损失函数2.3 生成器源码2.4 判别器源码2.5 训练过程1&#xff09;这里的训练顺序2&#xff09;为什么先训练判别器后训练生成器呢&#xff1f; 2.…

Dijkstra求最短路 II(堆优化Dijkstra算法)

给定一个 n 个点 m 条边的有向图&#xff0c;图中可能存在重边和自环&#xff0c;所有边权均为非负值。 请你求出 11 号点到 n 号点的最短距离&#xff0c;如果无法从 11 号点走到 n 号点&#xff0c;则输出 −1−1。 输入格式 第一行包含整数 n 和 m。 接下来 m 行每行包含…