【Pytorch深度学习开发实践学习】【AlexNet】经典算法复现-Pytorch实现AlexNet神经网络(1)model.py

news2024/11/19 3:33:41

在这里插入图片描述

算法简介

AlexNet是人工智能深度学习在CV领域的开山之作,是最先把深度卷积神经网络应用于图像分类领域的研究成果,对后面的诸多研究起到了巨大的引领作用,因此有必要学习这个算法并能够实现它。

主要的创新点在于:

  1. 首次使用GPU进行神经网络加速训练
  2. 使用使用了非饱和的激活函数ReLU,而不是传统的sigmoid和tanh
  3. 使用了数据增强手段抑制过拟合
  4. 提出了Dropout随机失活抑制过拟合
  5. 提出了LRN局部响应归一化
  6. 使用了重叠池化抑制过拟合

model.py代码讲解

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  使用48个11*11的卷积核,步长为4,padding为2 output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # input[48, 55, 55]  output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
    

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

model.py的全部代码如上
现在逐行进行分析

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  使用48个11*11的卷积核,步长为4,padding为2 output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # input[48, 55, 55]  output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )

class AlexNet(nn.Module):
定义了一个AlexNet的类,这个类继承了nn.Module
def init(self,num_classes=1000):
定义了类的初始化函数,它有个可选的参数 num_classes是我们这个神经网络在输出的分类数

super(AlexNet,self).__init()
这是为了调用父类的初始化函数

self.features = nn.Sequential()
在这里插入图片描述
这里非常重要,我们可以去Pytorch的官方文档上看看,
官方的解释是:
模块将按照传入构造函数的顺序添加到其中。另外,也可以传入一个有序字典的模块。Sequential的forward()方法接受任何输入,并将其转发给它包含的第一个模块。然后,对于每个后续模块,它将输出“链接”到输入,最终返回最后一个模块的输出。
Sequential相对于手动调用一系列模块的优势在于,它允许将整个容器视为单个模块,这样对Sequential执行的转换将应用于其存储的每个模块(它们分别是Sequential的注册子模块)。
Sequential和torch.nn.ModuleList之间有什么区别?ModuleList就像它的名字一样-用于存储Module的列表!另一方面,Sequential中的层以级联方式连接。

论文中的AlexNet网络结构图如下:
在这里插入图片描述
AlexNet是第一个网络结构开始变得更加复杂的神经网络模型(Lenet)只有两个卷积层和两个全连接层,而AlexNet有5个卷积层和3个全连接层,对于逐渐复杂的网络结构,我们可以利用Sequential函数搭建序列化的网络模块

比如这里我们首先定义了一个features模块
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
第一个卷积层 输入是2242243 48个1111的卷积核 步长是4,填充是2
输出是55
55*48

nn.ReLU(inplace=True),ReLU激活函数

nn.MaxPool2d(kernel_size=3, stride=2),
定义一个最大池化层,使用3x3的池化核,步长为2。这将进一步减少特征图的尺寸。

nn.Conv2d(48, 128, kernel_size=5, padding=2),
又是一个卷积层,输入是272748 128个55的卷积核 填充是2,输出是2727*128

然后以此类推
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), 又是激活函数和池化,池化后输出 1313128
nn.Conv2d(128, 192, kernel_size=3, padding=1), 输入1313128 输出1313192

nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1),输入1313192 输出1313192

nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), 输入1313192
输出1313128

nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), 输入1313128 输出 66128

self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

第二个模块,上一个是5层卷积层加3层池化层提取特征
下面这个模块就是全连接层做分类

首先是drouput随机失活抑制过拟合的操作
然后是 nn.Linear(128 * 6 * 6, 2048),12866的原因是全连接层是接着前面的最后一个也是第三个池化层,池化层的输出就是12866
后面再接两个全连接层,最后一个全连接层的输出就是对1000个类的预测结果

   def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

def forward(self, x):

定义一个名为forward的方法,这是PyTorch中自定义神经网络层或模型的标准做法。这个方法描述了输入数据x通过网络的前向传播过程。
x = self.features(x)
将输入数据x传递给feature模块
x = torch.flatten(x, start_dim=1)

使用PyTorch的flatten函数将特征图x在指定的维度(start_dim=1,通常是指从第二个维度开始,即特征图的深度维度)展平。这通常是为了将多维的特征图转换为一维的张量,以便输入到全连接层。
这里要重点说明一下,在feature后输出的x是一个四维的参数(B,C,H,W)分别是batchsize channel 高、宽 而这个函数的意思是从第二维channel开始,对后三维 通道数、宽、高进行展开,转为一维的向量输入全连接层

x = self.classifier(x)
将展平后的特征x传递给classifier
return x
返回经过分类器处理后的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1478846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PyTorch-Ignite的介绍与快速上手

PyTorch-Ignite 是一个用于 PyTorch 的高级库,旨在帮助开发者更快、更简洁地编写可复用的代码来进行深度学习实验。它由 PyTorch 社区开发,提供了一套灵活的抽象,用于构建和管理训练和验证循环,而无需牺牲 PyTorch 的灵活性和强大…

四、《任务列表案例》后端程序实现和测试

本章概要 准备工作功能实现前后联调 4.1 准备工作 数据库脚本 CREATE TABLE schedule (id INT NOT NULL AUTO_INCREMENT,title VARCHAR(255) NOT NULL,completed BOOLEAN NOT NULL,PRIMARY KEY (id) );INSERT INTO schedule (title, completed) VALUES(学习java, true),(学…

电力运维是做什么的?电力行业智能运维工作内容?

电力行业智能运维工作内容具体涉及哪些关键任务?实施智能运维过程中,如何利用现代信息技术、人工智能和大数据分析来提升电力系统的运行效率与维护响应速度?在电力行业中引入智能运维后,对于预防性维护、故障诊断、设备寿命预测以及成本控制…

react native中如何实现tab切换页面以及页面可以左右滑动效果

react native中如何实现tab切换页面以及页面可以左右滑动效果 效果示例图主体代码 效果示例图 主体代码 import React, {useRef, useState} from react; import {View,ScrollView,Text,StyleSheet,Dimensions,Animated, } from react-native; import {pxToPd} from ../../comm…

Linux系统——LAMP架构

目录 一、LAMP架构组成 1.LAMP定义 2.各组件的主要作用 3.CGI和FastCGI 3.1CGI 3.3CGI和FastCGI比较 4.PHP 4.1PHP简介 4.2PHP的Opcode语言 4.3PHP设置 二、LAMP架构实现 1.编译安装Apache httpd服务 2.编译安装Mysql 3.编译安装PHP 4.安装论坛 5.搭建博客 W…

力扣区间题:合并区间、插入区间

我们可以将区间按照左端点升序排列,然后遍历区间进行合并操作。 我们先将第一个区间加入答案,然后依次考虑之后的每个区间: 如果答案数组中最后一个区间的右端点小于当前考虑区间的左端点,说明两个区间不会重合,因此…

当大语言模型遇到AI绘画-google gemma与stable diffusion webui融合方法-矿卡40hx的AI一体机

你有想过建一台主机,又能AI聊天又能AI绘画,还可以直接把聊天内容直接画出来的机器吗? 当Google最新的大语言模型Gemma碰到stable diffusion webui会怎么样? 首先我们安装stable diffusion webui(automatic1111开源项目&#xff…

【基于ChatGPT大模型】GIS应用、数据清洗、统计分析、论文助手、项目基金助手、科研绘图、AI绘图

以ChatGPT、LLaMA、Gemini、DALLE、Midjourney、Stable Diffusion、星火大模型、文心一言、千问为代表AI大语言模型带来了新一波人工智能浪潮,可以面向科研选题、思维导图、数据清洗、统计分析、高级编程、代码调试、算法学习、论文检索、写作、翻译、润色、文献辅助…

Google Genie:创意互动环境

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

渗透测试靶场环境搭建

1.DVWA靶场 DVWA(Damn Vulnerable Web Application)是一个用来进行安全脆弱性鉴定的PHP/MySQL Web应用,包含了OWASP TOP10的所有攻击漏洞的练习环境,旨在为安全专业人员测试自己的专业技能和工具提供合法的环境,同时…

完美解决git 执行git push origin master指令 报错command not found

问题描述 报错信息为:在提交项目时的操作:找不到命令行 解决方案 (1)可以通过如下命令进行代码合并【注:pullfetchmerge】 git pull --rebase origin master(2)再执行语句: git p…

Linux(CentOS)学习

一、认识Linux 1、如何修改Linux时区 2、配置固定IP 3、重启网络服务 3、小技巧快捷键 4、环境变量设置 5、Linux文件的上传和下载 6、压缩和解压 二、基础命令 1、目录命令 (1、)查看目录内容(ls) 1、ls //查看当前目录内容 2、- a //显示隐藏内容 3…

Spatom——利用图神经网络进行蛋白质-蛋白质结合位点预测的新工具

介绍一个蛋白质-蛋白质结合位点预测的新工具——Spatom,这是一个图神经网络框架。其发布在brief in bioinformatics上面。 Paper and tool links 文章,网页工具和github链接如下 paper link: Spatom: a graph neural network for structure-based prot…

lv20 QT事件

1 事件模型 2 事件处理 virtual void keyPressEvent(QKeyEvent *event) virtual void keyReleaseEvent(QKeyEvent *event) virtual void mouseDoubleClickEvent(QMouseEvent *event) virtual void mouseMoveEvent(QMouseEvent *event) virtual void mousePressEvent(QMou…

【Android12】Monkey压力测试源码执行流程分析

Monkey压力测试源码执行流程分析 Monkey是Android提供的用于应用程序自动化测试、压力测试的测试工具。 其源码路径(Android12)位于 /development/cmds/monkey/部署形式为Java Binary # development/cmds/monkey/Android.bp // Copyright 2008 The Android Open Source Proj…

《PyTorch深度学习实践》第九讲多分类问题

一、 1、softmax的输入不需要再做非线性变换,也就是说softmax之前不再需要激活函数。softmax两个作用,如果在进行softmax前的input有负数,通过指数变换,得到正数。所有类的概率求和为1。 2、y的标签编码方式是one-hot。one-hot是…

java爬取深圳新房备案价

Java爬取深圳新房备案价 这是我做好效果,一共分3个页面 1、列表;2、统计;3、房源表 列表 价格分析页面 房源页面 一、如何爬取 第一步:获取深圳新房备案价 链接是:http://zjj.sz.gov.cn/ris/bol/szfdc/index.aspx 第二步:通过楼盘名查询获取明细 链接:http://z…

就业班 2401--2.27 Linux Day6--管道和重定向

管道与重定向 只有在开水里,茶叶才能展开生命浓郁的香气. 一、重定向 标准输入、标准正确输出、标准错误输出 进程在运行的过程中根据需要会打开多个文件,每打开一个文件会有一个数字标识。这个标识叫文件描述符。 进程使用文件描述符来管理打开的文件…

Android PDFView 提示401 pom

背景 在开发安卓app,使用PDF组件来解析URL地址 ,从github找到一个开源组件 AndroidPdfViewer 遇到一个大坑,一直提示下载依赖401 pom 打开控制台链接弹出需要登录jitpack 原因分析: 这个组件项目依赖库链接到了需要鉴权的…

【airtest】自动化入门教程(一)AirtestIDE

目录 一、下载与安装 1、下载 2、安装 3、打开软件 二、web自动化配置 1、配置chrome浏览器 2、窗口勾选selenium window 三、新建项目(web) 1、新建一个Airtest项目 2、初始化代码 3、打开一个网页 四、恢复默认布局 五、新建项目&#xf…