pytorch工具——使用pytorch构建一个分类器

news2025/1/22 12:52:33

目录

  • 分类器任务和数据介绍
  • 训练分类器的步骤
  • 在GPU上训练模型

分类器任务和数据介绍

在这里插入图片描述

训练分类器的步骤

在这里插入图片描述
在这里插入图片描述

#1
import torch
import torchvision
import torchvision.transforms as transforms

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) #三个部分的数据的均值,标准差都为0.5
trainset=torchvision.datasets.CIFAR10(root='./data1',train=True,download=True,transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True)
testset=torchvision.datasets.CIFAR10(root='./data1',train=False,download=True,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

在这里插入图片描述

展示若干训练集图片
在这里插入图片描述

#2
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.pool=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
    
net=Net()
print(net)

在这里插入图片描述

#3
import torch.optim as optim

criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
#4
for epoch in range(2):
    running_loss=0.0
    #按批次迭代训练模型
    for i,data in enumerate(trainloader,0):
        inputs,labels=data
        optimizer.zero_grad()
        outputs=net(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        #打印训练信息
        running_loss+=loss.item()
        if (i+1)%2000==0:
            print('[%d,%5d] loss:%.3f'%(epoch+1,i+1,running_loss/2000))
            running_loss=0
            
print('finished training')

#设定模型保存位置
PATH='./cifar_net.pth'
#保存模型的状态字典
torch.save(net.state_dict(),PATH)

在这里插入图片描述

#5
dataiter=iter(testloader)
images,labels=next(dataiter)
print('groundtrue:',' '.join('%5s'%classes[labels[j]] for j in range(4)))
#加载模型参数,在测试阶段
net.load_state_dict(torch.load(PATH))
#利用模型对图片进行预测
outputs=net(images)
_,predicted=torch.max(outputs,1)
print('predicted:',''.join('%5s'%classes[predicted[j]] for j in range(4)))

在这里插入图片描述

#5
#在整个测试集上测试模型准确率
correct=0
total=0
with torch.no_grad():
    for data in testloader:
        images,labels=data
        outputs=net(images)
        _,predicted=torch.max(outputs.data,1) #_是最大值,predicted是最大值下标
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
        
print('accuracy of the network on the 10000 test images:%d %%'%(100*correct/total))

在这里插入图片描述

在这里插入图片描述

分别测试不同类别的模型准确率
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在GPU上训练模型

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/779512.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringCloud学习路线(8)—— Docker

一、Docker的开始 (一)项目部署问题: 依赖关系复杂,容易出现兼容性问题开发、测试、生产环境有差异 (二)Docker如何解决问题? 1、依赖兼容问题 (1)将应用的Libs&…

垃圾回收之三色标记法(Tri-color Marking)

关于垃圾回收算法,基本就是那么几种:标记-清除、标记-复制、标记-整理。在此基础上可以增加分代(新生代/老年代),每代采取不同的回收算法,以提高整体的分配和回收效率。 无论使用哪种算法,标记…

(已解决)RuntimeError: Java gateway process exited before sending its port number

今天用Pycharm远程使用pysaprk解释器时,跑代码出现了这个错误: RuntimeError: Java gateway process exited before sending its port number 找了好多博客都没解决问题,有说重装spark的,有说本地配Java_home的,后面我…

[C语言刷题]杨氏矩阵、返回型参数

本文包含知识点 杨氏矩阵极其解法函数return多个值的四种方法 题目: 杨氏矩阵 有一个数字矩阵,矩阵的每行从左到右是递增的,矩阵从上到下是递增的,请编写程序在这样的矩阵中查找某个数字是否存在。 要求:时间复杂度小于…

js 在浏览器窗口关闭后还可以不中断网络请求

有个需求,我们需要在用户发送数据过程中,如果用户关闭了网页(包括整个浏览器关闭),不要中断数据传递 目前XMLHttpRequest对象是不支持的 http服务器 为了测试效果我们用nodejs写了个http服务器代码 文件名为httpServer.js如下,…

获取大疆无人机的飞控记录数据并绘制曲线

机型M350RTK,其飞行记录文件为加密的,我的完善代码如下 gitgithub.com:huashu996/DJFlightRecordParsing2TXT.git 一、下载安装官方的DJIFlightRecord git clone gitgithub.com:dji-sdk/FlightRecordParsingLib.git飞行记录文件在打开【我的电脑】&am…

Windows nvm 安装后webstrom vue项目编译报错,无法识别node

1 nvm安装流程 卸载原先nodejs用管理员权限打开exe安装nvmnvm文件夹和nodejs文件夹 都授权Authenticated Users 完全控制nvm list availablenvm install 16.20.1nvm use 16.20.1输入node和npm检查版本命令,正常显示确认系统变量和用户变量都有nvm 和nodejs 2 bug情…

数学建模-聚类算法 系统(层次)聚类

绝对值距离:网状道路 一般用组间和组内距离 聚类的距离计算如何选取:看结果是否解释的通,选择一种结果解释的通的方法。

【数据挖掘】将NLP技术引入到股市分析

一、说明 在交易中实施的机器学习模型通常根据历史股票价格和其他定量数据进行训练,以预测未来的股票价格。但是,自然语言处理(NLP)使我们能够分析财务文档,例如10-k表格,以预测股票走势。 二、对自然语言处…

【转载+修改】pytorch中backward求梯度方法的具体解析

原则上,pytorch不支持张量对张量的求导,它只支持标量对张量的求导 我们先看标量对张量求导的情况 import torch xtorch.ones(2,2,requires_gradTrue) print(x) print(x.grad_fn)输出,由于x是被直接创建的,也就是说它是一个叶子节…

Vue.js uni-app 混合模式原生App webview与H5的交互

在现代移动应用开发中,原生App与H5页面之间的交互已经成为一个常见的需求。本文将介绍如何在Vue.js框架中实现原生App与H5页面之间的数据传递和方法调用。我们将通过一个简单的示例来展示如何实现这一功能。附完整源码下载地址:https://ext.dcloud.net.cn/plugin?i…

Java集成openAi的ChatGPT实战

效果图: 免费体验地址:AI智能助手 具体实现 public class OpenAiUtils {private static final Log LOG LogFactory.getLog(OpenAiUtils.class);private static OpenAiProxyService openAiProxyService;public OpenAiUtils(OpenAiProxyService openAiP…

【C++】入门 --- 命名空间

文章目录 🍪一、前言🍩1、C简介🍩2、C关键字 🍪二、命名冲突🍪三、命名空间🍩1、命名空间定义🍩2、命名空间的使用 🍪四、C输入&输出 🍪一、前言 本篇文章是《C 初阶…

Data Transfer Object-DTO,数据传输对象,前端参数设计多个数据表对象

涉及两张表的两个实体对象 用于在业务逻辑层和持久层(数据库访问层)之间传输数据。 DTO的主要目的是将多个实体(Entity)的部分属性或多个实体关联属性封装成一个对象,以便在业务层进行数据传输和处理,从而…

八、HAL_UART(串口)的接收和发送

1、开发环境 (1)Keil MDK: V5.38.0.0 (2)STM32CubeMX: V6.8.1 (3)MCU: STM32F407ZGT6 2、UART和USART的区别 2.1、UART (1)通用异步收发收发器:Universal Asynchronous Receiver/Transmitter)。 2.2、USART (1)通用同步异步收发器:Universal Syn…

【《R4编程入门与数据科学实战》——一本“能在日常生活中使用统计学”的书】

《R 4编程入门与数据科学实战》的两名作者均为从事编程以及教育方面的专家,他们用详尽的语言,以初学者的角度进行知识点的讲解,每个细节都手把手教学,以让读者悉数掌握所有知识点,在每章的结尾都安排理论与实操相结合的习题。与同…

banner轮播图实现、激活状态显示和分类列表渲染、解决路由缓存问题、使用逻辑函数拆分业务(一级分类)【Vue3】

一级分类 - banner轮播图实现 分类轮播图实现 分类轮播图和首页轮播图的区别只有一个,接口参数不同,其余逻辑完成一致 适配接口 export function getBannerAPI (params {}) {// 默认为1 商品为2const { distributionSite 1 } paramsreturn httpIn…

VTK是如何显示一个三维立体图像的

VTK是如何显示一个三维立体图像的 1、文字描述2、图像演示 1、文字描述 2、图像演示

MySQL-事务-介绍与操作

思考 假设在一个场景中,学工部解散了,需要删除该部门及该部门下的员工对应的SQL语句涉及的数据表信息如下 员工表 部门表 实现的SQL语句 -- todo 事务 -- 删除学工部 -- 删除1号部门 delete from tb_dept where id 1; -- 删除学工部下的员工 delete …

SPEC CPU 2006 docker gcc:4 静态编译版本 Ubuntu 22.04 LTS 测试报错Invalid Run

runspec.sh #!/bin/bash source shrc ulimit -s unlimited runspec -c gcc41.cfg -T all -n 1 int fp > runspec.log 2>&1 & tail -f runspec.log runspec.log 由于指定了-T all,导致-n 1 失效,用例运行了三次(后续验证&…