使用Pytorch加载预训练模型及修改网络结构

news2024/9/22 7:31:03

Pytorch有自带的训练好的AlexNet、VGG、ResNet等网络架构。详见官网

1.加载预训练模型

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchsummary import summary

#加载预训练的AlexNet,设置参数为pretrained=True即可
#如果只需要AlexNet的网络结构,设置参数为pretrained=False
model=models.alexnet(pretrained=True)

#因为summary默认在cuda上加载模型所以要把模型加载到GPU上
model.to("cuda")
summary(model,(3,227,227))

#如果没有GPU,直接输入
summary(model,(3,227,227),device="cpu")

用summary输出网络结构如下:

在这里插入图片描述

2.增加、更改某些层

print(model)
#输出如下
"""
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)"""

可以看到,AlexNet有三个层,分别是features、avgpool、classifier。用model.features查看features层(也就是卷积层)的网络结构。

model.features
#输出
"""
Sequential(
  (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
"""
#如果我们要在features最后加上一层全连接层,名字叫fc
model.features.add_module("fc",nn.Linear(120,20))
print(model)
#可以看到多了一个fc的全连接层
"""
Sequential(
  (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=120, out_features=20, bias=True)
)
"""
#如果要修改卷积层的卷积核大小,修改成7*7
model.features[0]=nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
model
#第一层的卷积核大小已经改变,例如修改model.features[-1]即可修改最后一层
"""
(features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (fc): Linear(in_features=120, out_features=20, bias=True)
  )
"""

3.冻结层(取消梯度反向传播)

# 方式一:冻结所有 features 层
for name,layer in model.named_children():
    if name == "features":
#     print(name,layer)
        for param in layer.parameters():
            param.requires_grad=False

# 方式二: 冻结所有层(包括features层和classifier层)
    for param in model.parameters():
        param.requires_grad = False
        
for name,layer in model.named_parameters():
#      for param in layer:
    print("%s层未冻结 %s"%(name,layer.requires_grad))

#可以看到features这一层的梯度都是False,不会进行反向传播。classiofier的梯度都是True。

"""
features.0.weight层未冻结 False
features.0.bias层未冻结 False
features.3.weight层未冻结 False
features.3.bias层未冻结 False
features.6.weight层未冻结 False
features.6.bias层未冻结 False
features.8.weight层未冻结 False
features.8.bias层未冻结 False
features.10.weight层未冻结 False
features.10.bias层未冻结 False
features.fc.weight层未冻结 False
features.fc.bias层未冻结 False
classifier.1.weight层未冻结 True
classifier.1.bias层未冻结 True
classifier.4.weight层未冻结 True
classifier.4.bias层未冻结 True
classifier.6.weight层未冻结 True
classifier.6.bias层未冻结 True
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/740067.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VBA系列技术资料MF33:VBA_将文本文件转换为Excel

【分享成果,随喜正能量】一心热枕对待生活,静静的安抚自己内心的急迫和焦虑,你人生的好运,常常在你沉醉于生活时悄悄临门的。。 我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高…

vue本地开发集成https

背景:在本地项目开发中,调用第三方服务获取音视频通话,音视频通话是采用 WebRTC 来实现的,而 WebRTC 中使用音视频设备进行取流是需要在安全域下才可以调起的设备权限 解决方案:使用npm安装mkcert,配置证书…

spring boot+MySQL实现学习平台

本次设计任务是要设计一个学习平台,通过这个系统能够满足学习信息的管理及学生和教师的学习管理功能。系统的主要功能包括首页,个人中心,学生管理,教师管理,课程信息管理,类型管理,作业信息管理…

Hive(18):DML之Load加载数据

1 背景 回想一下,当在Hive中创建好表之后,默认就会在HDFS上创建一个与之对应的文件夹,默认路径是由参数hive.metastore.warehouse.dir控制,默认值是/user/hive/warehouse。 要想让hive的表和结构化的数据文件产生映射,就需要把文件移到到表对应的文件夹下面,当然,可以在…

天天刷题-->LeetCode(无重复字符的最长字串)

个人名片: 🐅作者简介:一名大二在校生,热爱生活,爱好敲码! \ 💅个人主页 🥇:holy-wangle ➡系列内容: 🖼️ tkinter前端窗口界面创建与优化 &…

轻松学会Java导出word,一篇文章就够了!

很多小伙伴在工作中&#xff0c;可能又这样一个需求&#xff1a;根据word模板去填充数据&#xff0c;变成我们想要的word文档&#xff0c;这是很多刚进入职场的小白都会碰到的需求。 当遇上这种需求&#xff0c;我们可以通过这篇文章要讲的poi-tl 来做处理。 导入依赖 <dep…

下载pycharm专业版

PyCharm: the Python IDE for Professional Developers by JetBrainsThe Python & Django IDE with intelligent code completion, on-the-fly error checking, quick-fixes, and much more...https://www.jetbrains.com/pycharm/Pycharm安装使用与版本切换_pycharm专业版换…

华为开发者大会2023(Cloud)之旅

【摘要】 金鱼哥畅游记&#xff1a;华为开发者大会2023&#xff08;Cloud&#xff09; 2023年7月7日华为开发者大会2023&#xff08;Cloud&#xff09;在广东东莞正式揭开帷幕&#xff0c;金鱼哥很庆幸能有机会参加此次盛大聚会&#xff0c;看到众开发者共聚一堂&#xff0c;在…

812. 打印数字

链接&#xff1a; 812. 打印数字 - AcWing题库 题目&#xff1a; 输入一个长度为 nn 的数组 aa 和一个整数 sizesize&#xff0c;请你编写一个函数, void print(int a[], int size), 打印数组 aa 中的前 sizesize 个数。 输入格式 第一行包含两个整数 nn 和 sizesize。 第二行包…

MySQL (select查询的基本用法及select相关练习)

如图插入数据&#xff1a; 得 1、显示所有职工的基本信息。 mysql> select * from worker;效果如图&#xff1a; 2、查询所有职工所属部门的部门号&#xff0c;不显示重复的部门号 mysql> select distinct 部门号 from worker;效果如图&#xff1a; 3、求出所有职…

痴呆≠阿尔茨海默病?5个特征或是发生痴呆!

痴呆是一种智力退化的综合症&#xff0c;其特点包括记忆力减退、思维能力下降、判断力和语言能力受损等。然而&#xff0c;很多人错误地将痴呆等同于阿尔茨海默病。事实上&#xff0c;阿尔茨海默病只是痴呆症的一种常见类型。下面将介绍痴呆的五个主要特征以及导致痴呆的原因。…

基于Java+SpringBoot+Vue的开放实验管理系统设计与实现

博主介绍&#xff1a;✌擅长Java、微信小程序、Python、Android等&#xff0c;专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3fb; 不然下次找不到哟 Java项目精品实战案…

找回存储在DBeaver连接中的数据库密码

一、拿到 credentials-config.json 文件 1、打开 Dbeaver 后&#xff0c;点击 “窗口 — 首选项” 2、找到worksapce path 3、进入 workspace path 的文件夹&#xff0c;再进入到 \General.dbeaver 文件夹&#xff0c;找到文件 credentials-config.json &#xff08;可以备…

数据结构初阶--排序1

目录 前言冒泡排序思路代码实现 选择排序思路代码实现 插入排序思路代码实现 希尔排序思路代码实现 堆排序思路向上调整建堆向下调整建堆 代码实现 前言 排序在我们的日常生活中无处不在&#xff0c;比如对若干个学生的期末成绩&#xff0c;可以依据姓氏&#xff0c;学号&…

交易成本模型与Python技术共同促进高频交易的发展走向

高频交易是一种在金融市场中越来越受到关注的交易方式&#xff0c;其具有快速、高效、低风险的特点&#xff0c;可以为投资者带来丰厚的利润。然而&#xff0c;在高频交易中&#xff0c;交易成本往往占据了很大的比例&#xff0c;可以说是一个非常重要的因素。因此&#xff0c;…

辅助驾驶功能开发-功能规范篇(22)-1-L2级辅助驾驶方案功能规范

1. 系统概览 System Overview 1.1 系统架构 各模块描述如下: ADSADS控制器,包括5R1V感知融合算法模块、level0-level2相关功能、控制决策模块、响应执行模块。EPBi制动执行机构,包括主制动和冗余制动,可实现行车和驻车控制。EPS转向执行机构,TJP包含冗余转向,L-TJP不包含…

DynaSLAM代码详解(3) — MaskNet.cc加载Mask R-CNN网络部分

目录 3.1 Mask R-CNN运行 3.2 SegmentDynObject::SegmentDynObject() 3.3 SegmentDynObject::GetSegmentation() 3.1 Mask R-CNN运行 在Examples/RGB-D/rgbd_tum.cc文件开始运行Mask R-CNN网络&#xff0c;首先进入MaskNet->GetSegmentation函数。 // Segment out the i…

使用Dockerfile创建nginx+php镜像,采用分层

什么是Dockerfile Dockerfile是一种能被Docker程序解释的脚本&#xff0c;它是由一条条的命令所组成&#xff0c;每条命令对应Linux下面的一条命令&#xff0c;Docker程序将这些Dockerfile命令翻译成真正的Linux命令 Dockerfile命令 Dockerfile通常会包含如下命令&#xff1a…

【AGC】认证服务HarmonyOS(api9)实现手机号码认证登录

【问题背景】 近期AGC上线了HarmonyOS(api9)平台的SDK&#xff0c;这样api9的设备也能使用认证服务进行快速认证登录了。下面为大家带来如何使用auth SDK&#xff08;api9&#xff09;实现手机号码认证登录。 【开通服务】 1.登录AppGallery Connect&#xff0c;点击“我的项…

lc202306

785. 判断二分图 对于单个连通图&#xff1a;一个dfs判断图中所有节点符合二分。 遍历节点列表>遍历所有连通图。 133. clone graph 994. rotting oranges 力扣 维护一个time表&#xff0c;表示所有orange rot的最快时间。对每一个 t0 就 rot 的 orange dfs&#xff0c;遇…