pytorch学习(8)——现有网络模型的使用以及修改

news2024/11/16 19:42:08

1 vgg16模型

在这里插入图片描述

1.1 vgg16模型的下载

采用torchvision中的vgg16模型,能够实现1000个类型的图像分类,VGG模型在AlexNet的基础上使用3*3小卷积核,增加网络深度,具有很好的泛化能力。
首先下载vgg16模型,python代码如下:

import torchvision

# 下载路径:C:\Users\win10\.cache\torch\hub\checkpoints
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("ok")

下载结果:

G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
G:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
ok

1.2 vgg16模型内部结构

查看预训练的模型和未预训练的模型的内部结构:

import torchvision

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)
print(vgg16_false)

预训练的模型和未预训练的模型在整体结构上相同,但内部节点的参数(weight和bias)有所不同。
输出结果如下:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以发现(classifier)中最后一层可以发现out_features=1000,表示该模型能够支持1000种类型的图像分类。

2 迁移学习

迁移学习是机器学习的一个子领域,它允许一个已经在一个任务上训练好的模型用于另一个但相关的任务。通过这种方式,模型可以借用在原任务上学到的知识,从而更快地、更准确地完成新任务。

本文采用CIFAR10数据集,内部包含10个种类的图像,修改vgg16模型对数据集进行图像分类。为了将此数据集代入vgg16模型,需要对模型进行修改。

(classifier): Sequential(
	... ...
	(6): Linear(in_features=4096, out_features=1000, bias=True)
)

2.1 添加层

使用add_module()函数添加模块。由于最后的归一化层为4096通道输出转1000通道输出,因此添加一个归一化层将1000通道输出转换为10通道输出。

import torchvision
from torch import nn

vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

输出结果(部分):

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
    (add_linear): Linear(in_features=1000, out_features=10, bias=True)
  )

2.2 修改层

对classifier内的第6层进行修改。由于最后的归一化层为4096通道输出转1000通道输出,因此需要修改为4096通道输出转换为10通道输出。

import torchvision
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)

vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

输出结果(部分):

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )

2.3 模型保存

有两种方式保存模型数据。第一种保存方式是将模型结构和模型参数保存,第二种保存方式只是保存模型参数,以字典类型保存。
python代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential, CrossEntropyLoss

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 未经过训练的模型

# 保存方式1,模型结构+模型参数
torch.save(vgg16_false, "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method1.pth")

# 保存方式2,模型参数(官方推荐)
torch.save(vgg16_false.state_dict(), "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method2.pth")

保存修改过的模型或自己的编写的模型:

# 保存模型和导入模型时都需要导入MYNN这个类
class MYNN(nn.Module):
    def __init__(self):
        super(MYNN, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2, stride=1),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2, stride=1),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2, stride=1),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

mynn = MYNN()
torch.save(mynn, "G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\mynn_method1.pth")

以上两部分Python代码运行结果如下:
在这里插入图片描述

2.4 模型导入

有两种方式导入模型数据。第一种导入方式能够直接使用,第二种导入方法需要将字典数据导入原来的网络模型。

import torch
import torchvision
from torch import nn
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential, CrossEntropyLoss

# 方式1:加载模型
vgg16_import = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method1.pth")
print(vgg16_import)

# 方式2:加载模型(字典数据)
vgg16_import2 = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\vgg16_method2.pth")
vgg16_new = torchvision.models.vgg16(pretrained=False)      # 重新加载模型
vgg16_new.load_state_dict(vgg16_import2)                    # 将数据填入模型
print(vgg16_import2)
print(vgg16_new)

导入保存的自己的模型,python代码如下:

# 需要导入自己网络模型
class MYNN(nn.Module):
    def __init__(self):
        super(MYNN, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2, stride=1),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2, stride=1),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2, stride=1),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

model = torch.load("G:\\Anaconda\\pycharm_pytorch\\learning_project\\model\\mynn_method1.pth")
print(model)

自己的网络模型导入运行结果:

MYNN(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)

完整的模型训练套路-GPU训练

模型验证套路

Github上的代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/933708.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ae 效果:CC Jaws

过渡/CC Jaws Transition/CC Jaws CC Jaws(CC 锯齿)效果为视频或图像创造独特的锯齿状过渡效果。它允许用户控制中心点、方向、高度、宽度和形状,从而提供多种独特的过渡样式。 ◆ ◆ ◆ 效果属性说明 Completion 完成度 控制过渡效果的完成…

【MYSQL】排序时 如何将0排到最后,并让其他值按正序展示?

背景:展示排名时需要1,2,3,4,5,这样展示但是有些没有排名得数据字段默认值时0,这时直接用ASC就会出现问题 实现效果 实现方式:使用MySQL的ORDER BY语句来实现。以下是一个示例的SQL查…

Creo软件安装包分享(附安装教程)

目录 一、软件简介 二、软件下载 一、软件简介 Creo是一款机械设计软件,全称为Creo Parametric,是PTC公司推出的一款三维CAD/CAM/CAE软件。Creo被广泛应用于机械设计、汽车、航空、航天、电子、家电、玩具等各个行业,是世界上应用最广泛、最…

小研究 - Java虚拟机内存管理(二)

Java 语言的面向对象,平台无关,安全,开发效率高等特点,使其在许多领域中得到了越来越广泛的应用。但是由于Java程序由于自身的局限性,使其无法应用于实时领域。由于垃圾收集器运行时将中断Java程序的运行,其…

PC端版面设计之尾部设计

1、莫拉王子,底部就放了一个返回顶部 2 麻辣王子就放了一个认证--放了产地和得到的奖 3 阿芙:就是精油 4阿芙放的是品牌故事 5 这里可以做微博粉丝群体 6 基本返回底部是一个标配,点一下就可以反悔了 7 加一下旺旺店铺 8 BetyBoop的底部 9 底部 10 返回底…

【C++】priority_queue优先级队列

🏖️作者:malloc不出对象 ⛺专栏:C的学习之路 👦个人简介:一名双非本科院校大二在读的科班编程菜鸟,努力编程只为赶上各位大佬的步伐🙈🙈 目录 前言一、priority_queue的介绍二、pr…

PHP聚合支付网站源码/对接十多个支付接口 第三方/第四方支付/系统源码

PHP聚合支付网站源码/对接十多个支付接口 第三方/第四方支付/系统源码 内附数十个支付接口代码文件。 下载地址:https://bbs.csdn.net/topics/616764485

window系统中如何判断是物理机还是虚拟机

总结了如何判断物理机: 1. 用systeminfo的系统型号。(注,有资料是看处理器和bios。但是我这台不明确。看系统型号准确一些) 2. 在任务管理器》性能中查看“逻辑处理器”还是“虚拟处理器”。虚拟机,看“是“、”否”。…

金融客户敏感信息的“精细化管控”新范式

目 录 01 客户信息保护三箭齐发,金融IT亟需把握四个原则‍ 02 制度制约阻碍信息保护的精细化管控 ‍‍‍‍‍‍‍ 03 敏感信息精细化管控范式的6个关键设计 04 分阶段实施,形成敏感信息管控的长效运营的机制 05 未来,新挑战与新机遇并存 …

postgresql常用函数-数学函数

postgresql常用函数 简介数学函数算术运算符绝对值取整函数乘方与开方指数与对数整数商和余数弧度与角度常量 π符号函数生成随机数 简介 函数(function)是一些预定义好的代码模块,可以将输入进行计算和处理,最终输出一个 结果值…

VScode的PHP远程调试模式Xdebug

目录 第一步、安装VScode中相应插件 remote-ssh的原理 ssh插件: PHP相关插件: 第二步、安装对应PHP版本的xdebug 查看PHP具体配置信息的phpinfo页面 1、首先,打开php编辑器,新建一个php文件,例如:inde…

林【2021】

三、应用 1.字符串abaaabaabaa,用KMP改进算法求出next和nextval的值 2.三元组矩阵 4.二叉树变森林 四、代码(单链表递增排序,二叉树查找x,快速排序)

华为eNSP模拟器中,路由器如何添加serial接口

在ensp模拟器中新建拓扑后,添加2个路由器。 在路由器图标上单击鼠标右键,选择设置选项。 在【视图】选项卡的【eNSP支持的接口卡】窗口查找serial接口卡。 选择2SA接口卡,将其拖动到路由器空置的卡槽位。 如上图所示,已经完成路由…

JavaScript用indexOf()在字符串数组中查找子串时需要注意的一个地方

一、遇到问题 在 继续更新完善:C 结构体代码转MASM32代码 中,由于结构体成员中可能为数组类型的情况,因此我们在提取结构体成员信息的过程中,需要检测结构体成员名称字符串中是否包括 [],如果包括那么我们要截取[前面…

基于Java+SpringBoot+Vue前后端分离美食推荐商城设计和实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

Linux(实操篇一)

Linux实操篇 Linux(实操篇一)1. 常用基本命令1.1 帮助命令1.1.1 man获得帮助信息1.1.2 help获得shell内置命令的帮助信息1.1.3 常用快捷键 1.2 文件目录类1.2.1 pwd显示当前 工作目录的绝对路径1.2.2 ls列出目录的内容1.2.3 cd切换目录1.2.4 mkdir创建一个新的目录1.2.5 rmdir删…

[管理与领导-52]:IT基层管理者 - 8项核心技能 - 7 - 决策

目录 前言: 一、什么是决策 二、为什么需要管理者的决策 三、什么时候需要管理者决策 四、常见的决策误区 4.1 关于决策的误区 4.2 错误的决策行为 五、如何进行有效决策 六、进行决策的常用方法 前言: 管理者存在的价值就是制定目标&#xff0…

自然语言处理(一):词嵌入

词嵌入 词嵌入(Word Embedding)是自然语言处理(NLP)中的一种技术,用于将文本中的单词映射到一个低维向量空间中。它是将文本中的单词表示为实数值向量的一种方式。 在传统的文本处理中,通常使用独热编码&…

【定时提醒】的应用场景

应用场景: 定时提醒是一个在多个行业中都有广泛应用的功能,它可以用来提醒用户执行某些任务、活动或事件。以下是几个定时提醒在不同行业中的应用案例: 医疗保健行业: 医疗机构可以利用定时提醒来提醒患者服药、定期检查、体检预…

【Java集合学习1】ArrayList集合学习及集合概述分析

JavaArrayList集合学习及集合学习概述 一、Java集合概述 Java 集合, 也叫作容器,主要是由两大接口派生而来:一个是 Collection接口,主要用于存放单一元素;另一个是 Map 接口,主要用于存放键值对。对于Col…