【构建卷积神经网络】

news2024/11/27 11:38:14

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致

全连接层:batch784,各个像素点之间都是没有联系的。
卷积层:batch
12828,各个像素点之间是有联系的。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

首先读取数据

  • 分别构建训练集和测试集(验证集)
  • DataLoader来迭代取数据
# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片

# 训练集
train_dataset = datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 

# 测试集
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

卷积网络模块构建

  • 一般卷积层,relu层,池化层可以写成一个套餐
  • 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

图像是二维卷积 conv2
视频是三维卷积 conv3
单向量是一维卷积 conv1
官网有关conv2d的输出宽度和长度的计算公式
在这里插入图片描述

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # 1:灰度图;3:RGB
                out_channels=16,            # 要得到几多少个特征图,即是卷积核的个数 
                kernel_size=5,              # 卷积核大小
                stride=1,                   # 步长
                padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
            ),                              # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),                      # relu层
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出 (32, 7, 7)
        )
        
        self.conv3 = nn.Sequential(         # 下一个套餐的输入 (32, 7, 7)
            nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (64, 7, 7)
            nn.ReLU(),             # 输出 (64, 7, 7)
        )
        
        self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 64 * 7 * 7)
        output = self.out(x)
        return output

准确率作为评估标准

def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] 
    rights = pred.eq(labels.data.view_as(pred)).sum() 
    return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = [] 
    
    for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环
        net.train()                             
        output = net(data) 
        loss = criterion(output, target) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        right = accuracy(output, target) 
        train_rights.append(right) 

    
        if batch_idx % 100 == 0: 
            
            net.eval() 
            val_rights = [] 
            
            for (data, target) in test_loader:
                output = net(data) 
                right = accuracy(output, target) 
                val_rights.append(right)
                
            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1], 
                100. * val_r[0].numpy() / val_r[1]))

在这里插入图片描述

练习

  • 再加入一层卷积,效果怎么样?
  • 当前任务中为什么全连接层是3277 其中每一个数字代表什么含义

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/847521.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LaTex使用技巧21:设置中文环境、字体、行间距和页边距

我在Overleaf上编写我的中文LaTex,设置了中文环境,字体、行间距以及页间距,记录一下方便以后查询。 使用中文环境命令为: \usepackage{xeCJK}可以使用Overleaf上支持的中文字体Fonts for CJK Chinese,设置字体的命令…

探究使用HTTP爬虫ip后无法访问网站的原因与解决方案

在今天的文章中,我们要一起来解决一个常见问题:使用HTTP爬虫ip后无法访问网站的原因是什么,以及如何解决这个问题。我们将提供一些实际的例子和操作经验,帮助大家解决HTTP爬虫ip无法访问网站的困扰。 1、代理服务器不可用 使用HT…

Debian 12.1 正式发布

导读Debian 12.1 现已发布,这是对稳定发行版 Debian 12(代号 Bookworm )的首次更新。本次发布主要增加了安全问题的修正,并对严重问题进行了一些调整。 一些更新内容包括: 妥善处理系统用户的创建;修复 eq…

ChatGLM-RLHF(五)-PPO(Proximal Policy Optimization)原理实现代码逐行注释

一,前言 从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO代码实现。 上章我们…

Java编程实践:实现Java接口的方法也建议加上@Override注解

说明 作为一个Java编程实践,实现接口的方法也强烈建议加上Override注解。这样做的好处: 阅读代码的时候,一眼就能看出来是新增的函数,还是实现接口的函数。加上Override注解,如果拼写错误,编译器马上就能…

电视盒子哪款好?内行整理超值网络电视盒子推荐

从事电视盒子这行已经五年了,很多朋友在挑选电视盒子时会咨询我的意见,我耗费半个月时间整理了超值网络电视盒子推荐,盘点目前最值得入手的五款电视盒子机型,想买电视盒子不知道电视盒子哪款好可以从下面五款中挑选: 榜…

VIM 编辑器: Bram Moolenaar

VIM 用了很长时间, 个人的 VIM 配置文件差不多10年没有更新了。以前写程序的时候, 编辑都用这个。 linux kernel, boost规模的代码都不在话下。现在虽然代码写的少了,依然是我打开文件的首选。 现在用手机了,配个蓝牙键…

idea中如何处理飘红提示

idea中如何处理飘红提示 在写sql时,总是会提示各种错误 查找资料,大部分都是说关提示,这里把错误提示选择为None即可 关掉以后,也确实不显示任何提示了,但总有一种掩耳盗铃的感觉 这个sms表明明存在,但是还…

android studio安卓真机调试

把usb 手机开启到usb调试模式,然后用usb线连接手机 安装adb 如果下载速度很慢,请使用vpn 终端需要先安装brew /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"使用brew安装adb brew install android-platfor…

面试遇到登录功能测试用例设计,你回答对了吗

给你一个登录功能,如何设计测试用例 哪怕是最常用最小的一个登录功能,其实涉及到的测试用例也是非常多的,这个题目通常会通过面试来考察求职者的综合能力,尤其是测试用例的设计思维,因为你即使你背了各种测试用例设计…

开发一款保护程序检测进程假死,精准打开保护的程序

网上很多保护程序都收费, 有免费的,可以将一般程序改成windows服务,我没用,应该很强大 功能点: 1,首先要有能加入保护程序的功能 2,不断的轮询检测程序是否已经运行 3,不断的轮询检测程序是否假死 4,一些其他检测 将保护的程序存入文件列表 保护程序运行时加载…

嵌入式开发学习(STC51-9-led点阵)

内容 点亮一个点; 显示数字; 显示图像; LED点阵简介 LED 点阵是由发光二极管排列组成的显示器件 通常应用较多的是8 * 8点阵,然后使用多个8 * 8点阵可组成不同分辨率的LED点阵显示屏,比如16 * 16点阵可以使用4个8 *…

恒盛策略:医药股反弹,掀涨停潮!

今天上午A股商场涨跌互现,上证指数一度显着跌落,但临近上午收盘时翻红。 作为行情风向标,券商板块盘中一度大幅跌落,但随后快速收窄跌幅,板块内分解较为显着,其中市值超越1000亿元的龙头券商之一的中金公司…

OPENCV C++(八)HOG的实现

hog适合做行人的识别和车辆识别 对一定区域的形状描述方法 可以表示较大的形状 把图像分成一个一个小的区域的直方图 用cell做单位做直方图 计算各个像素的梯度强度和方向 用3*3的像素组成一个cell 3*3的cell组成一个block来归一化 提高亮度不变性 常用SVM分类器一起使用…

HTML Emoji和Emoji 参考手册

HTML表情可以用来在网页中插入各种表情符号图标,丰富了网页表现形式和视觉效果。下面是一些常用HTML表情代码大全📜 Emoji 参考手册 HTML Emoji 扩展:📌 HTML 自定义实现emoji - (freesion.com)

native vlan tag设置错误,导致交换机无法访问

一同事找来,说他的一个测试交换机,下挂一些测试设备,能正常访问,但交换机的ip192.168.100.128却无法telnet访问,ping过去显示无法访问目的主机,让给看一下原因? 已知组网这个交换机接在交换机的…

用于实体对齐的联合学习实体和关系表示2019 AAAI 8.7+8.8

用于实体对齐的联合学习实体和关系表示 摘要介绍相关工作实体对齐图卷积网络 问题公式我们的方法整体架构初步实体对齐图卷积层对齐训练 近似关系表示联合实体和关系对齐 实验总结 摘要 实体对齐是在不同知识图之间集成异构知识的一种可行方法。该领域的最新发展通常采用基于嵌…

端口映射软件可以做什么?快解析如何设置端口映射?

说到端口映射,首先说说nat。简单地说,nat就是在局域网内部网络中使用内部地址,而当内部节点要与外部网络进行通讯时,就在网关处,将内部地址替换成公用地址,从而在外部公网(internet)…

网络系统观察之道

什么是“可观察性”? 当然,“可观察性”这个术语并不是我们发明的。我们最开始从用户那里听到这个概念,这些用户主要来自网站可靠性工程 (SRE) 社区。有些信息来源认为,这个术语起源于硅谷巨头(如 Twitter&#xff09…

CTF流量题解http2.pcapng

使用wireshark工具打开流量文件。 根据网络协议进行分组排序,对流量文件里面的内容进行观察。 16进制转换,16进制转换文本字符串,在线16进制转换 | 在线工具 (sojson.com) Base64编码/解码器,在线解码Base64 (sojson.com) https:…