MNIST手写数字辨识-cnn网路 (机器学习中的hello world,加油)

news2025/2/28 6:31:15

用PyTorch实现MNIST手写数字识别(非常详细) - 知乎 (zhihu.com)

参考来源(这篇文章非常适合入门来看,每个细节都讲解得很到位)

一、模块函数用法-查漏补缺:

1.关于torch.nn.functional.max_pool2d()的用法:

上述示例中,输入张量 input 经过最大池化操作后,使用了 kernel_size=2stride=2,所以输出张量 output 的高度和宽度均为输入的一半(32/2=16)。

2.pytorch中的view函数的用法:

http://t.csdn.cn/AAhdH

这一篇文章写得非常好

3.关于f.log_softmax(x,dim = -1)这个先进行softmax,再取log的函数的讲解:

http://t.csdn.cn/GIJ7g

这篇文章讲解得非常好,补充一点,dim的default值和softmax一样,都是-1,也就是计算最里面那个维度的softmax的结果

4.原来loss和counter计数器数组有这个作用:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

5.关于F.nll_loss这个损失函数:

http://t.csdn.cn/ZoruZ

总的来说就是一句话“损失函数 nn.CrossEntropyLoss() 与 NLLLoss() 相同, 唯一的不同是它为我们去做 log_softmax.”

这篇文章讲述得非常清楚

6.关于loss.item()的作用:

http://t.csdn.cn/AvrnJ

这篇文章讲得非常清楚:

就是输出loss这个数值,但是呢,是用非常高的精度进行输出的,一般我们进行一各batch的训练后,就会得到这一次的loss单个数值,需要输出的话,最好就用item()

7.with torch.no_grad()的用法:

http://t.csdn.cn/STaKp

这篇文章讲述得非常清楚,就是不会进行gradient_descend操作,极大的节省了运算开销

8.data.max()函数的用法:

http://t.csdn.cn/aBmin

上面那里讲得不太好,还是chatGPT比较优秀

9.data.view_as()的用法:

10.torch.eq的用法:

http://t.csdn.cn/Tb0kY

这篇文章讲述得非常清楚,也就是对张量中的数值逐个进行比较,

返回的是同样形状的数据,每个位置要么True要么False,可以用.sum()求和得到True的总数

顺便提一下torch.sum的用法,

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x.sum())

输出的结果是21

二、各个部分的代码和注释:

#设置环境
import torch
import torchvision
from torch.utils.data import DataLoader
#准备数据集
#1.设置必要的参数
n_epochs = 3
batch_size_train = 64 #所以呢,这个64其实就是下面train时候的batch_size大小
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10 #这个就是后面用来输出的间隔
random_seed = 1
torch.manual_seed(random_seed)

 

#利用pytorch直接加载对应的train_data集 和 test_Data集
train_loader = torch.utils.data.DataLoader( #这里调用的是torch.utils.data.DataLoader的对象,实例化出train_dataloader
    #限免设置各个参数,比如,第一个就是Dataset参数,这里是引用MNIST作为参数,并且设置MNIST中的各个参数
    torchvision.datasets.MNIST('./data/', train=True, download=True, #设为train数据+下载
                               transform=torchvision.transforms.Compose([ #对数据进行transform变换
                                   torchvision.transforms.ToTensor(), #先变tensor后进行Normlize
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_train, shuffle=True) #这个loader的后两个参数batch_size和shuffle
#同样的道理设置test_data_loader
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_test, shuffle=True)
#查看一条数据:
examples = enumerate(test_loader) #enumerate返回一个(index,data)的元组,本身是一个迭代器,可以用于遍历test_loader
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets) #输出测试(这里的test是有answer作为label的)的1000各answer
print(len(example_targets)) #总共1000各target
print(example_data.shape) #一共有1000张28*28的黑白灰度图
#利用matplotlib进行绘制得到某些数据的可视化结果
import matplotlib.pyplot as plt
fig = plt.figure() #创建一个fig对象
for i in range(6):
  plt.subplot(2,3,i+1) #按照2行3列绘制6张图片
  plt.tight_layout() #设置紧密相连
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none') #利用imshow在下方直接输出图像
  plt.title("Ground Truth: {}".format(example_targets[i]))#设置标题,就是label的数值
  plt.xticks([])
  plt.yticks([])
plt.show()
#定义neural network的结构
import torch.nn as nn #引入neural network的库
import torch.nn.functional as F #引入nn总的常用Func
import torch.optim as optim #引入torch中的optimizer

class Net(nn.Module): #继承nn中的module
    def __init__(self):  #定义这个网络结构的构造函数
        super(Net, self).__init__() #继承nn.Module的初始化构造
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)#参数:输入channel、输出channel、卷积核5*5(filters),strdie(default =1),padding(default=0)
        #所以1*28*28的图像通过后,10*21*21(10是filters的数量)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d() #这个函数http://t.csdn.cn/xK6og这篇文章讲得挺好的,就是让部分filters在某一层不工作,效果是有效防止overfit
        self.fc1 = nn.Linear(320, 50) #定义一个320 -->50 的Linear层函数
        self.fc2 = nn.Linear(50, 10)  #定义一个50 -->10  的Linear层函数
    def forward(self, x): #下面就是直接进行整个network的作用过程定义了 , 输入1*28*28的灰度图
        x = F.relu(F.max_pool2d(self.conv1(x), 2)) #经过一个conv1卷积层后,经过1次2*2窗口的pooling得到,默认??padding=1,之后再算好了
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #再通过conv2之后->通过conv2_drop->通过max_pool2d
        x = x.view(-1, 320) #第二维是320,并自动计算第一维
        x = F.relu(self.fc1(x)) #通过一个linear层之后,又通过一个relu的激活函数,最后输出的是第二维是50的结果
        x = F.dropout(x, training=self.training) #只有在training模式下才会调用dropout(让某些神经元“熄火”喵)
        x = self.fc2(x) #再让x通过一个linear层,输出的结果是2维的数据,第二维(共10列)
        return F.log_softmax(x) #最后通过对最里面那一层softmax层后,取log对数
#创建model对象+设置optimizer优化器
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum) #lr和momentum都是上面设置好的

#设置用于存储的数组结构:
train_losses = []
train_counter = [] #估计就是一个计数器的作用
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
#print(test_counter)输出[0,60000,120000,180000]不知道再干啥,反正上面的n——epoch==3
#定义这个train函数:
def train(epoch): #这里的epoch是传递进来的参数
  network.train() #开启train模式
  for batch_idx, (data, target) in enumerate(train_loader):#迭代器:以batch为单位逐个从train_loader中获取 索引、data图像数据、label作为target数据
    optimizer.zero_grad() #因为torch中的grad是累加的,所以需要在每个batch训练之前利用optimizer.zero_grad()清零
    output = network(data) #将data图像数据通过network网络得到output输出结果
    loss = F.nll_loss(output, target) #这个loss_func只是比cross_entropy少一个对输入数据的log_softmax操作
    loss.backward()
    optimizer.step() #loss.backward + optimizer.step()常规更新模型参数的操作
    
    if batch_idx % log_interval == 0: #下面都是没啥用的间隔输出操作,上面设置的log_interval =10
      #每经过10各batch处理输出一次:
      #第几个epoch,第几个图像,总共的train有多少图像,已经完成了百分之几的batch,这个batch的loss值
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item())) 
      #将这个batch的loss值添加到train_losses数组中(注意,这里好像是每隔10个batch记录一次loss)
      train_losses.append(loss.item())
      train_counter.append(  #在counter中记录这个batch在考虑epoch情况下的位置
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset))) #这个64是train时候的batch_size上面写了
      #将当前的network的参数状态state_dice存储到对应的路径下, 同时optimizer的状态也要存储?why感觉optim没啥用
      torch.save(network.state_dict(), './model.pth')
      torch.save(optimizer.state_dict(), './optimizer.pth')
          
#train(1) #传递参数epoch=1进行train一次
#这里的train有个地方很有意思,它只是输出loss,没有利用argmax计算出对应的one-hot vec,从而没法和label进行比较得到acc
#定义test函数,并且进行test测试 (不用想,大概率和train的内容没有太大的区别,不过是少了backward和step的更新)
def test():
  network.eval() #开启model的eval模式
  test_loss = 0 #设置loss和acc初值
  correct = 0
  with torch.no_grad(): #不计算SGD
    for data, target in test_loader: #非enumerate,非迭代器版本,不会返回索引,获取data图像batch和target的labels数值
      output = network(data) #调用network获取output结果

      test_loss += F.nll_loss(output, target, size_average=False).item() #这里计算出这一次的 output和target之间的loss

      pred = output.data.max(1, keepdim=True)[1] #通过data.max函数获取对应的索引,这是一个索引的数组,因为是一个batch一起预测的
      correct += pred.eq(target.data.view_as(pred)).sum() #如果pred和target数组对应位置比较,计算总共相等的位置的数量
  test_loss /= len(test_loader.dataset) #计算平均的loss
  test_losses.append(test_loss) #将这一次的平均loss加入到test_losses数组中
  
  #输出:
  #这一次的平均loss,总数中正确预测的数目,正确率
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
    
#test()#调用上述定义的test函数
#再调用一次test()
test()
for epoch in range(1, n_epochs + 1): #调用n_epochs个数的train和测试结果
  train(epoch)
  test()

#下面对上述获取到的数据进行图像的绘制
#绘制图像一开始出错了,我怀疑是我多进行了一次test(),导致x和y的大小不对应
import matplotlib.pyplot as plt
fig = plt.figure()       #创建figure对象 
plt.plot(train_counter, train_losses, color='blue') #绘制曲线图,x是train计数,y是trainloss
#plt.scatter(test_counter, test_losses, color='red') #绘制散点图,x是test_counter计数,y是test_losses数据
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen') #x轴标题
plt.ylabel('negative log likelihood loss')     #y轴标题
plt.show() #绘制结果
#抽取几个直观的例子进行测试:

examples = enumerate(test_loader) #获取test_loader的迭代器
batch_idx, (example_data, example_targets) = next(examples) #获取第一个test_loader中的batch
with torch.no_grad():
  output = network(example_data) #将example_data数据通过network得到output
fig = plt.figure() #创建figure对象
for i in range(6): #构建2行3列的图像排列
  plt.subplot(2,3,i+1)
  plt.tight_layout() #紧密排列
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none') #利用imshow输出example图像
  plt.title("Prediction: {}".format(
    output.data.max(1, keepdim=True)[1][i].item())) #输出预测结果,结果非常美妙
  plt.xticks([])
  plt.yticks([])
plt.show() #绘制-这个似乎可以不用
#为了能够持续训练,这里考虑 获取 上一次的 model_dict 和 optim_dict
continued_network = Net()
continued_optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                                momentum=momentum)

network_state_dict = torch.load('model.pth')
continued_network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load('optimizer.pth')
continued_optimizer.load_state_dict(optimizer_state_dict)

#再接着上面练上6次
for i in range(4, 9):
    test_counter.append(i*len(train_loader.dataset))
    train(i)
    test()
#同样进行图像的绘制
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red') #因为之前多test了一次,所以这里应该还是会出错
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()

第一个再vscode完成的神经网络训练, 撒花庆祝!!🎉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/990173.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

56、springboot ------ RESTful服务及RESTful接口设计

★ RESTful服务 RESTful服务是“前后端分离”架构中的主要功能&#xff1a; 后端应用对外暴露RESTful服务&#xff0c;前端应用则通过RESTful服务与后端应用交互。后端应用 RESTful接口 <------------------> 前端★ 基于JSON的RESTful服务 使用RestController注解…

Linux 系统时间同步 ​使用 NTP 服务时间同步​

目录 一、使用 NTP 服务时间同步 二、启动并设置开机自启 三、配置文件 在 /etc/ntp.conf 四、局域网指定一台服务器作为 时间服务器 一、使用 NTP 服务时间同步 安装 ntp yum -y install ntp 二、启动并设置开机自启 systemctl start ntpd systemctl enable ntpd #开…

vue+antd——实现table表格的打印——分页换行,每页都有表头——基础积累

这里写目录标题 场景效果图功能实现1&#xff1a;html代码功能实现2&#xff1a;css样式功能实现3&#xff1a;js代码补充内容page-break-inside 属性page-break-after属性page-break-before 属性 场景 最近在写后台管理系统时&#xff0c;遇到一个需求&#xff0c;就是要实现…

【独家工具】JMeterPerfReporter3.0正式版本,让你的JMeter更好用

Lemon-JMeterPerfReporter工具&#xff0c;是我们性能测试课程教研组根据JMeter性能测试报告的不足&#xff0c;定制开发的一个性能报告生成工具。有需要的同学&#xff0c;可以通过小编官方gitee账户下载&#xff0c;或咨询我免费获取哦&#xff01; 做过性能测试的人员都知道…

单目标应用:基于蜘蛛蜂优化算法(Spider wasp optimizer,SWO)的微电网优化调度MATLAB

一、微网系统运行优化模型 微电网优化模型介绍&#xff1a; 微电网多目标优化调度模型简介_IT猿手的博客-CSDN博客 二、蜘蛛蜂优化算法 蜘蛛蜂优化算法&#xff08;Spider wasp optimizer&#xff0c;SWO&#xff09;由Mohamed Abdel-Basset等人于2023年提出&#xff0c;该…

企业帮助中心如何在线搭建,还能多场景使用呢?

搭建一个企业帮助中心的在线平台可以帮助企业提供高效的客户支持和解决方案。同时&#xff0c;这个平台还可以用于其他场景&#xff0c;例如内部员工培训、知识共享等。下面我将详细介绍如何在线搭建一个企业帮助中心&#xff0c;并且使其能够多场景使用。 选择合适的在线平台…

jeecg vue3版本集成达梦数据库

jeecg他的文档中有一个集成达梦数据库的步骤&#xff0c;链接如下 连接达梦数据库 - JeecgBoot 文档中心&#xff0c;但是我按照步骤去操作的时候并没有适配成功&#xff0c;大部分是他的步骤写的不够清楚&#xff0c;没有说明改哪里的文件&#xff0c;下面是我摸索的适配步骤。…

[移动通讯]【Carrier Aggregation-4】【LTE-1】

前言&#xff1a; 参考&#xff1a; 《Carrier Aggregation Explained In 101 Seconds》 Qualcomm 《Carrier aggregation (CA) in LTE-Advanced by TELCOMA Global》TELCOMA Global 《Carrier Aggregation _CA_Part1》 《Carrier Aggregation _CA_Part2》 《Carrier Aggregati…

【Opencv入门到项目实战】(十一):harris角点检测|SIFT|特征匹配

所有订阅专栏的同学可以私信博主获取源码文件 文章目录 1.harris角点检测2.尺度不变特征变换&#xff08;SIFT&#xff09;2.1图像尺度空间2.2 关键点定位2.3 消除边界响应2.4 代码示例 1.harris角点检测 这一节我们来讨论一下Harris角点检测&#xff0c;由Chris Harris和Mike…

论文分享丨西工大音频语音与语言处理研究组四篇论文被IEEE Trans. ASLP和SPL录用

近日&#xff0c;实验室三篇论文被语音研究顶级期刊IEEE/ACM Transactions on Audio, Speech and Language Processing (TASLP)录用&#xff0c;一篇论文被重要期刊IEEE Signal Processing Letters (IEEE SPL)录用&#xff0c;论文方向涉及说话人识别中的对抗攻击、基于扩散模型…

初学python(一)

一、python的背景和前景 二、 python的一些小事项 1、在Java、C中&#xff0c;2 / 3 0&#xff0c;也就是整数 / 整数 整数&#xff0c;会把小数部分舍掉。而在python中2 / 3 0.66666.... 不会舍掉小数部分。 在编程语言中&#xff0c;浮点数遵循IEEE754标准&#xff0c;不…

vscode中git的使用,以及与webstorm中git的使用对比

前言&#xff1a; 在项目中经常使用的git提交我们代码的时候&#xff0c;vscode和webstorm 是用的非常多的两个工具了&#xff0c;这里再次整理下他们的具体使用以及各自的优势&#xff01; 1、初始化拉取项目 个人习惯&#xff0c;这里就不说框架用法了&#xff0c;原始的最简…

Python函数的概念以及定义方式

一. 前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 二. 什么是函数&#xff1f; 假设你现在是一个工人&#xff0c;如果你实现就准备好了工具&#xff0c;等你接收到任务的时候&#xff0c; 直接带上工…

【安全】正则回溯绕过练习简单案例

目录 环境 案例1 前要 代码审计 分析 案例2 代码审计 分析 payload 环境 phpstudy 案例1 前要 php中0 1 -1 true false null 空字符 数组之间的比较 代码审计 <?php function areyouok($greeting){return preg_match(/Merry.*Christmas/is,$greeting); //2.传…

FP103 双运算放大器和参考调节器芯片

FP103 双运算放大器和参考调节器芯片 一般说明 FP103是一个由一个独立的运放器&#xff08;OPA2&#xff09;和另一个运放器&#xff08;OPA1&#xff09;组成&#xff0c;在非反相输入上具有2.5V精密电压参考&#xff0c;应用于许多应用&#xff0c;如电源、二流/直流转换器或…

python selenium控制浏览器打开网页 模拟鼠标动作

selenium 是一个浏览器控制的库 需要下载安装 谷歌浏览器的驱动 chromedriver https://sites.google.com/chromium.org/driver/downloads 在这里选择跟自己谷歌浏览器版本号一致的驱动程序 如果是最新的浏览器版本可以点这里下面这个链接 Chrome for Testing availability 选…

Web自动化测试详细流程和步骤

一、什么是web自动化测试 自动化&#xff08;Automation&#xff09;是指机器设备、系统或过程&#xff08;生产、管理过程&#xff09;在没有人或较少人的直接参与下&#xff0c;按照人的要求&#xff0c;经过自动检测、信息处理、分析判断、操纵控制&#xff0c;实现预期的目…

echo tail 与 重定向符

1.echo 命令 可以使用echo命令在命令行内输出指定内容 语法: echo输出的内容 无需选项&#xff0c;只有一个参数&#xff0c;表示要输出的内容&#xff0c;复杂内容可以用””包围其类似于 printf 函数 例子&#xff1a; 2. 反引号符 被包围的内容&#xff0c;会被作为命令…

SpringMVC_拦截器

4.拦截器 4.1拦截器概述 概述&#xff1a;一种动态拦截方法调用的机制&#xff0c;在SpringMVC中动态拦截控制器方法的执行实际开发中&#xff0c;静态资源&#xff08;HTML/CSS&#xff09;不需要交给框架处理&#xff0c;需要拦截的是动态资源 4.2图示 图示 4.3案例实现 …

基于Java SSM+layui+mysql实现的图书借记管理系统源代码+数据库

介绍 本项目使用的技术栈是SSMlayuimysql&#xff0c;服务器使用的是tomcat 其中书籍图片存放的位置需要先在tomcat根目录下conf/setting.xml中配置虚拟路径&#xff0c;本项目配置的是D:\upload 完整代码下载地址&#xff1a;图书借记管理系统 用户角色划分 游客 使用本系…