动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

news2024/10/6 18:28:45

29残差网络ResNet

在这里插入图片描述

import torch  
from torch import nn  
from torch.nn import functional as F 
import liliPytorch as lp  
import matplotlib.pyplot as plt

# 定义一个继承自nn.Module的残差块类
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        # 第一个卷积层,使用3x3的卷积核,填充为1,步幅为指定值
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        # 第二个卷积层,使用3x3的卷积核,填充为1
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        
        # 可选的1x1卷积层,用于匹配输入输出通道数和步幅
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        
        # 批量归一化层
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        # 为什么需要两个不同的批量归一化层?
        # 1.不同的位置,不同的输入特征
        # 2.独立的参数和统计数据
    
    def forward(self, X):
        # 先通过第一个卷积层、批量归一化层和ReLU激活函数
        Y = F.relu(self.bn1(self.conv1(X)))
        # 然后通过第二个卷积层和批量归一化层
        Y = self.bn2(self.conv2(Y))
        # 如果定义了conv3,则通过conv3调整X
        if self.conv3:
            X = self.conv3(X)
        # 将输入X加到输出Y上实现残差连接
        Y += X
        # 通过ReLU激活函数
        return F.relu(Y)

# 创建一个包含输入和输出形状一致的残差块实例,并测试其输出形状
# blk = Residual(3, 3)
# X = torch.rand(4, 3, 6, 6)
# Y = blk(X)
# print(Y.shape)  # 预期输出形状:torch.Size([4, 3, 6, 6])

# 创建一个包含1x1卷积和步幅为2的残差块实例,并测试其输出形状
# blk = Residual(3, 6, use_1x1conv=True, strides=2)
# print(blk(X).shape)  # 预期输出形状:torch.Size([4, 6, 3, 3])

# 定义一个包含初始卷积层、批量归一化层、ReLU激活函数和最大池化层的顺序容器
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

# 定义一个函数,用于创建由多个残差块组成的模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        # 如果是第一个残差块且不是第一个模块,则使用1x1卷积和步幅为2
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

# 创建由残差块组成的各个模块
# *符号有多种用途,但在函数调用时,*符号主要用于将列表或元组解包。
# *resnet_block()的作用是将列表中的元素逐个传递给nn.Sequential
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

# 创建整个ResNet模型
net = nn.Sequential(
    b1, 
    b2, 
    b3, 
    b4, 
    b5,
    nn.AdaptiveAvgPool2d((1, 1)),  # 自适应平均池化层
    nn.Flatten(),  # 展平层
    nn.Linear(512, 176)  # 全连接层,输出10类
)

# 测试整个网络的输出形状
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
# AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
# Flatten output shape:    torch.Size([1, 512])
# Linear output shape:     torch.Size([1, 10])

# 设置训练参数
lr, num_epochs, batch_size = 0.05, 10, 256
# 加载训练和测试数据
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=96)
# 训练模型
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# 显示训练结果
plt.show()

# loss 0.009, train acc 0.998, test acc 0.920
# 2306.3 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1855087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

adb 查看哪些应用是双开的

adb shell pm list users 得到 这 里有 user 0 ,11,999 其中0是系统默认的,11是平行空间的,999是双开用户 pm list packages --user 999 -3 得到了999用户安装第三方应用的包名 pm list packages --user 11 -3 得到了隐私空间用户安装第三方应用的…

Git客户端安装步骤详解

git windows7 百度经验:jingyan.baidu.com 方法/步骤 1 从git官网下一个git安装包。 步骤阅读 2 点击git.exe安装程序,点击【next】 ![git的安装和配置](https://imgsa.baidu.com/exp/w500/sign7565f44ba58b87d65042ab1f37092860/21a4462309f790525e5b0144…

兰州理工大学24计算机考研情况,好多专业都接受调剂,只有计算机专硕不接收调剂,复试线为283分!

兰州理工大学(Lanzhou University of Technology),位于甘肃省兰州市,是甘肃省人民政府、教育部、国家国防科技工业局共建高校,甘肃省高水平大学和“一流学科”建设高校;入选国家“中西部高校基础能力建设工…

Android-系统开发_四大组件篇----探讨-Activity-的生命周期

当一个活动不再处于栈顶位置,但仍然可见时,这时活动就进入了暂停状态。你可能会觉得既然活动已经不在栈顶了,还怎么会可见呢? 这是因为并不是每一个活动都会占满整个屏幕,比如对话框形式的活动只会占用屏幕中间的部分…

Retrieval-Augmented Generation for Large Language Models A Survey

Retrieval-Augmented Generation for Large Language Models: A Survey 文献综述 文章目录 Retrieval-Augmented Generation for Large Language Models: A Survey 文献综述 Abstract背景介绍 RAG概述原始RAG先进RAG预检索过程后检索过程 模块化RAGModules部分Patterns部分 RAG…

15天搭建ETF量化交易系统Day9—玩大A必学网格策略

搭建过程 每个交易者都应该形成一套自己的交易系统。 很多交易者也清楚知道,搭建自己交易系统的重要性。现实中,从0到1往往是最难跨越的一步。 授人鱼不如授人以渔,为了帮助大家跨出搭建量化系统的第一步,我…

vscode用vue框架2,续写登陆页面逻辑,以及首页框架的搭建

目录 前言: 一、实现登录页信息验证逻辑 1.实现登录数据双向绑定 2.验证用户输入数据是否和默认数据相同 补充知识1: 知识点补充2: 二、首页和登录页之间的逻辑(1) 1. 修改路由,使得程序被访问先访问首页 知识点补充3&am…

[每周一更]-(第102期):认识相机格式Exif

文章目录 EXIF数据包含的信息读取EXIF数据的工具和库EXIF数据读取示例(Go语言)想法参考 相机拍摄的照片,在照片展示行无水印信息,但是照片属性中会包含比较丰富的信息,相机品牌、型号、镜头信息等,这些我们…

基于NURBS曲线的数据拟合算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1NURBS曲线基础 4.2 数据拟合原理 5.完整程序 1.程序功能描述 基于NURBS曲线的数据拟合算法,非均匀有理B样条(Non-Uniform Rational B-Splines,简称NURBS&#xf…

电子商务的未来:大数据||论主流电商大数据采集API接口的重要性

每天全球生成2.5万亿字节。金融交易,用户出版物和众多内容,数据对于组织了解消费者的行为和实际需求极为重要。 在一个互联、动态和竞争激烈的多元世界中,能够预测技术和市场变化的组织在竞争中生存下来。这是大数据管理的真正目标&#xff0…

鸿蒙生态伙伴SDK市场正式发布,驱动千行百业鸿蒙原生应用开发

6月21-23日,华为开发者大会(HDC 2024)在东莞举办。在22日举办的【鸿蒙生态伙伴SDK】论坛中,正式发布了【鸿蒙生态伙伴SDK市场】(以下简称:伙伴SDK市场),伙伴SDK市场是为开发者提供各…

PHP转Go系列 | ThinkPHP与Gin的使用姿势

大家好,我是码农先森。 安装 使用 composer 进行项目的创建。 composer create-project topthink/think thinkphp_demo使用 go mod 初始化项目。 go mod init gin_demo目录 thinkphp_demo 项目目录结构。 thinkphp_demo ├── LICENSE.txt ├── README.md …

JVM专题三:Java代码如何运行

通过前面的第一篇文章,对JVM整体脉络有了一个大概了解。第二篇文章我们通过对高级语言低级语言不同特性的探讨引出了Java的编译过程。有了前面的铺垫,咱们今天正式进入Java到底是如何运行起来的探讨。 目前大部分公司都是使用maven作为包管理工具&#x…

Flink-03 Flink Java 3分钟上手 Stream 给 Flink-02 DataStreamSource Socket写一个测试的工具!

代码仓库 会同步代码到 GitHub https://github.com/turbo-duck/flink-demo 当前章节 继续上一节的内容:https://blog.csdn.net/w776341482/article/details/139875037 上一节中,我们需要使用 nc 或者 telnet 等工具来模拟 Socket 流。这节我们写一个 …

基于YOLOv5的火灾检测系统的设计与实现(PyQT页面+YOLOv5模型+数据集)

基于YOLOv5的火灾检测系统的设计与实现 概述系统架构主要组件代码结构功能描述YOLOv5检测器视频处理器主窗口详细代码说明YOLOv5检测器类视频处理类主窗口类使用说明环境配置运行程序操作步骤检测示例图像检测视频检测实时检测数据集介绍数据集获取数据集规模YOLOv5模型介绍YOL…

python输入、输出和变量

一、变量 变量是存储数据的容器。在 Python 中,变量在使用前不需要声明数据类型,Python 会根据赋值自动推断变量类型。 定义变量: 二、输入(Input) input() 函数用于获取用户输入。默认情况下,input() 会…

在阿里云使用Docker部署MySQL服务,并且通过IDEA进行连接

阿里云使用Docker部署MySQL服务,并且通过IDEA进行连接 这里演示如何使用阿里云来进行MySQL的部署,系统使用的是Linux系统 (Ubuntu)。 为什么使用Docker? 首先是因为它的可移植性可以在任何有Docker环境的系统上运行应用,避免了在不通操作系…

Android intent 打开链接跳转到外部浏览

前言: 各位同学大家好, 最近接到一个比较诡异的需求 ,不是通常的webview 加URL显示网页 是需要跳转到外部浏览器 ,我这边处理好了就分享给大家 效果图 : 点几就跳转到外部浏览器 如图 具体代码实现: 点击打开链接并跳转外部浏览器方法 public void openBrowser(Con…

urfread刷SQL|left join|175. 组合两个表

175. 组合两个表 题目描述 左连接 因为不是所有人都有地址 如果只是用join,那么只会匹配到有地址的人。没有地址的人,就不会显示在结果中。 如果使用左连接,会把左表都显示在结果中,如果谁匹配不到右表,就为空值。 所…

Excel做简单的趋势预测

这种方法不能代替机器学习,时序分析等,只是为后面的时序预测提供一个经验认识。 step1 选中序号列(或时间列)与预测列如图1所示: 图1 step2 工具栏点击“数据”,然后再“数据”下点击“预测模型”&#x…