MNIST手写数据集项目

news2025/1/11 8:18:13
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

一、数据处理部分

①定义数据转换 ②加载训练和测试数据,实验数据转换 ③使用dataloader加载数据,类似与生成器,需要多少数据取多少数据

#数据转换
transforms = torchvision.transforms.Compose([
    transforms.ToTensor()   #将array转换为tensor 才能作为神经网络的输入
])

train = torchvision.datasets.MNIST('/mnist',train=True,download=True,transform = transforms) #mnist中训练集和测试集已经分好了
test = torchvision.datasets.MNIST('/mnist',train=False,download=True,transform= transforms) # 要记得加上transforms进行数据的转换
#定义批次 一次性给神经网络输入的数据量   是否进行打乱
dl_train = torch.utils.data.DataLoader(train,batch_size=64,shuffle=True) 
dl_test = torch.utils.data.DataLoader(test,batch_size=64,shuffle=True)

以下可以查看图像

image,label = next(iter(dl_train)) #图片和真实的标签,取一个batch也就是64张
image.shape#64批量大小,1是黑白图像,28*28是图像大小
#查看一张图片
im = image[0]#tensor
im = im.numpy()#转换为np
im.shape#(1,28,28)
#灰度图
im= np.squeeze(im)#将维度为1 的维度删掉
im.shape #(28,28)
#图片批量展示
plt.figure(figsize=(16,8)) ##长为16高为8的画板
#写循环输出
for i in range(len(image[:8])): #image[0:8]   的维度是【8,1,28,28】   也就是从一个批次中取了八张图片
    #每次取一张图片
    img = image[:8][i].numpy()
    img = np.squeeze(img)##(1,28,28)的图像需要删除1 才可以show   
    label_img = label[:8][i].numpy()#将标签作为标题
    plt.subplot(2,4,i+1)#在大图中绘制小图
    plt.title('img_label:{}'.format(label_img))
    plt.imshow(img)  

在这里插入图片描述

二、 定义模型 定义一个class

class IMNet(nn.Module):
    def __init__(self):  #定义神经网络
        super(IMNet,self).__init__() #从nn.module中继承一些基本的参数
        self.layer1 = nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64*5*5,1024),
            nn.Linear(1024,10)
        )
        #self.conv1 = nn.Conv2d(1,32,kernel_size=3)  #输入通道 卷积核个数 卷积核大小
        #self.conv2 = nn.Conv2d(32,64,kernel_size=3) 
        #self.pool = nn.MaxPool2d(2,2) #卷积核大小 步长
        #self.fc1 = nn.Linear(64*5*5,1024) 
        #self.fc2 = nn.Linear(1024,10) #十分类
        
    def forward(self,x): #前向传播,定义神经网络的使用
        #x = F.relu(self.conv1(x))
        #x = self.pool(x)
        #x = F.relu(self.conv2(x))
        #x = self.pool(x)   
        #print(x.shape)#知道图片的大小     torch.Size([64,64,5,5])
        #x = x.view(-1,64*5*5)#压缩数据  linear层的数据需要的是一维的,而输入数据是四维的需要进行压缩
        #x = self.fc1(x)
        #x = self.fc2(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(-1,64*5*5)
        x = self.fc(x)
        
        return x#返回预测值
model = IMNet() #实例化模型   #可以通过模型中的print知道 图片的大小

三、定义损失函数优化器,训练函数

#损失函数    
loss_func = nn.CrossEntropyLoss()
#优化器
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to('cuda')#放到GPU上
import time#查看学习时间
from tqdm import tqdm #查看学习的速度
correct = 0 #正确率
total_number = 0
running_loss = 0 
epochs = 5
def trainNet(epochs,model,dl_train,dl_test):
    correct = 0
    total_number = 0
    running_loss = 0
    test_correct = 0
    test_total_number = 0
    test_running_loss = 0
    for epoch in range(epochs):
        model.train()#训练模式
        for x,y in tqdm(dl_train):#x是数据,y是label   #显示进度 
            x,y = x.to('cuda'),y.to('cuda')
            y_prediction = model(x) #输入到图像中输出预测值
            loss = loss_func(y_prediction,y)#真实值与预测值之间差距
            optimizer.zero_grad()#对优化器进行初始化
            loss.backward()#通过loss进行反向传播
            optimizer.step() #把模型放到优化器中进行梯度下降,对loss进行梯度下降的计算
            #统计准确率
            with torch.no_grad(): #以下内容不放到梯度中,统计准确度(要优化的只有神经网络的权重)
                y_finalPred = torch.argmax(y_prediction,dim=1)#取出最大值,dim=1在10中找到最大值
                #输出是【64,10】(64张图片每张图片有10个预测值)最大的值是预测的值
                correct += (y_finalPred==y).sum().item() #64张图中预测正确也就是一个批次中预测正确的
                total_number += y.size(0) #总数 
                running_loss += loss.item()  #一个批次中的loss 累加成总体的损失值
        #一整个epoch中的loss  所有图片加起来的loss/图片数量(一个epoch60000)  平均到每张图片的loss
        epoch_loss = running_loss/len(dl_train.dataset) 
        epoch_acc = correct / total_number

        model.eval()#预测模式
        for x,y in tqdm(dl_test):
            x,y = x.to('cuda'),y.to('cuda')
            y_prediction = model(x)
            loss = loss_func(y_prediction,y)
            with torch.no_grad():
                y_finalPred = torch.argmax(y_prediction,dim=1)
                test_correct += (y_finalPred==y).sum().item()
                test_correct += y.size(0)
                test_running_loss += loss.item()
        test_epoch_loss = test_running_loss/len(dl_test.dataset) 
        test_epoch_acc = test_correct / test_correct
        print(epoch,'accuracy:{}'.format(round(epoch_acc,3)),'loss:{}'.format(round(epoch_loss,3)))
        print(epoch,'test_accuracy:{}'.format(round(test_epoch_acc,3)),'test_loss:{}'.format(round(test_epoch_loss,3)))
trainNet(epochs,model,dl_train,dl_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1412046.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ansible自动化运维(三)Playbook 模式详解

👨‍🎓博主简介 🏅云计算领域优质创作者   🏅华为云开发者社区专家博主   🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入! 🐋 希望大家多多支…

输入框内容和占位符过长悬浮提示

1. 输入框内容过长&#xff0c;内容悬浮提示 <el-tooltip :disabled"isShowTooltip" effect"dark" :content"formData.toChineseCode" placement"top"><el-inputv-model"formData.toChineseCode"mouseover.native&…

【开源项目】超经典开源项目实景三维数字孪生海洋牧场、海上平台

飞渡科技数字孪生海上运营平台&#xff0c;基于数字孪生、物联网IOT、远程控制与遥感等技术&#xff0c;结合前端智能感知设备数据和专业模型算法&#xff0c;对海洋养殖环境、船只、网箱、监测设备等各要素态势进行综合监测分析&#xff0c;对各类异常事件进行可视化预警、告警…

微信开放平台第三方授权(第一篇)

1.项目需要的功能 需求不可以用微信原有的开发接管方式&#xff0c;之后发现可以进行第三方授权。先登录微信开放平台&#xff0c;找到第三方平台。 微信开放平台 配置第三方平台 基本概念介绍 | 微信开放文档 进入到管理配置 以上连接配置参考 接收授权事件&#xff1a;h…

深度解析Java8社招面试题:Lambda序列化到底行不行?

大家好&#xff0c;我是小米&#xff0c;一个热爱技术分享的小伙伴。今天&#xff0c;我们来聊一个关于Java8的话题&#xff0c;一个颇具技术深度的问题&#xff1a;“社招面试题&#xff1a;Java8中的Lambda表达式可以序列化吗&#xff1f;”废话不多说&#xff0c;让我们一起…

stm32中的SPI

SPI的简介 文章目录 SPI的简介物理层协议层基本通讯过程起始和终止信号数据有效性CPOL/CPHA及通讯模式 STM3的SPI特性及架构通讯引脚时钟控制逻辑数据控制逻辑整体控制逻辑通讯过程 代码配置实现指令集结构体的定义SPI时钟信号的定义SPI端口定义SPI命令 flash驱动代码初始化代码…

【Python】paddleocr快速使用及参数详解

文章目录 1. paddleocr快速使用1.1 使用默认模型路径1.2 设定模型路径 2. PaddleOCR其他参数介绍PaddleOCR模型推理参数解释 其它相关推荐&#xff1a; PaddleOCR模型训练及使用详细教程 官方网址&#xff1a;https://github.com/PaddlePaddle/PaddleOCR PaddleOCR是基于Paddle…

【软考问题】-- 3 - 知识精讲 - 项目整合管理

一、基本问题 1&#xff1a;项目章程的内容包括什么&#xff1f;&#xff08;助记&#xff1a;疯木鱼-要进庙里-发神经&#xff09; 疯&#xff1a;项目整体风险木&#xff1a;项目目标鱼&#xff1a;整体预算要&#xff1a;概要设计进&#xff1a;总体里程碑进度庙&#xff1a…

深度学习(4)--Keras安装

目录 Keras安装: 1.1.安装CUDA/cuDDN工具包 1.1.1.安装前准备 1.1.2.安装CUDA 1.1.3.安装cuDDN 1.2.安装Anaconda 1.3.安装tensorflow框架 1.3.1.使用cmd安装 1.3.2.使用Anaconda Prompt安装 1.4.安装Keras框架 1.5.打开jupyter notebook&#xff0c;执行import调用…

Linux/Academy

Enumeration nmap 首先扫描目标端口对外开放情况 nmap -p- 10.10.10.215 -T4 发现对外开放了22,80,33060三个端口&#xff0c;端口详细信息如下 结果显示80端口运行着http&#xff0c;且给出了域名academy.htb&#xff0c;现将ip与域名写到/et/hosts中&#xff0c;然后从ht…

Redis数据结构与底层实现揭秘

在高并发的系统开发中&#xff0c;缓存和高效的数据存储机制对于提升应用性能至关重要。Redis&#xff0c;作为其中的佼佼者&#xff0c;以其卓越的性能和丰富的数据结构赢得了开发者的青睐。本文将深入探讨Redis的数据结构及其底层实现&#xff0c;带领读者走进这个高性能数据…

【云原生】Docker的镜像创建

目录 1&#xff0e;基于现有镜像创建 &#xff08;1&#xff09;首先启动一个镜像&#xff0c;在容器里做修改 ​编辑&#xff08;2&#xff09;然后将修改后的容器提交为新的镜像&#xff0c;需要使用该容器的 ID 号创建新镜像 实验 2&#xff0e;基于本地模板创建 3&am…

【网站项目】基于SSM的249作业提交与查收系统

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

【Python爬虫入门到精通】小白也能看懂的知识要点与学习路线

文章目录 1. 写在前面2. 爬虫行业情况3. 学习路线 【作者主页】&#xff1a;吴秋霖 【作者介绍】&#xff1a;Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作&#xff01; 【作者推荐】&#xff1a;对JS逆向感兴趣的朋友可以关…

计数指针:shared_ptr (共享指针)与函数 笔记

推荐B站视频&#xff1a; 4.shared_ptr计数指针_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV18B4y187uL?p4&vd_sourcea934d7fc6f47698a29dac90a922ba5a3 5.shared_ptr与函数_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV18B4y187uL?p5&vd_sourcea…

AI引爆算力需求,思腾推出支持大规模深度学习训练的高性能AI服务器

近日人工智能研究公司OpenAI公布了其大型语言模型的最新版本——GPT-4&#xff0c;可10秒钟做出一个网站&#xff0c;60秒做出一个游戏&#xff0c;参加了多种基准考试测试&#xff0c;它的得分高于88%的应试者&#xff1b;随后百度CEO李彦宏宣布正式推出大语言模型“文心一言”…

扫雷游戏——数组和函数实现

扫雷游戏的功能说明 使⽤控制台实现经典的扫雷游戏 游戏可以通过菜单实现继续玩或者退出游戏扫雷的棋盘是9*9的格⼦ 默认随机布置10个雷可以排查雷如果位置不是雷&#xff0c;就显⽰周围有⼏个雷如果位置是雷&#xff0c;就炸死游戏结束把除10个雷之外的所有⾮雷都找出来&…

域名缩短平台搭建

前言 当自己搭建的项目和网站相关文章的链接过长&#xff0c;可以参考一下本文搭建的平台 遵纪守法&#xff0c;不要乱缩网址。 代码&#xff1a; https://github.com/dyanst/shorturlhttps://github.com/dyanst/shorturl shorturl-main.zip官方版下载丨最新版下载丨绿色版…

Linux(linux版本 centos 7) 下安装 oracle 19c详细教程(新手小白易上手)

一、安装前准备 1、下载预安装包 wget http://yum.oracle.com/repo/OracleLinux/OL7/latest/x86_64/getPackage/oracle-database-preinstall-19c-1.0-1.el7.x86_64.rpm预安装包下载成功 2、下载oracle安装包 下载地址如下 https://www.oracle.com/cn/database/technologies…

Maven命令运行单元测试

使用idea开发多模块项目时,有时别的模块编译不通过会导致不能运行单元测试,这是我们可以使用maven命令来运行单元测试 格式 mvn -DtestDingTalkTest#getAllUsers 命令说明 mvn -Dtest 固定格式 DingTalkTest 单元测试类名 getAllUsers 单元测试方法 单元测试类和单元测试方法…