卷积神经网络实战

news2024/10/6 4:08:40

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致

1.首先读取数据

 - 分别构建训练集和测试集(验证集)
- DataLoader来迭代取数据

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片

# 训练集
train_dataset = datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 

# 测试集
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

2.卷积网络模块构建

- 一般卷积层,relu层,池化层可以写成一个套餐
- 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # 灰度图
                out_channels=16,            # 要得到几多少个特征图
                kernel_size=5,              # 卷积核大小
                stride=1,                   # 步长
                padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
            ),                              # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),                      # relu层
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出 (32, 7, 7)
        )
        
        self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),             # 输出 (32, 7, 7)
        )
        
        self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

 

3.准确率作为评估标准

def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] 
    rights = pred.eq(labels.data.view_as(pred)).sum() 
    return rights, len(labels) 

 

4训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = [] 
    
    for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环
        net.train()                             
        output = net(data) 
        loss = criterion(output, target) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        right = accuracy(output, target) 
        train_rights.append(right) 

    
        if batch_idx % 100 == 0: 
            
            net.eval() 
            val_rights = [] 
            
            for (data, target) in test_loader:
                output = net(data) 
                right = accuracy(output, target) 
                val_rights.append(right)
                
            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1], 
                100. * val_r[0].numpy() / val_r[1]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1574341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Springfox Swagger实现API自动生成单元测试

目录 第一步:在pom.xml中添加依赖 第二步:加入以下代码,并作出适当修改 第三步:在application.yaml中添加 第四步:添加注解 第五步:运行成功之后,访问相应网址 另外:还可以导出…

JavaScript代码小挑战

题目如下: 朱莉娅和凯特正在做一项关于狗的研究。于是,她们分别询问了 5 位狗主人他们的狗的年龄,并将数据存储到一个数组中(每人一个数组)。目前,她们只想知道一只狗是成年狗还是小狗。如果狗的年龄至少为…

算力在现实生活中的多方面应用!

算力在现实生活中的应用是多方面的,它已经成为推动现代社会发展的重要力量。 以下是算力在不同领域中的具体应用: 立即免费体验:https://gpumall.com/login?typeregister&sourcecsdn #分布式云服务#算力#GpuMall#GpuMall智算云#训练#…

【AI-3】Transformer

Transformer? Transformer是一个利用注意力机制来提高模型训练速度的模型,因其适用于并行化计算以及本身模型的复杂程度使其在精度和性能上都要高于之前流行的循环神经网络。 标准的Transformer结构如下图所示(图来自知乎-慕文)&#xff0c…

特征提取算法

特征提取算法 0. 写在前边1. Harris算法1.1 写在前面1.2 Harris算法的本质1.3 Harris算法的简化 2. Harris3D2.1 Harris3D算法问题定义2.2 Harris3D with intensity2.3 Harris3D without intensity 3. ISS特征点的应用 0. 写在前边 本篇将介绍几种特征提取算法,特征…

C++从入门到精通——类对象模型

类对象模型 前言一、如何计算类对象的大小问题 二、类对象的存储方式猜测对象中包含类的各个成员代码只保存一份,在对象中保存存放代码的地址只保存成员变量,成员函数存放在公共的代码段问题总结 三、结构体内存对齐规则四、例题结构体怎么对齐&#xff…

3D桌面端可视化引擎HOOPS Visualize如何实现3D应用快速开发?

HOOPS Visualize是一个开发平台,可实现高性能、跨平台3D工程应用程序的快速开发。一些主要功能包括: 高性能、以工程为中心的可视化,使用高度优化的OpenGL或DirectX驱动程序来充分利用可用的图形硬件线程安全的C和C#接口,内部利用…

mysql索引相关知识点

1. 索引是什么? 索引是一种特殊的文件(InnoDB数据表上的索引是表空间的一个组成部分),它们包含着对数据表里所有记录的引用指针。 索引是一种数据结构。数据库索引,是数据库管理系统中一个排序的数据结构,以协助快速查询、更新数…

【Java业务需求解决方案】分布式锁应用详情,多种方案选择,轻松解决,手把手操作(非全数字编码依次加一问题)

背景: 现有编码格式为业务常量数字,每新增一条数据在基础上1,比如: 文件类型1 编码为ZS01 文件类型1下文件1 编码为ZS0101 文件类型1下文件2 编码为ZS0102 文件类型2 编码…

Vue - 3( 15000 字 Vue 入门级教程)

一:初识 Vue 1.1 收集表单数据 收集表单数据在Vue.js中是一个常见且重要的任务,它使得前端交互变得更加灵活和直观。 Vue中,我们通常使用v-model指令来实现表单元素与数据之间的双向绑定,从而实现数据的收集和更新。下面总结了…

Springboot引入swagger

讲在前面&#xff1a;在spring引入swagger时&#xff0c;由于使用的JDK、Spring、swagger 的版本不匹配&#xff0c;导致启动报错&#xff0c;一直存在版本依赖问题。所以在此声明清楚使用版本。JDK 1.8、Spring boot 2.6.13、 Swagger 2.9.2。 引入maven依赖 <dependency&…

【Canvas与艺术】绘制金色Brand Award品牌嘉奖奖章

【成果图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>金色Brand Award品牌嘉奖</title><style type"text/…

WebGL异步绘制多点

异步绘制线段 1.先画一个点 2.一秒钟后&#xff0c;在左下角画一个点 3.两秒钟后&#xff0c;我再画一条线段 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"…

Games101-几何(基本表示方法)

几何分类 几何分类&#xff1a;隐式几何和显示几何 隐式几何&#xff1a;不会告诉空间中的点具体在哪&#xff0c;告诉这些点满足的一定关系。 如球的描述 x 2 y 2 z 2 1 x^2 y^2 z^2 1 x2y2z21 缺点&#xff1a;这个面都有哪些点是不容易看出来的&#xff0c;从上述的…

[Apple Vision Pro]开源项目 Beautiful Things App Template

1. 技术框架概述&#xff1a; - Beautiful Things App Template是一个为visionOS设计的免费开源软件&#xff08;FOSS&#xff09;&#xff0c;用于展示3D模型画廊。 2. 定位&#xff1a; - 该模板作为Beautiful Things网站的延伸&#xff0c;旨在为Apple Vision Pro用户…

从300亿分子中筛出6款,结构新且易合成,斯坦福抗生素设计AI模型登Nature子刊

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。 新建了免费的人工智能中文站https://ai.weoknow.com 新建了收费的人工智能中文站https://ai.hzytsoft.cn/ 更多资源欢迎关注 全球每年有近 500 万人死于抗生素耐药性&#xff0c;因此迫切需要新的方法来对抗耐药菌株。 …

最具有影响力的三个视觉平台 | 3D高斯、场景重建、三维点云、工业3D视觉、SLAM、三维重建、自动驾驶

大家好&#xff0c;我是小柠檬 这里给大家推荐三个国内具有影响力的3D视觉方向平台&#xff01; 原文&#xff1a;最具有影响力的三个视觉平台 | 3D高斯、场景重建、三维点云、工业3D视觉、SLAM、三维重建、自动驾驶

青风环境带您了解2024第13届生物发酵展

参展企业介绍 浙江青风环境股份有限公司创立于1998年&#xff0c;是一家集科研、生产及贸易为一体的高新技术企业。公司座落于浙江省丽水市水阁工业区&#xff0c;占地面积120亩&#xff0c;建筑面积近11万平方米&#xff0c;年产值可达20亿元&#xff0c;建有标准的冷&#x…

【JAVASE】带你了解instanceof和equals的魅力

✅作者简介&#xff1a;大家好&#xff0c;我是橘橙黄又青&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;再无B&#xff5e;U&#xff5e;G-CSDN博客 1.instanceof instanceof 是 Java 的保留关键字。它的作用是测试…

编译原理实验3(基于算符优先文法分析的语法分析器 )

实验目的 加深对语法分析器工作过程的理解&#xff1b;加强对算符优先分析实现语法分析程序的掌握&#xff1b;能够产用一种编程语言实现简单的语法分析程序&#xff1b;能够使用自己编写的分析程序对简单的程序段进行语法分析。 实验要求 根据简单表达式文法构造算符优先分…