刘二大人《Pytorch深度学习实践》第十一讲卷积神经网络(高级篇)

news2024/9/21 14:26:01

文章目录

  • Inception-v1实现
  • Skip Connect实现

Inception-v1实现

在这里插入图片描述
 Inception-v1中使用了多个11卷积核,其作用:
(1)在大小相同的感受野上叠加更多的卷积核,可以让模型学习到更加丰富的特征。传统的卷积层的输入数据只和一种尺寸的卷积核进行运算,而Inception-v1结构是Network in Network(NIN),就是先进行一次普通的卷积运算(比如5
5),经过激活函数(比如ReLU)输出之后,然后再进行一次11的卷积运算,这个后面也跟着一个激活函数。11的卷积操作可以理解为feature maps个神经元都进行了一个全连接运算。
(2)使用1*1的卷积核可以对模型进行降维,减少运算量。当一个卷积层输入了很多feature maps的时候,这个时候进行卷积运算计算量会非常大,如果先对输入进行降维操作,feature maps减少之后再进行卷积运算,运算量会大幅减少。

import torch
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize ((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST (root='./dataset/mnist/', train = True, download= True, transform = transform)
train_loader = DataLoader (train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST (root='./dataset/mnist/', train = False, download= True, transform = transform)
test_loader = DataLoader (test_dataset, shuffle = False, batch_size = batch_size)

class InceptionA (torch.nn.Module):
  def __init__(self, in_channels):
    super (InceptionA, self).__init__()
    self.branch1x1 = torch.nn.Conv2d (in_channels,16, kernel_size=1)

    self.branch5x5_1 = torch.nn.Conv2d (in_channels, 16, kernel_size=1)
    self.branch5x5_2 = torch.nn.Conv2d (16, 24, kernel_size=5, padding=2)

    self.branch3x3_1 = torch.nn.Conv2d (in_channels, 16, kernel_size=1)
    self.branch3x3_2 = torch.nn.Conv2d (16, 24, kernel_size=3, padding=1)
    self.branch3x3_3 = torch.nn.Conv2d (24, 24, kernel_size=3, padding=1)

    self.branch_pool = torch.nn.Conv2d (in_channels, 24, kernel_size=1)

  def forward (self, x):
    branch1x1 = self.branch1x1 (x)

    branch5x5 = self.branch5x5_1 (x)
    branch5x5 = self.branch5x5_2 (branch5x5)

    branch3x3 = self.branch3x3_1(x)
    branch3x3 = self.branch3x3_2(branch3x3)
    branch3x3 = self.branch3x3_3(branch3x3)

    branch_pool = F.avg_pool2d (x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool (branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
    return torch.cat (outputs, dim = 1)
  
class Net (torch.nn.Module):
  def __init__(self):
    super (Net, self).__init__()
    self.conv1 = torch.nn.Conv2d (1, 10, kernel_size = 5)
    self.conv2 = torch.nn.Conv2d (88, 20, kernel_size = 5)

    self.incep1 = InceptionA (in_channels=10)
    self.incep2 = InceptionA (in_channels=20)

    self.mp = torch.nn.MaxPool2d (2)
    self.fc = torch.nn.Linear (1408, 10)
  
  def forward (self, x):
    in_size = x.size (0)
    x = F.relu (self.mp (self.conv1(x)))
    x = self.incep1 (x)
    x = F.relu (self.mp (self.conv2(x)))
    x = self.incep2 (x)
    x = x.view (in_size, -1)
    x = self.fc (x)

    return x

model = Net()
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
 
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()
 
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0
 
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d %% ' % (100*correct/total))
 
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

Skip Connect实现

在这里插入图片描述
 跳连结构解决了梯度消失的问题

import torch
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize ((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST (root='./dataset/mnist/', train = True, download= True, transform = transform)
train_loader = DataLoader (train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST (root='./dataset/mnist/', train = False, download= True, transform = transform)
test_loader = DataLoader (test_dataset, shuffle = False, batch_size = batch_size)

class ResidualBlock (torch.nn.Module):
  def __init__(self, channels):
    super (ResidualBlock, self).__init__()
    self.channels = channels
    self.conv1 = torch.nn.Conv2d (channels, channels, kernel_size = 3, padding = 1)
    self.conv2 = torch.nn.Conv2d (channels, channels, kernel_size = 3, padding = 1)

  def forward (self, x):
    y = F.relu (self.conv1(x))
    y = self.conv2 (y)
    return F.relu (x + y)
 
  
class Net (torch.nn.Module):
  def __init__(self):
    super (Net, self).__init__()
    self.conv1 = torch.nn.Conv2d (1, 16, kernel_size = 5)
    self.conv2 = torch.nn.Conv2d (16, 32, kernel_size=5)
    
    self.rblock1 = ResidualBlock(16)
    self.rblock2 = ResidualBlock(32)

    self.mp = torch.nn.MaxPool2d (2)
    self.fc = torch.nn.Linear (512, 10)
    
  
  def forward (self, x):
    in_size = x.size (0)

    x = self.mp (F.relu (self.conv1(x)))
    x = self.rblock1(x)
    x = self.mp (F.relu (self.conv2(x)))
    x = self.rblock2(x)

    x = x.view (in_size, -1)
    x = self.fc (x)

    return x

model = Net()
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
 
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()
 
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0
 
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d %% ' % (100*correct/total))
 
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440777.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows系统本地批量预览svg图标

一、为何需要此操作 目前前端使用图标大致分为两类: iconfont方式:通过引入在线或者下载到本地的iconfont.css类文件实现显示图标第二类是封装图标组件,通过传入指定的svg名称快速生成图标 目前第二种是比较方便的,不需要频…

【记录】Truenas Scale|中危漏洞,需要SMB签名

部分内容参考:等保测试问题——需要SMB签名(SMB Signing not Required) 以及 ChatGPT。 Truenas常用SMB服务,但默认并不开启SMB签名。这样具有中间人攻击的风险。 一、漏洞详情 1.1 漏洞报告 漏洞提示如下: 1.2 漏洞介绍 SMB是一个协议名…

Mybatis-Plus -01 Mybatis-Plus入门

Mybatis-Plus入门 1 Mybatis-Plus1.1 Mybatis-Plus简介1.2 Mybatis-Plus特性1.3 Mybatis-Plus框架结构1.1 Mybatis-Plus简介1.2 Mybatis-Plus特性1.3 Mybatis-Plus框架结构 2 Mybatis-Plus 快速入门2.1 数据库准备2.2 导入mybatis-plus依赖2.3 Spring整合MP2.4 编写实体类2.5 编…

i.MX8MP平台开发分享(gicv3篇)-- set_handle_irq及中断路由过程分析

专栏目录:专栏目录传送门 平台内核i.MX8MP5.15.71 文章目录 set_handle_irqhard中断入口 set_handle_irq(gic_handle_irq);set_handle_irq 这个函数的功能很简单,将gic_handle_irq设置为中断处理函数。在发生中断异常后,内核就会切入到这个…

060201面积-定积分在几何学上的应用-定积分的应用

文章目录 1 平面图形的面积1.1 直角坐标情形1.2 极坐标情形1.2.1 极坐标的定义1.1.2 曲边扇形的面积 结语 1 平面图形的面积 1.1 直角坐标情形 ①平面图形由 y f ( x ) , y 0 , x a , y b yf(x),y0,xa,yb yf(x),y0,xa,yb围成图像的面积,如下图1.1-1所示&#…

防洪决策指挥系统(Axure高保真原型)

使用Axure制作的rp高保真原型防洪决策指挥系统可用于行业参考、实际业务需求开发、学习交流使用,本原型需求可以作为开发使用,业务需求均为作者本人行业经验。本系统包括水系展示系统、城区调度决策系统、实时监测预警和防洪调度四大功能模块的界面。 原…

MICCAI 2023 FLARE国际竞赛:打造腹部泛癌CT分割Foundation Models

竞赛官网 CodaLab - Competitionhttps://codalab.lisn.upsaclay.fr/competitions/12239 背景介绍 腹部器官是相当常见的患癌部位,例如结直肠癌和胰腺癌,分别位列癌症死亡率排名的第二位和第三位。Computed tomography (CT) 成像可以为医生提供重要的诊…

前端错误合集

Uncaught Reference Error: xx is not defined 未捕获的引用错误:未定义xx 原因 1.关键字写错了 解决办法 1.修改成正确的关键字 NAN 计算错误 原因 计算时数据类型不同 解决办法 使用数据类型相同的数据进行计算 Uncaught SyntaxError: Invalid left-h…

计算广告(十四)

营销是一个涉及产品、服务或品牌从概念到消费者的全过程的商业活动。它包括分析市场需求、识别潜在消费者、制定和实施策略以满足他们的需求、创造价值和实现销售。营销的目标是在满足客户需求的同时,实现企业的利润和业务增长。 营销涉及以下几个关键环节&#xf…

vim编辑器的使用介绍

文章目录 vim编辑器的使用介绍vim的缓存、恢复与打开时的警告信息vim的额外功能可视化区块多文件编辑多窗口功能vim的关键词补全功能vim环境设置与记录:~/.vimrc、~/.viminfovim的环境设置参数 vim常用的命令示意图 其他vim使用注意事项中文编码的问题DOS与Linux的换…

【PyTorch】课堂测试一:线性回归的求解

作者🕵️‍♂️:让机器理解语言か 专栏🎇:PyTorch 描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。 寄语💓:🐾没有白走的路,每一步都算数&#…

如何在自定义数据集上训练YOLOv8的各个模型

YOLOv8效果图(可以应用到图片和视频): 四个模式命令 yolo taskdetect modepredict modelmodel/yolov8n.pt sourceinput/test.mp4 showTrueyolo tasksegment modepredict modelmodel/yolov8x-seg.pt sourceinput/zidane.jpg showTrueyolo tas…

JavaSE-part2

文章目录 Day07 IO流1.IO流1.1背景介绍1.2File类1.2.1常用方法 1.3IO流原理1.4IO流的分类1.4.1InputStream 字节输入流1.4.1.1FileInputStream1.4.1.2FileOutPutStream1.4.1.3练习 1.4.2Reader and Writer1.4.2.1FileReader1.4.2.2FileWriter 1.4.3节点流和处理流1.4.3.1处理流…

MSNet网络结构与代码搭建深入解读

模型结构 1、首先,将多光谱遥感图像的波段分为可见光和不可见光两组,然后进行分组同步特征提取; 代码 先看总体结构,主代码 __init__定义了声明MSNet模型有哪些类,MSNet的forward方法规定数据如何在层之间流动。 1、首先是获得图片的输入尺寸input_size = (rgbnnd.size(…

Python数据结构与算法-动态规划(钢条切割问题)

一、动态规划(DP)介绍 1、从斐波那契数列看动态规划 (1)问题 斐波那契数列递推式: 练习:使用递归和非递归的方法来求解斐波那契数列的第n项 (2)递归方法的代码实现 import time # 递…

Spark----RDD(弹性分布式数据集)

RDD 文章目录 RDDRDD是什么?为什么需要RDD?RDD的五大属性WordCount中的RDD的五大属性如何创建RDD?RDD的操作两种基本算子/操作/方法/API分区操作重分区操作聚合操作四个有key函数的区别 关联操作排序操作 RDD的缓存/持久化cache和persistchec…

Java学习-MySQL-DQL数据查询-联表查询JOIN

Java学习-MySQL-DQL数据查询-联表查询JOIN 1.分析需求,查找那些字段 2.分析查询的字段来自哪些表 3.确定使用哪种连接查询 4.确定交叉点 5.确定判断条件 操作描述inner join返回左右表的交集left join返回左表,即使右表没有right join返回右表&#xf…

iptables深度总结--基础篇

iptables 五表五链 链:INPUT OUTPUT FORWARD PREROUTING POSTROUTING 表:filter、nat、mangle、raw、security 数据报文刚进网卡,还没有到路由表的时候,先进行了prerouting,进入到路由表,通过目标地址判…

FFMPEG 关于smaple_fmts的理解及ffplay播放PCM

问题 当我将一个aac的音频文件解码为原始的PCM数据后,使用ffplay播放测试是否成功时,需要提供给ffplay 采样率,通道数,PCM的格式类型 3个参数,否则无法播放! 所以使用ffprobe 查看原来的aac文件信息&…

Python手写板 画图板 签名工具

程序示例精选 Python手写板 画图板 签名工具 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<Python手写板 画图板 签名工具>>编写代码&#xff0c;代码整洁&#xff0c;规则&am…