Pytorch笔记之分类

news2024/11/19 21:23:22

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、模型评估
  • 总结


前言

使用Pytorch进行MNIST分类,使用TensorDataset与DataLoader封装、加载本地数据集。


一、导入库

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader # 数据集工具
from load_mnist import load_mnist # 本地数据集

二、数据处理

1、导入本地数据集,将标签值设置为int类型,构建张量
2、使用TensorDataset与DataLoader封装训练集与测试集

# 构建数据
x_train, y_train, x_test, y_test = \
    load_mnist(normalize=True, flatten=False, one_hot_label=False)
# 数据处理
x_train = torch.from_numpy(x_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.int64))
x_test = torch.from_numpy(x_test.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.int64))
# 数据集封装
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True)

三、构建模型

输入到全连接层之前需要把(batch_size,28,28)展平为(batch_size,784)
交叉熵损失函数整合了Softmax,在模型中可以不添加Softmax

# 继承模型
class FC(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 10)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        y = self.fc1(x.view(x.shape[0],-1))
        y = self.softmax(y)
        return y
# 定义模型
model = FC()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

四、迭代训练

从DataLoader中取出x和y,进行前向和反向的计算

for epoch in range(10):
    print('Epoch:', epoch)
    for i,data in enumerate(train_loader):
        x, y = data
        y_pred = model.forward(x)
        loss = loss_function(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

五、模型评估

在测试集中进行验证
使用.item()获得tensor的取值

	correct = 0
    for i,data in enumerate(test_loader):
        x, y = data
        y_pred = model.forward(x)
        _, y_pred = torch.max(y_pred, 1)
        correct += (y_pred == y).sum().item()
    acc = correct / len(test_dataset)
    print('Accuracy:{:.2%}'.format(acc))


总结

记录了TensorDataset与DataLoader的使用方法,模型的构建与训练和上一篇Pytorch笔记之回归相似。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1064635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

10.5汇编语言整理

【汇编语言相关语法】 1.汇编语言的组成部分 1.伪操作:不参与程序的执行,但是用于告诉编译器程序该怎么编译 .text .global .end .if .else .endif .data 2.汇编指令 编译器将一条汇编指令编译成一条机器码,在内存里一条指令占4字节内存&…

c++---模板篇

1、模板 概念:模板就是建立通用的模具,大大提高复用性 特点: 模板不可以直接使用,它只是一个框架模板的通用并不是万能的 1.1、函数模板 C另一种编程思想称为泛型编程,主要利用的技术就是模板C提供两种模板机制&a…

数据结构与算法(四):哈希表

参考引用 Hello 算法 Github:hello-algo 1. 哈希表 1.1 哈希表概述 哈希表(hash table),又称散列表,其通过建立键 key 与值 value 之间的映射,实现高效的元素查询 具体而言,向哈希表输入一个键…

Linux CentOS7 vim宏操作

vim的macro就是用来解决重复的问题。在vim寄存器的文章里面已经对macro有所涉及,macro的操作都是以文本的方式存放在寄存器中。 宏是一组命令的集合,应用极其广泛,包括MS Office中的word编辑器,excel编辑器和各种文本编辑器&…

Pytorch笔记之回归

文章目录 前言一、导入库二、数据处理三、构建模型四、迭代训练五、结果预测总结 前言 以线性回归为例,记录Pytorch的基本使用方法。 一、导入库 import numpy as np import matplotlib.pyplot as plt import torch from torch.autograd import Variable # 定义求…

ESP32/ESP8266在线刷写Sonoff Tasmota固件以及配置简要

ESP32/ESP8266在线刷写Sonoff Tasmota固件以及配置简要 📍原项目Github地址:https://github.com/arendst/Tasmota/tree/v13.1.0📑官方文档介绍:https://tasmota.github.io/docs/🚩(✨推荐方式✨)在线固件刷写地址&…

strcpy函数详解:字符串复制的利器

目录 一,strcpy函数的简介 二,strcpy函数的实现原理 三,strcpy函数的注意事项 四,strcpy函数的模拟实现 一,strcpy函数的简介 strcpy函数是C语言中的字符串复制函数,其原型如下: char * str…

Linux中的wc命令

2023年10月6月,周五晚上 目录 wc命令的主要功能和用法如下:统计文件行数、字数和字节数只统计行数只统计字数只统计字节数 wc命令在Linux/Unix系统中是word count的缩写,它用来统计文件的行数、字数和字节数。 wc命令的主要功能和用法如下: 统计文件行数、字数和字…

英语四六级高频核心词(故事版)

第一组:" A Century of Community Effort to Improve Quality of Life and Climate" In the early years of the 20th century, a small community found itself facing a decade of challenges. The most pressing issue was the mental quality of life…

VSC-HVDC直流输电matlab仿真模型

微❤关注“电气仔推送”获得资料(专享优惠) VSC-HVDC直流输电仿真,换流站采用两电平结构,全控型器件(IGBT),采用双环控制,包括电压外环,电流内环,分为d、q两…

【论文阅读】An Evaluation of Concurrency Control with One Thousand Cores

An Evaluation of Concurrency Control with One Thousand Cores Staring into the Abyss: An Evaluation of Concurrency Control with One Thousand Cores ABSTRACT 随着多核处理器的发展,一个芯片可能有几十乃至上百个core。在数百个线程并行运行的情况下&…

Springboot+vue的开放性实验室管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。

演示视频: Springbootvue的开放性实验室管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的开放性实验室管理系统,采用M&#xff08…

基于SSM的家庭财务管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

基于SSM的大学生就业信息管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

matlab之zeros函数语法与举例说明(附代码)

一、zeros函数语法与举例说明 (1)X zeros——返回标量0 X zeros 示例: (2)X zeros(n)——返回一个 nn 的全零矩阵 零矩阵: 示例:创建一个由零值组成的 33 矩阵 X zeros(3) (…

《protobuf》基础语法3

文章目录 默认值更新规则保留字段未知字段 默认值 在反序列化时,若被反序列化的二进制序列中不包含某个字段,则在反序列化时,就会设置对应默认值。不同的类型默认值不同: 类型默认值字符串“”布尔型false数值类型0枚举型0设置了…

基于SSM的旅游网站设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

国庆看坚如磐石

坚如磐石上映了,可以在爱奇艺观看。 而博主在使用蓝牙耳机连接电脑的过程中,发现没有蓝牙开启选项,并且在服务的设备管理器中也没有找到,很明显这是缺少驱动导致的,因此便去联想官方网站下载对应的驱动。 这里可以输入…

二分查找:34. 在排序数组中查找元素的第一个和最后一个位置

个人主页 : 个人主页 个人专栏 : 《数据结构》 《C语言》《C》《算法》 文章目录 前言一、题目解析二、解题思路1. 暴力查找2. 一次二分查找 部分遍历3. 两次二分查找分别查找左右端点1.查找区间左端点2. 查找区间右端点 三、代码实现总结 前言 本篇文…