CNN代码实战

news2024/9/21 2:34:07

CNN的原理

从 DNN 到 CNN
(1)卷积层与汇聚
⚫ 深度神经网络 DNN 中,相邻层的所有神经元之间都有连接,这叫全连接;卷积神经网络 CNN 中,新增了卷积层(Convolution)与汇聚(Pooling)。
⚫ DNN 的全连接层对应 CNN 的卷积层,汇聚是与激活函数类似的附件;单个卷积层的结构是:卷积层-激活函数-(汇聚),其中汇聚可省略。
(2)CNN:专攻多维数据
在深度神经网络 DNN 课程的最后一章,使用 DNN 进行了手写数字的识别。但是,图像至少就有二维,向全连接层输入时,需要多维数据拉平为 1 维数据,这样一来,图像的形状就被忽视了,很多特征是隐藏在空间属性里的,而卷积层可以保持输入数据的维数不变,当输入数据是二维图像时,卷积层会以多维数据的形式接收输入数据,并同样以多维数据的形式输出至下一层

导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

制作数据集

# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)

训练网络

class CNN(nn.Module):
    def __init__(self):
       super(CNN,self).__init__()
       self.net = nn.Sequential(
       nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),
       nn.AvgPool2d(kernel_size=2, stride=2),
       nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),
       nn.AvgPool2d(kernel_size=2, stride=2),
       nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),
       nn.Flatten(),
       nn.Linear(120, 84), nn.Tanh(),
       nn.Linear(84, 10)
)
    def forward(self, x):
         y = self.net(x)
         return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(
    model.parameters(),
    lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):
    for (x, y) in train_loader: # 获取小批次的 x 与 y
        x, y = x.to('cuda:0'), y.to('cuda:0')
        Pred = model(x) # 一次前向传播(小批量)
        loss = loss_fn(Pred, y) # 计算损失函数
        losses.append(loss.item()) # 记录损失函数的变化
        optimizer.zero_grad() # 清理上一轮滞留的梯度
        loss.backward() # 一次反向传播
        optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()

测试网络

# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能
    for (x, y) in test_loader: # 获取小批次的 x 与 y
        x, y = x.to('cuda:0'), y.to('cuda:0')
        Pred = model(x) # 一次前向传播(小批量)
        _, predicted = torch.max(Pred.data, dim=1)
        correct += torch.sum( (predicted == y) )
        total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')

使用网络

# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')

完整代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)


class CNN(nn.Module):
    def __init__(self):
       super(CNN,self).__init__()
       self.net = nn.Sequential(
       nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),
       nn.AvgPool2d(kernel_size=2, stride=2),
       nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),
       nn.AvgPool2d(kernel_size=2, stride=2),
       nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),
       nn.Flatten(),
       nn.Linear(120, 84), nn.Tanh(),
       nn.Linear(84, 10)
)
    def forward(self, x):
         y = self.net(x)
         return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(
    model.parameters(),
    lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):
    for (x, y) in train_loader: # 获取小批次的 x 与 y
        x, y = x.to('cuda:0'), y.to('cuda:0')
        Pred = model(x) # 一次前向传播(小批量)
        loss = loss_fn(Pred, y) # 计算损失函数
        losses.append(loss.item()) # 记录损失函数的变化
        optimizer.zero_grad() # 清理上一轮滞留的梯度
        loss.backward() # 一次反向传播
        optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()

# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能
    for (x, y) in test_loader: # 获取小批次的 x 与 y
        x, y = x.to('cuda:0'), y.to('cuda:0')
        Pred = model(x) # 一次前向传播(小批量)
        _, predicted = torch.max(Pred.data, dim=1)
        correct += torch.sum( (predicted == y) )
        total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')

# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')

运行截图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2044430.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 操作 Redis和redis持久化

一、Jedis 我们要使用 Java 来操作 Redis&#xff0c;Jedis 是 Redis 官方推荐的 java连接开发工具&#xff01; 使用Java 操作 Redis 中间件&#xff01; 1.导入对应的依赖 https://mvnrepository.com/artifact/redis.clients/jedis <dependency><groupId>redi…

【学习笔记】Matlab和python双语言的学习(最小生成树——Kruskal算法、Prim算法)

文章目录 前言一、最小生成树树的一些概念关键特性最小生成树和最短路径的主要区别常用算法1. Kruskal算法(适合点多边少的图)2. Prim算法(适合边多点少的图) 二、示例三、代码实现----Matlab四、代码实现----python1. Kruskal算法2. Prim算法 总结 前言 通过模型算法&#xf…

硬件模拟的基本原理是什么?

具体来说&#xff0c;这种设计方法减少了集成电路 (IC) 设计和开发的设计迭代次数&#xff0c;并且广泛适用于所有电力电子设计。我详细介绍了我在快速上市 IC 开发方面的经验&#xff0c;并将该方法与其他旨在缩短产品开发时间的技术进行了对比。 产品开发流程 图 1&#xff…

三菱定位控制(三,步进电机与定位模块的接线详情)

相信大家对前面的学习已经对前面的内容有所了解&#xff0c;下面就来看看步进电机&#xff0c;步进电机驱动器还有定位模块之间是如何接线的吧&#xff01; 一&#xff0c;将定位模块转换为端子排 首先&#xff0c;我们肯定是无法之间再定位模块上直接进行接线的。所以我们需要…

基于Java中的SSM框架实现家政预约管理系统项目【项目源码+论文说明】

基于Java中的SSM框架实现家政预约管理系统演示 摘要 随着线上预约服务应用的不断普及&#xff0c;为了给用户提供更加便捷的服务&#xff0c;很多行业都实行了线上预约制&#xff0c;比如医疗行业的线上挂号以及交通行业预约购票等&#xff0c;预约服务可以帮助人们节约大量排…

Ubuntu 安装 mysql 与 远程连接配置

1、安装 mysql ubuntu 默认安装 8.0 版本&#xff1a; sudo apt install mysql-server安装过程中 提示 是否继续操作 y 即可 2、使用ubuntu 系统用户 root 直接进入 mysql 切换至 系统用户 su root 输入命令 可直接进入 mysql: mysql3、创建一个允许远程登录的用户 创建 …

使用国内镜像站点安装Qt6 for Mac

使用国内镜像站点安装Qt6 for Mac 从下列网址下载在线安装包 Index of /archive/online_installers (qt.io) 双击前述dmg文件&#xff0c;在终端执行语句 使用一句命令行语句&#xff1a; open qt-unified-macOS-x64-4.6.1-online/qt-unified-macOS-x64-4.6.1-online.app --…

【leetcode】回文链表-25-3

方法&#xff1a;快慢指针递归遍历 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), next(nullptr) {}* ListNode(int x, ListNode *next) …

超吸睛!用AI绘画做沙雕图文号,有趣又解压!轻松吸粉变现!

大家好&#xff0c;我是画画的小强 小强可谓是小某书资深用户&#xff0c;最近突然刷到了某篇主打沙雕日常文案的 AI手绘插画图文号&#xff0c;不仅篇篇笔记上千点赞量&#xff0c;重点&#xff01;这类图文号植入的均是报价贼高的软广&#xff0c;我的老天鹅&#xff01;一搜…

基于vue框架的爱购电商平台256tr(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;用户,商家,商品分类,商品信息 开题报告内容 基于Vue框架的爱购电商平台 开题报告 一、研究背景与意义 随着互联网技术的迅猛发展和消费者购物习惯的深刻变革&#xff0c;电子商务已成为推动全球经济增长的重要力量。然而&#xff0c;…

分销商城小程序系统渠道拓展

线上卖货渠道很多&#xff0c;想要不断提高营收和新客获取&#xff0c;除了自己和工具本身努力外&#xff0c;还需要其他人的帮助来提高商城店铺的整体销量。 搭建saas商城系统网站/小程序&#xff0c;后台上货&#xff0c;设置支付、配送、营销、精美模板商城装修等内容&…

网络研讨会 | Unlock your SAP Data:用新一代 SAP 集成方案加速释放企业核心数据价值

伴随着业务的快速发展&#xff0c;企业对 SAP 系统的使用逐渐深化。为了更好释放数据价值&#xff0c;有效驱动商业决策&#xff0c;将 SAP 数据与其他数据系统集成打通&#xff0c;已成为企业数字化发展的必然趋势。 然而&#xff0c;由于 SAP 系统的独特性&#xff0c;企业想…

Lesson 59 Is that all?

Lesson 59 Is that all? 词汇 envelope n. 信封 复数&#xff1a;envelopes 例句&#xff1a;在红色信封里有一封书信。 There is a letter in the red envelope. writing paper n. 信纸【不可数】 构成&#xff1a;write v. 写 writer n. 作家    相关&#xff1a;hand…

SQL SERVER 2008多表关联查询,关联条件用(*=、=*)无法使用,高版本数据库不兼容,报错“*=”附近有语法错误!

专业问题&#xff0c;已经习惯了先问AI 下图是百度AI的回复 下图是讯飞星火的回复 下面是SQL SERVER 2022数据库引擎&#xff0c;查询使用&#xff01; 报错&#xff1a;“*”附近有语法错误。 查询数据库版本语法 SELECT VERSION 得到数据库版本&#xff1a;SQL Server 202…

【CVE-2024-38077】修复Windows 远程桌面授权服务远程代码执行漏洞记录

官方漏洞指南&#xff1a;Security Update Guide - Microsoft Security Response Center 受影响版本&#xff1a; Windows Server 2012 R2 (Server Core installation) Windows Server 2012 R2 Windows Server 2012 (Server Core installation) Windows Server 2012 Windows …

AI时代来临,程序员歇业在家,如何利用AI来赚钱?

我没有团队&#xff0c;就一个人&#xff0c;能不能用AI来赚点小钱&#xff1f; 当然可以&#xff0c;单打独斗并非阻碍&#xff0c;AI技术完全可以助您拓展收入来源。接下来&#xff0c;我将为您详细介绍10种个体利用AI提升收入的有效途径。 首先&#xff0c;您可以通过以下…

PRVF-4037 : CRS is not installed on any of the nodes

描述&#xff1a;公司要求替换centos&#xff0c;重新安装ORACLE LINUX RAC的数据库做备库&#xff0c;到时候切换成主库&#xff0c;安装Linux7GRID 19C 11G Oracle&#xff0c;顺利安装grid 19c&#xff0c;安装11G数据库软件的时候检测报如题错误&#xff1a;**PRVF-4037 …

显卡刷坏BIOS黑屏怎么办_万能救砖恢复方法来了!

电脑显卡刷bios黑屏不开机无显示无信号&#xff0c;是因为显卡vbios刷错文件了&#xff0c;找到正确的显卡BIOS文件刷回去就行了。但是黑屏进不去系统如何刷显卡BIOS呢。给大家推荐个万能刷显卡BIOS工具&#xff0c;名字是《离线刷bios设备》&#xff0c;用免拆夹子直接连接显卡…

49-DRC的设置及检查

1. DRC设置入口 2.如下设置 3.运行DRC

IDEA 解决创建新项目Maven配置变化问题

我们发现&#xff0c;每次创建新项目&#xff0c;配置好的maven路径就变了&#xff0c;总是恢复成IDEA自带的maven配置&#xff0c;我们想永久使用我们的Maven配置&#xff0c;该如何修改呢&#xff1f; 配置好要重启IDEA&#xff01;&#xff01;&#xff01;