Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)

news2024/9/23 23:28:48

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)


文章目录

  • 系列文章目录
  • 一、常见的现有网络模型
  • 二、VGG16模型


一、常见的现有网络模型

  1. AlexNet: AlexNet是一个经典的卷积神经网络模型,由Alex Krizhevsky等人提出。它是在ImageNet数据集上取得突破性性能的模型,具有8个卷积层和3个全连接层。
  2. VGG: VGG是由Karen Simonyan和Andrew Zisserman提出的一系列卷积神经网络模型。它以其简单而深层的架构而闻名,有16层或19层的变种。VGG模型以其强大的特征提取能力而受到广泛使用。
  3. ResNet: ResNet是由Kaiming He等人提出的深度残差网络。它通过引入残差连接解决了深层网络训练中的梯度消失和梯度爆炸问题。ResNet模型具有不同深度的变种,如ResNet-18、ResNet-34、ResNet-50等。
  4. DenseNet: DenseNet是由GaoHuang等人提出的一种密集连接卷积神经网络模型。它的特点是在网络中的每一层都与所有后续层进行连接,从而增加了信息传递和特征重用的效果。
  5. Inception: Inception是由ChristianSzegedy等人提出的一系列卷积神经网络模型,其中包含了多种并行的卷积分支。Inception模型以其高效的计算和强大的表示能力而受到广泛关注。
  6. MobileNet: MobileNet是一系列轻量级的卷积神经网络模型,旨在在计算资源受限的环境下实现高效的计算。MobileNet模型通过深度可分离卷积等技术来减少参数量和计算量。

注意:PyTorch通过torchvision.models模块提供了更多的预训练模型.

官网的预训练模型有如下几种:

在这里插入图片描述

二、VGG16模型

torchvision.models.vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any)

VGG-16是一种具有16个卷积层和3个全连接层的卷积神经网络模型,由Karen Simonyan和Andrew Zisserman在2014年提出。

参数如下:
weights(可选):指定要加载的预训练权重。可以是None(默认值)表示不加载预训练权重,或是指定为预定义的某个预训练权重标识符。
progress:指示下载进度条的显示设置,默认为True显示下载进度条。
**kwargs:其它可选参数,传递给VGG-16模型的基类torchvision.models.VGG。

创建VGG16模型并打印输出结果

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

从上述运行结果可知:VGG16网络是由13层卷积层和3层全连接层组成,最后网络输出一共有1000个分类结果。

修改VGG16模型结构
 使用add_module()方法在VGG16模型后增加一个线性层,实现将VGG16的1000个类别输出为类似CIFAR10的10个类别,代码如下:

import torchvision.models as models
from torch import nn

vgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)

# print(vgg16_false)
vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_true)

运行结果如下:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
  (add_linear): Linear(in_features=1000, out_features=10, bias=True)
)

由上述可以知道,add_linear是在classifier外面的,如果要在classifier里面,可以将

vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

替换为

vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/854875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

photoshop PS 查看像素坐标、像素颜色、像素HSB颜色

方法一 photoshop 菜单栏 窗口菜单->信息菜单项(F8), 在信息窗口里会有当前的 x,y坐标 方法二 photoshop 菜单栏 视图菜单->标尺菜单项(ctrlR) 宽度和高度边上都有标尺,默认的是厘米,右键单机宽度和高度边上…

【自动化测试框架】关于unitttest你需要知道的事

一、UnitTest单元测试框架提供了那些功能 1.提供用例组织和执行 如何定义一条“测试用例”? 如何灵活地控制这些“测试用例”的执行? 2.提供丰定的断言方法 当测试用例的执行结果与预期结果不一致时,判定测试用例失败。在自动化测试中,通过“断言”…

【IMX6ULL驱动开发学习】04.应用程序和驱动程序数据传输和交互的4种方式:非阻塞、阻塞、POLL、异步通知

一、数据传输 1.1 APP和驱动 APP和驱动之间的数据访问是不能通过直接访问对方的内存地址来操作的,这里涉及Linux系统中的MMU(内存管理单元)。在驱动程序中通过这两个函数来获得APP和传给APP数据: copy_to_usercopy_from_user …

MySQL之深入InnoDB存储引擎——Undo页

文章目录 一、UNDO日志格式1、INSERT操作对应的UNDO日志2、DELETE操作对应的undo日志3、UPDATE操作对应的undo日志1)不更新主键2)更新主键的操作 3、增删改操作对二级索引的影响 二、UNDO页三、UNDO页面链表四、undo日志具体写入过程五、回滚段1、回滚段…

Java课题笔记~ Servlet编程

1.Servlet编程基础 (1)什么是Servlet Servlet是基于Java语言的Web编程技术,部署在服务器端的Web容器里,获取客户端的访问请求,并根据请求生成响应信息返回给客户端。 创建Servlet的方式,有 如下图:一般创建Servlet都…

Vue3 表单输入绑定简单应用

去官网学习→表单输入绑定 | Vue.js 运行示例&#xff1a; 代码&#xff1a;HelloWorld.vue <template><div class"hello"><h1>Vue 表单输入绑定</h1><input type"text" placeholder"输入框" v-model"msg"…

Open_PN笔记

>>>仅用作学习用途 1.准备好需要用到的工具 官网下载地址&#xff1a; openvpn 客户端下载地址&#xff1a; https://swupdate.openvpn.org/community/releases/openvpn-install-2.4.5-I601.exe EasyRSA下载地址&#xff1a; https://githu…

4、Rocketmq之存储原理

CommitLog ~ MappedFileQueue ~ MappedFile集合

web-初始前端

不区分大小写&#xff0c;单双引号&#xff0c; <html><head><title>初识HTML</title></head><body><h1>Hello world!</h1><img src OIF-C.jfif/></body> </html> <!-- 文件格式 --> <!DOCTYPE h…

PoseFormer:基于视频的2D-to-3D单人姿态估计

3D Human Pose Estimation with Spatial and Temporal Transformers论文解析 摘要1. 简介2. Related Works2.1 2D-to-3D Lifting HPE2.2 GNNs in 3D HPE2.3 Vision Transformers 3. Method3.1 Temporal Transformer Baseline3.2 PoseFormer: Spatial-Temporal TransformerSpati…

人文景区有必要做VR云游吗?如何满足游客出行需求?

VR云游在旅游行业中的应用正在快速增长&#xff0c;为游客带来沉浸式体验的同时&#xff0c;也为文旅景区提供了新的营销方式。很多人说VR全景展示是虚假的&#xff0c;比不上真实的景区触感&#xff0c;人文景区真的有必要做VR云游吗&#xff1f;我的答案是很有必要。 如果你认…

【vue3+pinia+Cookie】实现token登录及数据持久化

vue2后台登录项目,目前接触到的基本上都是vuex+本地存储或vuex+cookie或vuex+专门用于持久化vuex的库,如:vuex-persistedstate vuex默认情况下是在客户端内存中保存应用程序的状态,当页面刷新的时候,内存中的数据会丢失,导致vuex中的数据也会丢失。因为vuex的状态存储是临…

linux-删除KVM虚拟机

1.查看主机 #virsh list 2.关闭主机 #virsh shutdown 虚拟机名称 3.删除主机定义 #virsh undefine 虚拟机名称 4.删除KVM虚拟机文件 #find / -name 虚拟机名称 #rm -rf 虚拟机文件

Windows11安装Linux子系统,并实现服务自启动,局域网访问,磁盘挂载

Windows11安装Linux子系统&#xff0c;并实现服务自启动&#xff0c;局域网访问&#xff0c;磁盘挂载 一、准备工作二、安装Linux子系统(wsl2)三、为Linux子系统设置桥接网络检查wsl版本在 Hyper-V 管理器中创建虚拟交换机创建 WSL 配置文件启动wsl 四、设置Windows开机自启动L…

【考研复习】24王道数据结构课后习题代码|第3章栈与队列

文章目录 3.1 栈3.2 队列3.3 栈和队列的应用 3.1 栈 int symmetry(linklist L,int n){char s[n/2];lnode *pL->next;int i;for(i0;i<n/2;i){s[i]p->data;pp->next;}i--;if(n%21) pp->next;while(p&&s[i]p->data){i--;pp->next;}if(i-1) return 1;…

openssl安装问题合辑

1.openssl拖累nginx编译失败 问题描述&#xff1a; 因为漏洞原因&#xff0c;升级openssl之后需要重新编译nginx&#xff0c;进行了以下步骤&#xff1a; config没问题&#xff0c;但是make一直报错 初步判断是openssl安装有问题&#xff0c;原因不明&#xff0c;重装了opens…

【ARM Cache 系列文章 9 -- ARM big.LITTLE技术】

文章目录 big.LITTLE 技术背景big.LITTLE 技术详解big.LITTLE 硬件要求 big.LITTLE 软件模型CPU MigrationGlobal Task SchedulingGlobal Task Scheduling比CPU Migration的优势 转自&#xff1a;https://zhuanlan.zhihu.com/p/630981648 如有侵权&#xff0c;请联系删除 big.L…

Codeforces Round 889 (Div. 2)C题题解

文章目录 [Dual (Hard Version)](https://codeforces.com/contest/1855/problem/C2)问题建模问题分析1.按元素值分类讨论&#xff0c;正负不同时存在时2.若正负同时存在时代码 Dual (Hard Version) 问题建模 给定n个数&#xff0c;n不超过20&#xff0c;且每个数ai&#xff0c…

MachineLearningWu_14/P65-P69_Multiclass

x.1 Multiclass多分类问题 对于分类问题&#xff0c;往往指的是二分类问题&#xff0c;而对于二分类的decision boundary较为简单&#xff0c;而实际生活中会有很多问题是多分类问题&#xff0c;例如MNIST手写数字识别&#xff0c; 从特征空间上来看&#xff0c;二分类和多分类…