修改网络的结构用于预训练

news2025/1/23 22:43:45

目录

一、模型准备

二、修改结构

1、在网络中添加一层

2、在classifier结点添加一个线性层

3、修改网络中的某一层(features 结点举例)

4、替换网络中的某一层结构(与第3点类似)

5、提取全连接层的输入特征数和输出特征数

6、删除网络层

7、指定某一层冻结

8、批量冻结(只训练最后一层)

9、查看哪些层冻结哪些没有

10、加载网络权重

三、自定义数据集加载


一、模型准备

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)

查看模型结构:

print(model)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU(inplace=True)
    (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

二、修改结构

1、在网络中添加一层

model.features.add_module('last_layer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
print(model)

2、在classifier结点添加一个线性层

model.classifier.add_module('Linear', nn.Linear(1000, 10))
print(model)

3、修改网络中的某一层(features 结点举例)

model.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
print(model)

4、替换网络中的某一层结构(与第3点类似)

直接提取这个结构,重新赋值一个结构即可

import torch
import torch.nn as nn
from torchvision import models

# 创建resnet50网络
model = models.resnet50(num_classes=1000)
print(model)

fc_in_features = model.fc.in_features
model.fc = torch.nn.Linear(fc_in_features, 2, bias=True)

5、提取全连接层的输入特征数和输出特征数

查看最后一层全连接的输入数量:

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)
print(model)

print(model.classifier[6].in_features)

其他的结构也类似,主要是写出结构里的参数即可。例如:全连接的in_features和out_features

6、删除网络层

使用一个空结构替换即可,即nn.Sequential()

model.classifier[6] = nn.Sequential()
print(model)

或者用切片来删除(删除后4层):

model.features = nn.Sequential(*list(model.features.children())[:-4])

net.classifier 对应 net.classifier.children()

net.features 对应 net.features.children()

7、指定某一层冻结

# 冻结指定层的预训练参数
model.classifier[3].weight.requires_grad = False

8、批量冻结(只训练最后一层)

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)
print(model)

for param in model.parameters():
    param.requires_grad_(False)
fc_in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(fc_in_features, 2, bias=True)

9、查看哪些层冻结哪些没有

for i in model.parameters():
    if i.requires_grad:
        print(i)

10、加载网络权重


model.load_state_dict(torch.load('model.pth')) # 加载网络参数

三、自定义数据集加载

from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, xxxx):
        super(MyDataset, self).__init__()
        pass
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        pass
        return image, classes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1855556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

优秀的“抗霾”神器:气膜体育馆—轻空间

随着空气污染问题日益严重,尤其是雾霾天气频发,体育运动的场地环境质量受到越来越多的关注。气膜体育馆作为一种新型的体育场馆解决方案,以其独特的设计和多重优势,成为了优秀的“抗霾”神器。轻空间将深入探讨气膜体育馆的特点和…

Ubuntu下安装、运行Nginx

安装Ubuntu桌面系统(虚拟机)_虚拟机安装ubuntu桌面版-CSDN博客 默认情况下,Ubuntu并没有安装Nginx服务。用户可以使用以下命令安装Nginx服务及其相关的组件: liuubuntu:~$ sudo apt install nginx 安装完成之后,使用…

高考十字路口:24年考生如何权衡专业与学校的抉择?

文章目录 每日一句正能量前言专业解析理工科专业商科专业人文社科专业艺术与设计专业个人经验与思考过程结论 名校效应分析名校声誉与品牌效应资源获取学术氛围就业优势个人发展结论 好专业和好学校的权衡个人职业目标行业需求教育质量资源和机会学术氛围就业优势经济和地理位置…

【驱动篇】龙芯LS2K0300之按键驱动

实验过程 实验目的: 在龙芯开发板上面验证GPIO按键的输入过程 ① 根据原理图连接按键板 ② 将4个i2c引脚的功能复用为GPIO ③ 注册input设备驱动,绑定中断处理函数,使用定时器消抖 原理图 4个按键引脚:CPU_I2C0_SCL -> G…

全栈人工智能工程师:现代博学者

任何在团队环境中工作过的人都知道,每个成功的团队都有一个得力助手——无论你的问题性质如何,他都能帮助你。在传统的软件开发团队中,这个人是一个专业的程序员,也是另一种技术的专家,可以是像Snowflake这样的数据库技…

【windows|010】OSI七层模型和TCP/IP五层模型详解

🍁博主简介: 🏅云计算领域优质创作者 🏅2022年CSDN新星计划python赛道第一名 🏅2022年CSDN原力计划优质作者 ​ 🏅阿里云ACE认证高级工程师 ​ 🏅阿里云开发者社区专家博主 💊交流社…

植物大战僵尸杂交版2024最新官方原版iOS手机+mac电脑版下载

🌱 绿意盎然的游戏体验,《植物大战僵尸杂交版》带你开启全新战斗模式 亲爱的小伙伴们,今天我要和大家分享一款让人眼前一亮的游戏——《植物大战僵尸杂交版》!这款游戏在经典的基础上进行了大胆创新,给玩家带来了前所未…

手把手带你从零构建一个用于讲故事的 LLM

教程 LLM101:创建能讲故事的 LLM LLM101n 是 llm.c 作者开的新坑,一个系列教程,手把手带你从零构建一个用于讲故事的 LLM,目前只写了目录已经斩获 5.8k star。

户外龙头边城体育签约实在智能,打造财务数字化转型标杆!

财务工作在企业中是一项极其重要且复杂的任务,需要具备高度的准确性和及时性。尽管应用财务软件,企业实现了财务数据的电子化和数字化,但仍然面临着数据采集与处理的实时性和全面性问题,以及由此带来的风险控制和决策支持的不足。…

运动蓝牙耳机哪个口碑最好?五大高口碑顶尖单品推荐

在这个快节奏时代,智能手机的普及使得运动开放式耳机逐渐成为我们日常出行的必备单品。运动开放式耳机凭借独特的外形设计,赢得了众多消费者的喜爱。它们不同于传统的入耳式设计,以舒适佩戴为核心,有效缓解了长时间佩戴对耳部造成…

【机器学习】自然语言处理的新前沿:GPT-4与Beyond

📝个人主页:哈__ 期待您的关注 目录 🔥引言 背景介绍 文章目的 一、GPT-4简介 GPT-4概述 主要特性 局限性和挑战 二、自监督学习的新进展 自监督学习的原理 代表性模型和技术 三、少样本学习和零样本学习 少样本学习的挑战 先…

zkWASM:ZK+zkVM的下一站?

1. 引言 ZK技术具备极大通用性,也帮助以太坊从去中心化投资走向去信任化的价值观。“Don’t trust, Verify it!”,是ZK技术的最佳实践。ZK技术能够重构链桥、预言机、链上查询、链下计算、虚拟机等等一系列应用场景,而通用型的ZK协处理器就是…

Adaptive Server Connection Failed on Windows

最近在使用pymssql (版本2.3.0)连接SQL Server2012遇到如下问题: pymssql._mssql.MSSQLDatabaseException: (20002, bDB-Lib error message 20002, severity 9:\nAdaptive Server connection failed (localhost)\nDB-Lib error message 2000…

前端也需要知道的一些常用linux命令

前端也需要知道的一些常用linux命令 1.问题背景2.连接工具(SecureCRT_Portable)a.下载工具b.连接服务器c.登录到root账户 3.基本命令a.cd命令和cd ..b.ll命令和ls命令c:cp命令d.rm命令e:rz命令f.unzip命令g.mv命令h.pwd命令(这里没有用到&…

Linux基础二

目录 一,tail查看文件尾部指令 二,date显示日期指令 三,cal查看日历指令 四,find搜索指令 五,grep 查找指令 六,> 和>> 重定向输出指令 七, | 管道指令 八,&&逻辑控…

如何发现Redis热Key,有哪些解决方案?

什么是 hotkey? 如果一个 key 的访问次数比较多且明显多于其他 key 的话,那这个 key 就可以看作是 hotkey(热 Key)。例如在 Redis 实例的每秒处理请求达到 5000 次,而其中某个 key 的每秒访问量就高达 2000 次&#x…

【AI大模型】驱动的未来:穿戴设备如何革新血液、皮肤检测与营养健康管理

文章目录 1. 引言2. 现状与挑战3. AI大模型与穿戴设备概述4. 数据采集与预处理4.1 数据集成与增强4.2 数据清洗与异常检测 5. 模型架构与训练5.1 高级模型架构5.2 模型训练与调优 6. 个性化营养建议系统6.1 营养建议生成优化6.2 用户反馈与系统优化 7. 关键血液成分与健康状况评…

grpc教程——proto文件转go

【1】编写一个proto文件 syntax "proto3"; package myproto;service NC{rpc SayStatus (NCRequest) returns (NCResponse){} }message NCRequest{ string name 1; } message NCResponse{string status 1; } 【2】转换:protoc --go_out. myservice.pro…

LLM Agent提效进阶:反思工作流——91%精度大超GPT-4 24%

1. 相关研究 反思依赖于LLM对自己之前提出的工作进行反思并提出改进的方法,有三篇典型论文详细描述了这种模式,我们先来看一下。 2. Self-Refine 顾名思义,它是一种自我精炼的LLM优化技术,使用单一的LLM作为生成器、改进器和反…

go语言day4 引入第三方依赖 整型和字符串转换 进制间转换 浮点数 字符串

Golang依赖下载安装失败解决方法_安装go依赖超时怎么解决-CSDN博客 go安装依赖包(go get, go module)_go 安装依赖-CSDN博客 目录 go语言项目中如何使用第三方依赖:(前两步可以忽略) 一、安装git,安装程序…