简单搭建卷积神经网络实现手写数字10分类

news2024/9/21 12:41:08

搭建卷积神经网络实现手写数字10分类

1.思路流程

1.导入minest数据集

2.对数据进行预处理

3.构建卷积神经网络模型

4.训练模型,评估模型

5.用模型进行训练预测

一.导入minest数据集

 

MNIST--->raw--->test-->(0,1,2...) 10个文件夹

MNIST--->raw--->train-->(0,1,2...) 10个文件夹

共60000张图片.可自己去网上下载

二.对数据进行预处理

----读取图片,将图片先转为张量。

img=cv2.imread(path)

----将图片进行归一化,即将像素值标准化到0-1之间

img_tensor=util.transforms_train(img)

----裁剪,翻转等,实现数据增强。

数据增强:通过对原始图像进行旋转、翻转等操作,可以增加数据的多样性。这有助于模型学习到更具泛化性的特征,减少对特定方向或位置的依赖,从而提高模型的鲁棒性和准确性

transforms_train=transforms.Compose([
    # transforms.CenterCrop(10),
    # transforms.PILToTensor(),
    transforms.ToTensor(),#归一化,转tensor
    transforms.Resize((28,28)),
    transforms.RandomVerticalFlip()
])

ps:为什么要归一化

  1. 消除量纲影响:不同图像的像素值范围可能差异很大。归一化可以将像素值范围统一到一个特定的区间,例如 [0, 1] 或 [-1, 1],消除不同图像之间因像素值范围差异带来的影响,使模型更关注图像的特征和结构,而不是像素值的绝对大小。

  2. 提高训练稳定性:有助于优化算法的收敛性和稳定性。如果像素值范围较大且分布不均匀,可能导致梯度计算不稳定,从而影响模型的训练效率和效果。

  3. 缓解过拟合:一定程度上可以减少数据中的噪声和异常值对模型的影响,降低模型对某些特定像素值的过度依赖,从而提高模型的泛化能力,减少过拟合的风险。

三.构建卷积神经网络模型

常见卷积神经网络(CNN),主要由卷积,池化,全连接组成。卷积核在输入图像上滑动,通过卷积运算提取局部特征。卷积核在整个图像上重复使用,大大减少了模型的参数数量,降低了计算复杂度,同时也增强了模型对平移不变性的鲁棒性。池化层对特征进行压缩,提取主要特征,减少噪声和冗余信息。

x=torch.randn(2,3,28,28)

用x表示初始图形的信息。为了简单理解,简单表述。其中

2--->两张图片

3--->图片的通道数是3个,即 RGB

28,28--->图片的宽高是28px 28px

采用以上的神经网络conv为卷积操作,maxpool为池化。Linear为全连接。relu为激活函数。

进入全连接层时需要将展平。torch.Size([2, 16, 5, 5])--->torch.Size([2, 400])

x=torch.flatten(x,1)

因为全连接是只进行的线性的变化。所以要把每张图片的维数参数降为1。

使用print(summary(net, x))可查看网络的层次结构。其中-1就表示自己算,是多少张图片就是多少

输入的的是x=torch.randn(2,3,28,28),最终输出的是(2,10)

四.训练模型,评估模型

需要初始化之前的数据和网络,然后选择合适的优化器和损失函数,学习率和加载图片的批次去训练模型。使用loss_avg和accurary来评估模型的性能。对于pytorch来说优化器可以实现自动梯度清0,自动更新参数。我们需要主要的是就实现其中的维度的转化。loss越小越接近真实值。其中计算精度的方法使用one-hot编码。其中0表示[0,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],2表示[0,0,1,0,0,0,0,0,0,0].。。。其他依次类推。我们把用网络得出的参数,类似[0.1,0.2,0.1,0.5,0,0,0,0,0,0](数据我随便写的),然后用Python的argmax去处最大值的索引与one-hot真实值的索引相比,如果相等就是正确的结果。

----本次实验使用的是MSE损失函数

----lr(学习率)设为0.01

----使用的优化器Adam ,其实其他优化器你也可以随便试试。

Adam 算法的主要优点包括:

  1. 自适应学习率:能够为每个参数自适应地调整学习率。

  2. 偏差校正:在初始阶段对梯度估计进行校正,加速初期的学习速率。

  3. 适应性强:在很多不同的模型和数据集上都表现出良好的性能。

  4. 实现简单,计算高效,对内存需求少。

使用tensorboard进行可视化

五.用模型进行训练预测

需要读取之前训练好的模型,然后用这个模型来实现预测一个自己手写的图片

    # 加载整个模型
    loaded_model = torch.load('whole_model.pth')
​
    # 保存模型参数
    torch.save(loaded_model.state_dict(),'model_params.pth')

代码附上:

dataset.py

import glob
import os.path
​
import cv2
import torch
import util
​
class DataAndLabel:
    def __init__(self,path='D:\\0MNIST\\raw',is_train=True):
        super().__init__()
        # 拼接路径
        #data里面是path,label
        clas='train' if is_train==True else 'test'
        path=os.path.join(path,clas)
        paths=glob.glob(os.path.join(path,'*','*'))
        # print(paths)
        # print(path)
        self.data=[]
        for path in paths:
            label=int(path.split('\\')[-2])
            self.data.append((path,label))
    def __getitem__(self, idx):
        #返回一个tensor,one-hot
        path,label =self.data[idx]
        img=cv2.imread(path)
        # cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
        img_tensor=util.transforms_train(img)
        one_hot=torch.zeros(10)
        one_hot[label]=1
        return img_tensor,one_hot
    def __len__(self):
        return len(self.data)
# if __name__ == '__main__':
#     data=DataAndLabel()
#     print(data[0])
#     print()

lenet5.py

import torch
import torch.nn as nn
from torchkeras import summary
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(3,6,5,1)
        self.maxpool1=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(6,16,3,1)
        self.maxpool2=nn.MaxPool2d(2)
        self.layer1=nn.Linear(16*5*5,10)
        self.layer2=nn.Linear(10,10)
        self.relu=nn.Softmax()
    def forward(self,x):
        x=self.conv1(x)
        x=self.relu(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.relu(x)
        x=self.maxpool2(x)
        # print(x.shape)
        x=torch.flatten(x,1)
        # print(x.shape)
​
        x=self.layer1(x)
        x=self.layer2(x)
        return x
if __name__ == '__main__':
    x=torch.randn(2,3,28,28)
    net=Net()
    out=net(x)
    # print(out.shape)
    # print(summary(net, x))

train_and_test

import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from lenet5 import Net
import torch.nn as nn
from dataset import DataAndLabel
class TrainAndTest(Dataset):
    def __init__(self):
        super().__init__()
        # self.writer=SummaryWriter("logs")
        net=Net()
        self.net=net
        self.loss=nn.MSELoss()
        self.opt = torch.optim.Adam(net.parameters(), lr=0.1)
        self.train_data=DataAndLabel(is_train=True)
        self.test_data=DataAndLabel(is_train=False)
        self.train_loader=DataLoader(self.train_data,batch_size=100,shuffle=False)
        self.test_loader=DataLoader(self.test_data,batch_size=100,shuffle=False)
    # 拿到数据,网络
    def train(self,epoch):
        loss_sum = 0
        accurary_sum = 0
        for img_tensor, label in tqdm.tqdm(self.train_loader, desc='train...', total=len(self.train_loader)):
            out = self.net(img_tensor)
            loss = self.loss(out, label)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            loss_sum += loss.item()
            accurary_sum += torch.mean(
                torch.eq(torch.argmax(label, dim=1), torch.argmax(out, dim=1)).to(torch.float32)).item()
        loss_avg = loss_sum / len(self.train_loader)
        accurary_avg = accurary_sum / len(self.train_loader)
        print(f'train---->loss_avg={round(loss_avg, 3)},accurary_avg={round(accurary_avg, 3)}')
        # self.writer.add_scalars('loss',{'loss_avg':loss_avg},epoch)
    def train1(self):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm.tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy =torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.train_loader)
        avg_acc = sum_acc / len(self.train_loader)
        print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
​
​
    def run(self):
        for epoch in range(10):
            self.train1()
            # self.test(epoch)
if __name__ == '__main__':
    tt=TrainAndTest()
    tt.run()

util.py

from torchvision import transforms
​
transforms_train=transforms.Compose([
    # transforms.CenterCrop(10),
    # transforms.PILToTensor(),
    transforms.ToTensor(),#归一化,转tensor
    transforms.Resize((28,28)),
    transforms.RandomVerticalFlip()
])
transforms_test=transforms.Compose([
    transforms.ToTensor(),  # 归一化,转tensor
    transforms.Resize((28, 28)),
])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933243.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

爬虫与 Zapier 集成

利用与 Zapier 集成的爬虫 API,以最小的工作量自动完成数据收集、处理和报告等复杂任务。 什么是 Zapier? Zapier 是一家为网络应用程序提供集成的公司,可用于自动化工作流程。 无代码集成 爬虫 API 集成无需编码,您只需点击几下&#x…

ERR SELECT is not allowed in cluster mode

在redis集群模式下,默认且只能使用0号database库,不允许使用SELECT 操作选择database 。

Django Q()函数

Q() 函数的作用 在Django中,Q()函数是一个非常有用的工具,主要用于构建复杂的查询。它允许你创建复杂的查询语句,包括AND、OR和NOT逻辑操作。这对于处理复杂的数据库查询特别有用,特别是在你需要组合多个条件或处理复杂的过滤逻辑…

乐鑫ESP-IoT-Bridge方案简化设备智能联网通信,启明云端乐鑫代理商

随着物联网技术的快速发展,设备联网已成为实现智能化的关键一步。然而,不同设备之间的通信协议、接口等差异,使得设备联网变得复杂且困难。 乐鑫推出的ESP-IoT-Bridge联网方案,正是为了解决这一难题,为物联网场景下的…

搞定前端面试题——ES6同步与异步机制、async/await的使用以及Promise的使用!!!

文章目录 同步和异步async/awaitPromisePromise的概念 同步和异步 ​ 同步:代码按照编写顺序逐行执行,后续的代码必须等待当前正在执行的代码完成之后才能执行,当遇到耗时的操作(如网络请求等)时,主线程会…

第二证券:深股和沪股区别?一文解析深股沪股区别?

深股,即在深圳证券生意所上市、生意的股票;沪股,即在上海证券生意所上市、生意的股票。 深市上市公司以小型和中型企业为主,上市条件相对较松;沪市的上市公司多为大型企业和国有企业,上市条件相对严峻。 …

Chromium CI/CD 之Jenkins实用指南2024- Windows节点开启SSH服务(七)

1.引言 在现代软件开发和持续集成的过程中,自动化部署和远程管理是不可或缺的关键环节。SSH(Secure Shell)协议以其强大的安全性和灵活性,成为连接和管理远程服务器的首选工具。对于使用Windows虚拟机作为Jenkins从节点的开发者而…

【整体介绍】HTML和JS编写多用户VR应用程序的框架

一、Networked-Aframe是什么? 简称NAF,底层基于Mozilla的AFrame框架,用HTML和JS编写多用户VR应用程序的框架。 二、特性 支持 WebRTC 和/或 WebSocket 连接。 语音聊天。音频流让您的用户在应用程序内交谈(仅限 WebRTC&#xff…

【cocos creator】ts中export的模块管理

在 TypeScript(TS)中,export 和 import 的概念与 Java 中的 public 类、接口以及 import 语句有一些相似之处。可以用以下方式来类比理解: Export 在 TypeScript 中,export 用于将模块中的变量、函数、类等暴露给外部…

白酒销售的新零售模式|琼台酱酒醉仙洞酒商业模式

近年来,酱酒在白酒市场中始终占据绝对C位,被称为“液体黄金”,其金融属性深受市场认同。酱酒之所以如此盛行,与其低风险、高收益、低门槛和快速变现能力密不可分。酒类流通市场的主要参与者包括上游的生产商、中游的酒类流通企业和…

服务器的80和443端口关闭也能申请SSL证书

一、简介 在服务器的80和443端口关闭的情况下,确实可以申请SSL证书,但申请过程和方法会根据证书类型和验证方式的不同而有所差异。 通常如果是网站域名申请SSL证书,哪怕服务器的80、443端口都打不开,也可以通过DNS解析的方式来验…

vue复制链接操作

vue复制链接操作 使用clipboardclipboard属性代码实现 发布测试出现问题问题分析解决方案最终代码实现document.execCommand扩展常用例子 给要复制的文本或者按钮加上点击事件后,并将要复制的值传过来 使用clipboard clipboard属性 –解释read从剪贴板读取数据&a…

代码重构思想和VSCode编辑器中代码重构插件

目录 一、参考资料 二、VSCode重构插件 1、小浣熊 (1)功能 (2)使用说明 2、Code Spell Checker 3、Abracadabra, refactor this! (1)重命名变量或函数名称 (2)提取变量 &a…

【STC89C51单片机】定时器/计数器的理解

目录 定时器/计数器1. 定时器怎么定时简单理解(加1经过了多少时间)什么是时钟周期什么是机器周期 2.如何设置定时基本结构相关寄存器1. TMOD寄存器2. TCON寄存器 代码示例 定时器/计数器 STC89C51单片机的定时器和计数器(Timers and Counter…

Talk|OSU汪博石:Transformer模型能否进行隐式的推理?关于Grokking和泛化的深入探索

本期为TechBeat人工智能社区第609期线上Talk。 北京时间7月17日(周三)20:00,俄亥俄州立大学博士生—汪博石的Talk已经准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “Transformer模型能否进行隐式的推理?关于Grokking和泛化的深入…

【leetcode】 字符串相乘(大数相乘、相加)

记录一下大数相乘相加方法: 给定两个以字符串形式表示的非负整数 num1 和 num2,返回 num1 和 num2 的乘积,它们的乘积也表示为字符串形式。 注意:不能使用任何内置的 BigInteger 库或直接将输入转换为整数。 示例 1: 输入: nu…

【C语言】联合体(union)

文章目录 1.联合体的含义2. 联合体的声明3. 联合体大小的计算4. 联合体的特点 1.联合体的含义 联合体也叫做共用体&#xff0c;是指联合体的所有成员共用同一块内存空间。这也就说明了&#xff0c;联合体的大小至少是其成员所占空间的最大值。 2. 联合体的声明 #include<…

CSS基础学习之元素定位(6)

目录 1、定位类型 2、取值 2.1、static 2.2、relative 2.3、absolute 2.4、fixed 2.5、stickty 3、示例 3.1、相对定位(relative) 3.2、绝对定位&#xff08;absolute&#xff09; 3.3、固定定位&#xff08;fixed&#xff09; 3.4、粘性定位&#xff08;sticky&…

力扣第九题(回文数)

9. 回文数 - 力扣&#xff08;LeetCode&#xff09; 提示 给你一个整数 x &#xff0c;如果 x 是一个回文整数&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 回文数 是指正序&#xff08;从左向右&#xff09;和倒序&#xff08;从右向左&#xff09;读…

Husco在汽车自动变速器电磁阀总产量超1200万台最大汽车电磁阀线圈制造商Amisco同时高产量增长

Husco Inc.是一家位于美国威斯康星州的汽车零部件制造商&#xff0c;专门生产汽车自动变速器电磁阀。根据最新的数据&#xff0c;Husco的汽车自动变速器电磁阀总产量已经超过了1200万台&#xff0c;成为全球最大的汽车电磁阀生产商之一。 与此同时&#xff0c;Amisco是一家专门…