Pytorch-CNN-Mnist

news2024/11/19 13:38:07

文章目录

  • model.py
  • main.py
  • 网络设置
  • 注意事项及改进
  • 运行截图

model.py

import torch.nn as nn
class CNN_cls(nn.Module):
    def __init__(self,in_dim=28*28):
        super(CNN_cls,self).__init__()
        self.conv1 = nn.Conv2d(1,32,1,1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(32,64,1,1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(64,128,1,1)
        self.lin1 = nn.Linear(128*7*7,512)
        self.lin2 = nn.Linear(512,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(-1,128*7*7)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import CNN_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = CNN_cls()

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

网络设置

在CNN_cls里面查看。

注意事项及改进

1.注意第一个输入通道是1,因为是灰度图像。
2.可以考虑加入GPU

运行截图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023上半年软件设计师上午题目总结

1 在计算机中系统总线用于连接 主存及外设部件 2 在由高速缓存、主存、硬盘构成的三级存储体系中,CPU执行指令时需要读取数据,DMA控制器和中断CPU发出的数据地址是 主存物理地址 。 DMA(Direct Memory Access)控制器是计算机硬…

Nacos深入原理从源码层面讲解

文章目录 1 Nacos原理1.1 Nacos架构1.2 注册中心原理1.3 SpringCloud服务注册1.4 NacosServiceRegistry实现1.4.1 心跳机制1.4.2 注册原理1.4.3 总结 1.5 服务提供者地址查询1.6 Nacos服务地址动态感知原理 1 Nacos原理 1.1 Nacos架构 Provider APP:服务提供者Cons…

STM32 学习笔记1:STM32简介

1 概述 STM32,从字面上来理解,ST 是意法半导体,M 是 Microelectronics 的缩写,32 表示 32 位,合起来理解,STM32 就是 ST 公司开发的 32 位微控制器。是一款基于 ARM 公司推出的基于 ARMv7 架构的 32 位 Co…

【详细教程hexo博客搭建】1、从零开始搭建一个能用的博客

1、开始 2.环境与工具准备 本教程主要面对的是Windows用户 操作系统:Windows10NodeGitHexo文本编辑器(强烈推荐VSCODE)GitHub 帐号一个域名(强烈推荐买个域名)云服务器(可选) 3.Node的安装 打开Node官网&#xff0…

vivo数据中心网络链路质量监测的探索实践

作者:vivo 互联网服务器团队- Wang Shimin 网络质量监测中心是一个用于数据中心网络延迟测量和分析的大型系统。通过部署在服务器上的Agent发起5次ICMP Ping以获取端到端之间的网络延迟和丢包率并推送到存储与分析模块进行聚合和分析与存储。控制器负责分发PingList…

【大数据】Neo4j 图数据库使用详解

目录 一、图数据库介绍 1.1 什么是图数据库 1.2 为什么需要图数据库 1.3 图数据库应用领域 二、图数据库Neo4j简介 2.1 Neo4j特性 2.2 Neo4j优点 三、Neo4j数据模型 3.1 图论基础 3.2 属性图模型 3.3 Neo4j的构建元素 3.3.1 节点 3.3.2 属性 3.3.3 关系 3.3.4 标…

JS生成器的介绍

1、 什么是生成器 生成器是ES6中新增的一种函数控制、使用的方案,它可以让我们更加灵活的控制函数什么时候继续执行、暂停执行等。 平时我们会编写很多的函数,这些函数终止的条件通常是返回值或者发生了异常。 生成器函数也是一个函数,但是…

阿里云无影云电脑是干什么用的?五大使用场景

阿里云无影云电脑是一种易用、安全、高效的云上桌面服务,阿里云无影云电脑可用于高数据安全管控、高性能计算等要求的金融、设计、视频、教育等领域,适用于多种办公场景,如远程办公、多分支机构、安全OA、短期使用、专业制图等。阿里云百科来…

【LeetCode热题100】--49.字母异位词分组

49.字母异位词分组 两个字符串互为字母异位词,当且仅当两个字符串包含的字母相同。同一组字母异位词中的字符串具备相同点,可以使用相同点作为一组字母异位词的标志,使用哈希表存储每一组字母异位词,哈希表的键为一组字母异位词的…

DockerCompose

DockerCompose Docker Compose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器! 初识DockerCompose Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。格式如下: version: &…

[golang 流媒体在线直播系统] 4.真实RTMP推流摄像头把摄像头拍摄的信息发送到腾讯云流媒体服务器实现直播

用RTMP推流摄像头把摄像头拍摄的信息发送到腾讯云流媒体服务器实现直播,该功能适用范围广,比如:幼儿园直播、农场视频直播, 一.准备工作 要实现上面的功能,需要准备如下设备: 推流摄像机(监控) 流媒体直播服务器(腾讯云流媒体服务器,自己搭建的流媒体服务…

React中组件通信01——props

React中组件通信01——props 1. 父传子——props1.1 简单例子——props1.2 props 可以传递任何数据1.2.1 传递数字、对象等1.2.2 传递函数1.2.3 传递模版jsx 2. 子传父 子传子——props2.1 父传子——传递函数2.2 子传父——通过父传子的函数实现2.3 优化 子传子(…

uniapp开发小程序中实现骨架屏

第一步:小程序中实现骨架屏在微信开发者工具中点击生成骨架屏: 第二步:复制html代码,到骨架屏vue组件汇中再把之前写的样式代码引入进去: import ../../pages/user/user.css; 第三步:组件中引入骨架屏&am…

python pytesseract 中文文字批量识别

用pytesseract 来批量把图片转成文字 1、安装好 pytesseract 包 2、下载安装OCR https://download.csdn.net/download/m0_37622302/88348824https://download.csdn.net/download/m0_37622302/88348824 Index of /tesseracthttps://digi.bib.uni-mannheim.de/tesseract/ 我是…

百度SEO优化TDK介绍(分析下降原因并总结百度优化SEO策略)

TDK是SEO优化中很重要的部分,包括标题(Title)、描述(Description)和关键词(Keyword),为百度提供网页内容信息。其中标题是最重要的,应尽量突出关键词,同时描述…

【C++学习】继承

目录 一、继承的概念及定义 1、继承的概念 2、继承的定义 2.1 定义格式 2.2 继承关系和访问限定符 2.3 继承基类成员访问方式的变化 二、基类和派生类对象赋值转换 三、继承中的作用域 四、派生类的默认成员函数 五、继承与友元 六、继承与静态成员 七、复杂的菱形…

天然气跟踪监管系统功能模块实现

天然气跟踪监管系统功能模块实现 1. 数据库查询3. 仓库管理(1)仓库查询与展示。代码说明 1. 数据库查询 救援物资跟踪监管系统的绝大部分功能都会涉及关系数据库中的业务数据,因此关系数据库的查询是本系统不可或缺的重要部分。 本系统中的数…

Matlab--微积分问题的计算机求解

目录 1.单变量函数的极限问题 1.1.公式例子 1.2.对应例题 1 2.多变量函数的极限问题 3.函数导数的解析解 4.多元函数的偏导数 5.Jacobian函数 6.Hessian矩阵 7.隐函数的偏导 8.不定积分问题的求解 9.定积分的求解问题 10. 多重积分的问题求解 1.单变量函数的极限问题 …

【Vue】快速入门案例与工作流程的讲解

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟在这里,我要推荐给大家我的专栏《Vue快速入门》。&#x1f…

运维Shell牛刀小试(十一):for循环读取多个命令行参数|read重定向读取文件内容

运维Shell脚本小试牛刀(一) 运维Shell脚本小试牛刀(二) 运维Shell脚本小试牛刀(三)::$(cd $(dirname $0); pwd)命令详解 运维Shell脚本小试牛刀(四): 多层嵌套if...elif...elif....else fi_蜗牛杨哥的博客-CSDN博客 Cenos7安装小火车程序动画 运维Shell脚本小试…