VGG16神经网络搭建

news2024/12/22 23:44:23

一、定义提取特征网络结构

将要实现的神经网络参数存放在列表中,方便使用。

数字代表卷积核的个数,字符代表池化层的结构

cfgs = {
    "vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

二、 定义提取特征网络

如果遍历过程中 v== 'M',就是定义池化层,后面的卷积核与stride步距都是网络的默认参数。

数字代表的就是定义卷积层,然后与激活函数链接在一起。

最后返回时,以非关键字参数的形式传入。

def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)

三、初始化网络

传入参数features,class_num,是否需要初始化权重。

定义分类网络结构,dropout方法缓解过拟合问题,再全连接核relu激活函数链接起来。

如果需要初始化权重,那么就会进入初始化权重的函数中。

class VGG(nn.Module):
    def __init__(self, features, class_num=1000, init_weight=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512*7*7, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Linear(2048, class_num)
        )
        if init_weight:
            self._initialize_weights()

 四、初始化权重函数

这个函数会遍历网络的每一个子模块。

如果遍历的当前层是一个卷积层,那么这个方法会初始化卷积核的权重,如果采用了偏置,那就默认初始化为0.

如果遍历的当前层是全连接层,也是用这个方法去初始化全连接层的权重,并将偏置设置为0.

    def _initialize_weights(self):
        for m in self.modules():  # 遍历模块中的每一个子模块
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

五、定义正向传播

x:输入的图像数据

features:提取网络特征结构

flatten:展平处理。因为第0个维度是batch,所以我们从第一个维度开始展平

经过分类网络结构后返回

    def forword(self, x):
        x = self.features(x)
        x = torch.flatten(x, strat_dim=1)
        x = self.classifier(x)
        return x

六、实例化模型

传入参数model_name:实例化给定的配置模型。

将key值传入字典当中

通过VGG这个类来实例化这个网络

features通过make_features这个函数来实现

最后创建对象实现VGG神经网络的搭建。 

def vgg(model_name="vgg16", **kwargs):
    try:
        cfg = cfgs[model_name]
    except:
        print("waring: model number {} not in cfgs dict".format(model_name))
    model = VGG(make_features(cfg), **kwargs)
    return  model

vgg_model = vgg(model_name='vgg13')

 运行成功,网络搭建成功。

 全部代码

import torch.nn as nn
import torch


class VGG(nn.Module):
    def __init__(self, features, class_num=1000, init_weight=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512*7*7, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Linear(2048, class_num)
        )
        if init_weight:
            self._initialize_weights()

    def forword(self, x):
        x = self.features(x)
        x = torch.flatten(x, strat_dim=1)
        x = self.classifier(x)
        return x


    def _initialize_weights(self):
        for m in self.modules():  # 遍历模块中的每一个子模块
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)



def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    "vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    try:
        cfg = cfgs[model_name]
    except:
        print("waring: model number {} not in cfgs dict".format(model_name))
    model = VGG(make_features(cfg), **kwargs)
    return  model

vgg_model = vgg(model_name='vgg13')

 全部代码与分开模块的顺序不同,但不影响最终实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1547660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习:数据降维主成分分析PCA

一、引言 1.数据分析的重要性   在当今的信息爆炸时代,数据已经渗透到各个行业和领域的每一个角落,成为决策制定、科学研究以及业务发展的重要依据。数据分析则是从这些数据中提取有用信息、发现潜在规律的关键手段。通过数据分析,我们能够…

第44期 | GPTSecurity周报

GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区,集成了生成预训练Transformer(GPT)、人工智能生成内容(AIGC)以及大语言模型(LLM)等安全领域应用的知识。在这里,您可以找…

elementui的table根据是否符合需求合并列

<el-table :data"tableData" border style"width: 100%;" :span-method"objectSpanMethodAuto"><!-- 空状态 --><template slot"empty"><div><img src"/assets/images/noData.png" /></di…

【双指针】Leetcode 查找总价格为目标值的两个商品

题目解析 LCR 179. 查找总价格为目标值的两个商品 本题很友好&#xff0c;只需要返回任意一个 算法讲解 这道题很显然就是使用对撞双指针&#xff0c;一个从左边&#xff0c;一个从右边&#xff0c;两边进行和target比较来移动 代码编写 class Solution { public:vector<…

我的创作纪念日 ---- 2024/3/26

前言 2024.3.26是我在CSDN成为创作者的第128天&#xff0c;也是我第一次真正在网上创作的第128天 当我还在日常创作时&#xff0c;突然发现我收到了一封信 我想我可以分享一下这段时间的感想以及收获 机缘 在CSDN的这段时间里&#xff0c;我学习到了很多知识&#xff0c;也…

字节跳动开源视频生成模型:AnimateDiff-Lightning视频生成加速十倍

前言 在近日&#xff0c;字节跳动再次引领AI视频生成领域的革新&#xff0c;推出了其最新研究成果——AnimateDiff-Lightning模型。这款开源的文本到视频生成模型&#xff0c;以其令人惊叹的生成速度和卓越的生成质量&#xff0c;标志着视频生成技术的一个重大突破&#xff0c…

两区域二次调频风火机组,麻雀启发式算法改进simulink与matlab联合

区域1结果 区域2结果 红色曲线为优化后结果〔风火机组二次调频〕

机器人机械手加装SycoTec 4060 ER-S电主轴高精密铣削加工

随着科技的不断发展&#xff0c;机器人技术正逐渐渗透到各个领域&#xff0c;展现出前所未有的潜力和应用价值。作为机器人技术的核心组成部分之一&#xff0c;机器人机械手以其高精度、高效率和高稳定性的优势&#xff0c;在机械加工、装配、检测等领域中发挥着举足轻重的作用…

docker快速安装Es和kibana

文章目录 概要一、Es二、kibana三、dcoker compose管理四、参考 概要 在工作过程中&#xff0c;经常需要测试环境搭建Es环境&#xff0c;本文基于Es V8.12.2来演示如何快速搭建单节点Es和kibana。 服务器默认已按装docker 一、Es 1&#xff1a;拉取镜像 docker pull elast…

【Linux】-Linux下的编辑器Vim的模式命令大全及其自主配置方法

目录 1.简单了解vim 2.vim的模式 2.1命令模式 2.2插入模式 2.3底行模式 3.vim各模式下的命令集 3.1正常&#xff08;命令模式下&#xff09; 3.1.1光标定位命令 3.1.2 复制粘贴 3.1.3 删除 3.1.4 撤销 3.1.5大小写转换 3.1.6替换 「R」&#xff1a;替换光标所到之处的字符&…

「09」媒体源:播放本地或在线的音视频GIF文件

「09」媒体源播放本地或在线的音视频GIF文件 通过媒体源功能&#xff0c;您可以添加自己想要展示的各种视频内容&#xff0c;例如自己的视频课程、电影或客户见证视频、以及GIF动画等。 &#xff08;图层叠加效果&#xff09; &#xff08;绿幕抠像叠加效果&#xff09; 缺点…

MySQL---存储过程详解

目录 一、介绍 二、基础语法 三、变量 四、流程控制 五、参数 六、游标 七、条件处理程序 八、存储函数 一、介绍 存储过程是事先经过编译并存储在数据库中的一段 SQL 语句的集合&#xff0c;调用存储过程可以简化应用开发人员的很多工作&#xff0c;减少数据在数据库和…

(五)图像的标准假彩色合成

环境&#xff1a;Windows10专业版 IDEA2021.2.3 jdk11.0.1 GDAL(release-1928-x64-gdal-3-5-2-mapserver-8-0-0) OpenCV-460.jar 系列文章&#xff1a; &#xff08;一&#xff09;PythonGDAL实现BSQ&#xff0c;BIP&#xff0c;BIL格式的相互转换 &#xff08;二&#xff…

pear-admin 项目结构讲解

上一篇文章介绍了pear-admin用到flask的技术&#xff0c; 深入代码后发现其结构也是令人眼前一亮&#xff0c; 结构化&#xff0c;模块化&#xff0c; 解耦做得非常优秀。 整个项目数据库使用migrate做了版本管理&#xff0c; 使用marshmallow做了序列化&#xff0c;这样数据库…

STL —— string(3)

目录 1. 使用 1.1 c_str() 1.2 find() & rfind() 1.3 substr() 1.4 打印网址的协议域名等 1.5 find_first_of() 2. string() 模拟实现 2.1 构造函数的模拟实现 2.2 operator[] 和 iterator 的模拟实现 2.3 push_back() & append() & 的模拟实现 2.4 ins…

C语言运算符优先级介绍

1. 引言 什么是运算符 运算符是编程中用于执行算术、比较和逻辑操作的符号。它们是构建表达式的基本工具&#xff0c;类似于数学中的加、减、乘和除。 程序片段示例: 简单的算术运算符使用 #include <stdio.h>int main() {int a 5, b 2;int sum a b; // 使用加法…

发车,易安联签约某新能源汽车领军品牌,为科技创新保驾护航

近日&#xff0c;易安联成功签约某新能源汽车领军品牌&#xff0c;为其 数十万终端用户 建立一个全新的 安全、便捷、高效一体化的零信任终端安全办公平台。 随着新能源汽车行业的高速发展&#xff0c;战略布局的不断扩大&#xff0c;技术创新不断引领其市场价值走向高点&am…

计算机网络——数据链路层(差错控制)

计算机网络——数据链路层&#xff08;差错控制&#xff09; 差错从何而来数据链路层的差错控制检错编码奇偶校验码循环冗余校验&#xff08;CRC&#xff09;FCS 纠错编码海明码海明距离纠错流程确定校验码的位数r确定校验码和数据位置 求出校验码的值检错并纠错 我们今年天来继…

C#打印50*30条码标签

示例图&#xff1a; 源码下载地址&#xff1a;https://download.csdn.net/download/tiegenZ/89035407?spm1001.2014.3001.5503

【Java程序设计】【C00379】基于(JavaWeb)Springboot的旅游服务平台(有论文)

【C00379】基于&#xff08;JavaWeb&#xff09;Springboot的旅游服务平台&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;已经做了六年的毕业设计程序开发&#xff0c…