【PyTorch】Torchvision Models

news2024/9/22 9:52:50

文章目录

  • 六、Torchvision Models
    • 1、VGG
      • 1.1 add
      • 1.2 modify
    • 2、Save and Load
      • 2.1 模型结构 + 模型参数
      • 2.2 模型参数(官方推荐)
      • 2.3 Trap

六、Torchvision Models

1、VGG

VGG参考文档:https://pytorch.org/vision/stable/models/vgg.html

VGG16为例:

https://pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16

ImageNet数据集:

https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageNet.html#torchvision.datasets.ImageNet

ImageNet描述:

https://image-net.org/challenges/LSVRC/index.php

train_data = torchvision.datasets.ImageNet("../data", split="train", transform=torchvision.transforms.ToTensor(),
                                           download=True)

报错:需要手动下载!!!(100多G,还是算了吧)

RuntimeError: The archive ILSVRC2012_devkit_t12.tar.gz is not present in the root directory or is corrupted. You need to download it externally and place it in ../data.

import torchvision

vgg16_false = torchvision.models.vgg16(pretrained=False)  # False 加载网络模型 不需要下载
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

out_features=1000,输出为1000个类,如果想要输出10个类,应该如何?

1.1 add

(1)在VGG16中的features中添加add_linear:

import torchvision
from torch import nn

vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)

vgg16_true.add_module('add_linear', nn.Linear(in_features=1000, out_features=10))
print(vgg16_true)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    ...
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
  (add_linear): Linear(in_features=1000, out_features=10, bias=True)
)

(2)在VGG16中的classifier中添加add_linear:

vgg16_true.classifier.add_module('add_linear', nn.Linear(in_features=1000, out_features=10))
VGG(
  (features): Sequential(
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    ...
    (6): Linear(in_features=4096, out_features=1000, bias=True)
    (add_linear): Linear(in_features=1000, out_features=10, bias=True)
  )
)

1.2 modify

直接将out_features=1000,修改为,输出100:

vgg16_true.classifier[6] = nn.Linear(in_features=1000, out_features=10)
VGG(
  (features): Sequential(
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
   ...
    (6): Linear(in_features=1000, out_features=10, bias=True)
  )
)

2、Save and Load

2.1 模型结构 + 模型参数

vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16, "../model/vgg16_method1.pth")
vgg16 = torch.load("../model/vgg16_method1.pth")
print(vgg16)
VGG(
  (features): Sequential(
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    ...
  )
)

2.2 模型参数(官方推荐)

vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16.state_dict(), "../model/vgg16_method2.pth")

查看一下保存的字典:

vgg16 = torch.load("../model/vgg16_method2.pth")
print(vgg16)
OrderedDict([('features.0.weight', tensor([[[[-0.0108,  0.0403, -0.0032],
          [-0.0723,  0.0372, -0.1241],
          [-0.0583, -0.1042, -0.0469]],
          ...
          ...
          ...
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))])

加载:

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("../model/vgg16_method2.pth"))
print(vgg16)
VGG(
  (features): Sequential(
    ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    ...
  )
)

2.3 Trap

当我们使用第一种方式保存自己定义的网络模型时:

class Liang(nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)

    def forward(self, x):
        x = self.conv1(x)
        return x


liang = Liang()
torch.save(liang, "../model/Liang.pth")

再使用第一种方式加载模型时:

liang = torch.load("../model/Liang.pth")
print(liang)

会报错:AttributeError: Can't get attribute 'Liang' on <module '__main__' from 'C:/Users/MK/Desktop/Pytorch/小土堆/program/TorchVision_Load Model.py'>>

解决方法一:加上class类

class Liang(nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)

    def forward(self, x):
        x = self.conv1(x)
        return x


liang = torch.load("../model/Liang.pth")
print(liang)
Liang(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
)

解决方法二:引入

首先将class Liang类 写入all_class.py 文件中,再使用 from all_class import *,直接引用!

from all_class import *

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("../model/vgg16_method2.pth"))

liang = torch.load("../model/Liang.pth")
print(liang)
Liang(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/32995.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

apache-atlas-hbase-hook源码分析

元数据类型 Hbase元数据类型, 包括命令空间、表、列族、列 public enum HBaseDataTypes {// ClassesHBASE_NAMESPACE,HBASE_TABLE,HBASE_COLUMN_FAMILY,HBASE_COLUMN;public String getName() {return name().toLowerCase();} }Hbase元数据采集实现 1&#xff09;批量采集HBa…

【windows】实战部署二(使用)SVNserver服务端+SVNclient客户端

SVN服务器应用 创建版本库 1、打开VisualSVN Server&#xff1a; 2、建立版本库&#xff1a; 需要右键单击左边窗口的Repositores,在弹出的右键菜单中选择Create New Repository或者新建-Repository 3、默认选择&#xff0c;点击 “下一步” 按钮&#xff1a; Regular FSFS…

物联网安全年报漏洞情况

物联网 威胁分析漏洞篇物联网威胁分析—漏洞篇 引言 本章将从漏洞利用角度对物联网威胁进行分析。首先&#xff0c;我们分析了 NVD和 Exploit-DB中的物联网 年度漏洞及利用 1 变化趋势&#xff1b;之后统计了绿盟威胁捕获系统捕获到的物联网漏洞利用的整体情况&#xff1b;最…

Matlab深度学习实战一:LeNe-5图像分类篇MNIST数据集分十类且matlab提供模型框架全网为唯一详细操作流程

1.数据集简介下载与准备 2.matlab搭建模型相关知识 3.matlab软件的操作过程&#xff1a; &#xff08;1&#xff09;界面操作 &#xff08;2&#xff09;深度学习设计器使用 &#xff08;3&#xff09;图像数据导入 &#xff08;4&#xff09;训练可视化 一、数据集简介下载与…

mysql基本命令操作

MySQL数据库管理 SQL语句 SQL语句用于维护管理数据库&#xff0c;包括数据查询、数据更新、访问控制、对象管理等功能 DDL&#xff1a;数据定义语言&#xff0c;用于创建数据库对象&#xff0c;如库、表、索引等 DML&#xff1a;数据操纵语言&#xff0c;用于对表中的数据进行…

[附源码]计算机毕业设计JAVA民宿网站管理系统

[附源码]计算机毕业设计JAVA民宿网站管理系统 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybati…

Android 10.0 11.0 12.0 启动模拟器教程

Android 10.0 11.0 12.0 启动模拟器教程 一、android 12.0 模拟器二、创建模拟器设备三、创建删除路经文件夹avd和配置环境变量四、启动模拟器一、android 12.0 模拟器 Android 10.0 11.0 12.0 启动模拟器都行,我选择android 12.0 模拟器 二、创建模拟器设备 第一步骤:在 …

i211网卡在Monterey及以上驱动方法

两种方法&#xff1a;一、驱动换成别人修改后的AppleIGB.kext。这么做一般情况用着没问题。但是如果你虚拟机桥接到这个网卡&#xff0c;可以获取到IP,网关等所有参数&#xff0c;就是不能上网 二、刷网卡固件&#xff0c;将i211刷成i210&#xff0c;直接免驱。 macos下操作 …

Brooks曾经在UMLChina网站留过言-回忆和送别(1)

&#xff08;抱歉&#xff0c;有点标题党。&#xff09; 《人月神话》作者Frederick Phillips Brooks Jr. 于2022年11月17日逝世&#xff0c;享年91岁。 图1 摘自 https://christianityandscholarship.org/event/making-things-ncsu-10-27-2015/ 这个岁数即使在今天也算是“…

玩转MySQL:都2022年了,这些数据库技术你都知道吗

引言 MySQL数据库从1995年诞生至今&#xff0c;已经过去了二十多个年头了&#xff0c;到2022.04.26日为止&#xff0c;MySQL8.0.29正式发行了GA版本&#xff0c;在此之前版本也发生了多次迭代&#xff0c;发行了大大小小N多个版本&#xff0c;其中每个版本中都有各自的新特性&…

亚马逊差评怎么删?常用的几种删差评方法介绍

正常情况下每个电商产品都是有好评和差评的&#xff0c;如果一味的都是好评&#xff0c;那么也显的很假&#xff0c;但是差评太多也会影响销售&#xff0c;特别是面对那些恶意差评&#xff0c;这会严重的影响客户下单&#xff0c;因此对于恶意差评&#xff0c;我们还是的想办法…

【Java八股文总结】之Java Web

文章目录Java Web一、Java Web介绍Q&#xff1a;什么是Java Web&#xff1f;Q&#xff1a;Java Web的工作原理&#xff1f;Q&#xff1a;Java Web的知识体系二、JDBC1、JDBC的使用步骤2、JDBC API详解1、DriverManager2、Connection3、Statement4、ResultSet5、PreparedStateme…

连续词袋模型(Continous bag of words, CBOW)

将高维度的词表示转换为低纬度的词表示方法我们称之为词嵌入&#xff08;word embedding&#xff09;。 CBOW是通过周围词去预测中心词的模型。&#xff08;Skip-gram是用中心词预测周围词&#xff09; CBOW模型的结构 最左边为上下文词&#xff0c;这些词用One-hot编码表示&a…

codeforces:C. Set Construction【构造 + 入度观察】

目录题目截图题目分析ac code总结题目截图 题目分析 题目要找n个集合给出一个矩阵b如果bij 1&#xff0c;表示第i个集合为第j个集合的真子集bij 0&#xff0c;表示不是真子集寻找集合间的关系&#xff0c;g记录下一个更大的集合&#xff0c;smaller表示被本集合包含的集合的…

以数据为中心的标记语言-->yaml

目录 一.yaml 介绍 二.yaml 基本语法 三.数据类型 1.字面量 2.对象 3.数组 四.yaml 应用实例 1.需求: 2.需求图解 3.代码实现 五.yaml 使用细节 一.yaml 介绍 YAML 是"YAML Aint a Markup Language"(YAML 不是一种标记语言) 的递归缩写。在开发 的这种语言…

每日一练2——C++排序子序列问题倒置字符串问题

文章目录排序子序列问题思路&#xff1a;代码&#xff1a;倒置字符串思路&#xff1a;方法一&#xff1a;代码&#xff1a;方法二&#xff1a;代码&#xff1a;排序子序列问题 题目链接 这道题题意不难理解&#xff0c;但是想写对还是有很多细节的。 本题要求解的是排序子序列…

python之正则表达式【简化版】

大家好&#xff0c;我们今天说一说正则表达式&#xff0c;在之前我们也介绍了关于正则表达式&#xff0c;今天&#xff0c;我们来深入的了解一下。我们知道正则表达式是处理字符串的强大工具&#xff0c;它有自己的语法结构&#xff0c;什么匹配啊&#xff0c;都不算什么。 正…

JavaIO流:BIO梳理

BIO&#xff08;blocking I/O&#xff09; &#xff1a; 同步阻塞&#xff0c;服务器实现模式为一个连接一个线程&#xff0c;即客户端有连接请求时服务器端就需要启动一个线程进行处理&#xff0c;如果这个连接不做任何事情会造成不必要的线程开销&#xff0c;可以通过线程池机…

Java8新特性 Stream流

Stream流 在Java 8中&#xff0c;得益于Lambda所带来的函数式编程&#xff0c;引入了一个全新的Stream概念&#xff0c;用于解决已有集合类库既有的弊端。 传统集合的多步遍历代码几乎所有的集合&#xff08;如 Collection 接口或 Map 接口等&#xff09;都支持直接或间接的遍…

我这样写代码,比直接使用 MyBatis 效率提高了 100 倍

对一个 Java 后端程序员来说&#xff0c;mybatis、hibernate、data-jdbc 等都是我们常用的 ORM 框架。它们有时候很好用&#xff0c;比如简单的 CRUD&#xff0c;事务的支持都非常棒。但有时候用起来也非常繁琐&#xff0c;比如接下来我们要聊到的一个常见的开发需求&#xff0…