Pytorch微调深度学习模型

news2024/11/26 23:47:51

在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。

网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。

大体步骤是:

1、加载训练好的模型。

2、冻结不想微调的层,设置想训练的层。(这里可以新建一个层替换原有层,也可以不新建层,直接微调原有层)

3、训练即可。

1、先加载一个模型

我这里是训练好的一个SqueezeNet模型,所有模型都适用。

## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Fire

model = torch.load("Model/SqueezeNet.pth").to(device)
print(model)
# 输出结果
SqueezeNet(
  (stem): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fire2): Fire(
    (squeeze): Sequential(
      (0): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire3): Fire(
    (squeeze): Sequential(
      (0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire4): Fire(
    (squeeze): Sequential(
      (0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv10): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (avg): AdaptiveAvgPool2d(output_size=1)
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

print(model)时会显示模型每个层的名字。这里我想对conv10层进行微调,因为它是最后一个具有参数可以微调的层了。当然,如果最后一层是全连接的话,也建议微调最后全连接层。 

2、冻结不想训练的层。

这里就有两种不同的方法了:一是新建一个conv10层,替换掉原来的层。二是不新建,直接微调原来的层。

新建:

model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)

可以直接用model.conv10.in_channels等加载原来层的各种参数。这样就定义好了一个新的conv10层,并且已经替换进了模型中。

然后先冻结所有层(requires_grad = False),再放开conv10层(requires_grad = True)。

# 先冻结所有层
for param in model.parameters():
    param.requires_grad = False

# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():
    param.requires_grad = True

如果不新建层,则不需要运行model.conv10 = nn.Conv2d那一行即可。直接开始冻结就可以。

 3、训练

这里一定要注意,optimizer里要设置参数 model.conv10.parameters(),而不是model.parameters()。这是让模型知道它将要训练哪些参数。

optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)

虽然上面已经冻结了不想训练的参数,但是这里最好还是写上model.conv10.parameters()。大家也可以试试不写行不行。

# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 只优化conv10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)
# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 设置模型为训练模式
model.train()

num_epochs = 10
for epoch in range(num_epochs):
    # model.train()
    running_loss = 0.0
    correct = 0
    for x_train, y_train in data_loader:
        x_train, y_train = x_train.to(device), y_train.to(device)
        print(x_train.shape, y_train.shape)

        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x_train.size(0)

        # 统计训练集的准确率
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_train).sum().item()

    # 计算每个 epoch 的训练损失和准确率
    epoch_loss = running_loss / len(dataset)
    epoch_accuracy = 100 * correct / len(dataset)
    
    # if epoch % 5 == 0 or epoch == num_epochs-1 :
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')

输出显示Loss下降说明模型有在学习。 模型准确率从0变成100,还是非常有成就感的!当然我这里就用了一个样本来微调hhhh。

Epoch [1/10]
Train Loss: 0.8185, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [2/10]
Train Loss: 0.7063, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [3/10]
Train Loss: 0.6141, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [4/10]
Train Loss: 0.5385, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [5/10]
Train Loss: 0.4761, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [6/10]
Train Loss: 0.4244, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [7/10]
Train Loss: 0.3812, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [8/10]
Train Loss: 0.3449, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [9/10]
Train Loss: 0.3140, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [10/10]
Train Loss: 0.2876, Train Accuracy: 100.00%

4、验证一下确实是只有这个层参数变化了,而其他层参数没变。

在训练模型之前,看一下这个层的参数:

raw_parm = model.conv10.weight
print(raw_parm)
# 部分输出为
Parameter containing:
tensor([[[[-0.1621]],

         [[ 0.0288]],

         [[ 0.1275]],

         [[ 0.1584]],

         [[ 0.0248]],

         [[-0.2013]],

         [[-0.2086]],

         [[ 0.1460]],

         [[ 0.0566]],

         [[ 0.2897]],

         [[ 0.2898]],

         [[ 0.0610]],

         [[ 0.2172]],

         [[ 0.0860]],

         [[ 0.2730]],

         [[-0.1053]]],

训练后,也输出一下这个层的参数:

## 查看微调后模型的参数
tuned_parm = model.conv10.weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.1446]],

         [[ 0.0365]],

         [[ 0.1490]],

         [[ 0.1783]],

         [[ 0.0424]],

         [[-0.1826]],

         [[-0.1903]],

         [[ 0.1636]],

         [[ 0.0755]],

         [[ 0.3092]],

         [[ 0.3093]],

         [[ 0.0833]],

         [[ 0.2405]],

         [[ 0.1049]],

         [[ 0.2925]],

         [[-0.0866]]],

可见这个层的参数确实是变了。

然后检查一下别的随便一个层:

训练前:

# 训练前
raw_parm = model.stem[0].weight
print(raw_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151,  0.1123],
          [-0.2114,  0.0173, -0.1322],
          [-0.0819,  0.0748, -0.2790]]],


        [[[-0.0918, -0.2783, -0.3193],
          [ 0.0359,  0.2993, -0.3422],
          [ 0.1979,  0.2499, -0.0528]]],

训练后:

## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151,  0.1123],
          [-0.2114,  0.0173, -0.1322],
          [-0.0819,  0.0748, -0.2790]]],


        [[[-0.0918, -0.2783, -0.3193],
          [ 0.0359,  0.2993, -0.3422],
          [ 0.1979,  0.2499, -0.0528]]],

可见参数没有变化。说明这层没有进行学习。

5、为了让大家更容易全面理解,完整代码如下。

import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset,TensorDataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler # 多数样本下采样

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 加载微调数据
feats = np.load("feats_jn105.npy")
labels = np.array([0])
print(feats.shape)
print(labels.shape)

# 将data和labels转换为 PyTorch 张量
data_tensor = torch.tensor(feats, dtype = torch.float32, requires_grad=True)
labels_tensor = torch.tensor(labels, dtype = torch.long)

# 添加通道维度
# data_tensor = data_tensor.unsqueeze(1)  # 变为(num, 1, 32, 16)
batch_size = 15

# 创建 TensorDataset
dataset = TensorDataset(data_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
input, label = next(iter(data_loader))
print(input.shape,label.shape)
# upyter nbconvert --to script ./Model/main_SqueezeNet.ipynb # 终端运行,ipynb转py

## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Fire

model = torch.load("Model/SqueezeNet.pth").to(device)
print(model)

# 为模型写一个新的层
# model.fc = nn.Linear(in_features = model.fc.in_features, out_features = model.fc.out_features)
model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)

# 先冻结所有层
for param in model.parameters():
    param.requires_grad = False

# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():
    param.requires_grad = True

raw_parm = model.stem[0].weight
print(raw_parm)
for name, param in model.named_parameters():
    print(name, param.requires_grad)

# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 只优化c10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)

# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 设置模型为训练模式
model.train()

num_epochs = 10
for epoch in range(num_epochs):
    # model.train()
    running_loss = 0.0
    correct = 0
    for x_train, y_train in data_loader:
        x_train, y_train = x_train.to(device), y_train.to(device)
        print(x_train.shape, y_train.shape)

        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x_train.size(0)

        # 统计训练集的准确率
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_train).sum().item()

    # 计算每个 epoch 的训练损失和准确率
    epoch_loss = running_loss / len(dataset)
    epoch_accuracy = 100 * correct / len(dataset)
    
    # if epoch % 5 == 0 or epoch == num_epochs-1 :
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')

## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)


如有更好的方法,欢迎大家分享~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2248095.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

滑动窗口最大值(java)

题目描述 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的最大值 。 示例 1: 输入:nums [1,3,-1,-3,5,3,6,7]…

拥抱极简主义前端开发:NoCss.js 引领无 CSS 编程潮流

在前端开发的世界里,我们总是在不断追寻更高效、更简洁的方式来构建令人惊艳的用户界面。而今天,我要向大家隆重介绍一款具有创新性的工具 ——NoCss.js,它将彻底颠覆你对传统前端开发的认知,引领我们进入一个全新的无 CSS 编程时…

配置Springboot+vue项目在ubuntu20.04

一、jdk1.8环境配置 (1) 安装jdk8: sudo apt-get install openjdk-8-jdk (2) 检查jdk是否安装成功: java -version(3) 设置JAVA_HOME: echo export JAVA_HOME/usr/lib/jvm/java-17-openjdk-amd64 >> ~/.bashrc echo export PATH$J…

Spring框架特性及包下载(Java EE 学习笔记04)

1 Spring 5的新特性 Spring 5是Spring当前最新的版本,与历史版本对比,Spring 5对Spring核心框架进行了修订和更新,增加了很多新特性,如支持响应式编程等。 更新JDK基线 因为Spring 5代码库运行于JDK 8之上,所以Spri…

软考教材重点内容 信息安全工程师 第 5 章 物理与环境安全技术

5.1.1 物理安全概念 传统上的物理安全也称为实体安全,是指包括环境、设备和记录介质在内的所有支持网络信息系统运行的硬件的总体安全,是网络信息系统安全、可靠、不间断运行的基本保证,并且确保在信息进行加工处理、服务、决策支持的过程中&…

「Chromeg谷歌浏览器/Edge浏览器」篡改猴Tempermongkey插件的安装与使用

1. 谷歌浏览器安装及使用流程 1.1 准备篡改猴扩展程序包。 因为谷歌浏览器的扩展商城打不开,所以需要准备一个篡改猴压缩包。 其他浏览器只需打开扩展商城搜索篡改猴即可。 没有压缩包的可以进我主页下载。 也可直接点击下载:Chrome浏览器篡改猴(油猴…

【案例学习】如何使用Minitab实现包装过程的自动化和改进

Masimo 是一家全球性的医疗技术公司,致力于开发和生产各种行业领先的监控技术,包括创新的测量、传感器和患者监护仪。在 Masimo Hospital Automation 平台的助力下,Masimo 的连接、自动化、远程医疗和远程监控解决方案正在改善医院内外的护理…

Git旧文件覆盖引发思考

一天,我的同事过来找到我,和我讲:张叫兽,大事不好,我的文件被人覆盖了。git是真的不好用啊 git不好用?文件被覆盖;瞬间我似乎知道了什么,让我想到了某位男明星的语法:他…

CSP/信奥赛C++语法基础刷题训练(23):洛谷P1217:[USACO1.5] 回文质数 Prime Palindromes

CSP/信奥赛C语法基础刷题训练(23):洛谷P1217:[USACO1.5] 回文质数 Prime Palindromes 题目描述 因为 151 151 151 既是一个质数又是一个回文数(从左到右和从右到左是看一样的),所以 151 151 …

嵌入式系统与OpenCV

目录 一、OpenCV 简介 二、嵌入式 OpenCV 的安装方法 1. Ubuntu 系统下的安装 2. 嵌入式 ARM 系统中的安装 3. Windows10 和树莓派系统下的安装 三、嵌入式 OpenCV 的性能优化 1. 介绍嵌入式平台上对 OpenCV 进行优化的必要性。 2. 利用嵌入式开发工具,如优…

SAP BC 记录一次因为HANA服务器内存满的问题

用户操作 DB02 进入hana数据库服务器 free -g 内存用完了 如下图 解决方案:增加内存 操作 关应用服务->关闭数据库服务->关闭hana服务器->加内存->起hana服务器->起hana服务->启动应用服务。

ArcGIS应用指南:ArcGIS制作局部放大地图

在地理信息系统(GIS)中,制作详细且美观的地图是一项重要的技能。地图制作不仅仅是简单地将地理数据可视化,还需要考虑地图的可读性和美观性。局部放大图是一种常见的地图设计技巧,用于展示特定区域的详细信息&#xff…

python画图|无坐标轴自由划线操作fig.add_artist(lines.Line2D()函数

【1】引言 新发现了一种自由划线操作函数,和大家共享。 【2】官网教程 点击下述代码,直达官网: https://matplotlib.org/stable/gallery/misc/fig_x.html#sphx-glr-gallery-misc-fig-x-py 官网代码非常简洁,我进行了解读。 …

深度解析:Nginx模块架构与工作机制的奥秘

文章目录 前言Nginx是什么?Ngnix特点: 一、Nginx模块与工作原理1.Nginx的模块1.1 Nginx模块常规的HTTP请求和响应的流程图:1.2 Nginx的模块从结构上分为如下三类:1.3 Nginx的模块从功能上分为如下三类: 2.Nginx的进程模型2.1 Nginx进程结构2.2 nginx进程…

抖音SEO矩阵系统:开发技术分享

市场环境剖析 短视频SEO矩阵系统是一种策略,旨在通过不同平台上的多个账号建立联系,整合同一品牌下的各平台粉丝流量。该系统通过遵循每个平台的规则和内容要求,输出企业和品牌形象,以矩阵形式增强粉丝基础并提升商业价值。抖音作…

面试经典 150 题:205,55

205. 同构字符串 【解题思路】 来自大佬Krahets 【参考代码】 class Solution { public:bool isIsomorphic(string s, string t) {map<char, char> Smap, Tmap;for(int i0; i<s.size(); i){char a s[i], b t[i];//map容器存在该字符&#xff0c;且不等于之前映射…

STM32 Keil5 attribute 关键字的用法

这篇文章记录一下STM32中attribute的用法。之前做项目的时候产品需要支持远程升级&#xff0c;要求版本只能向上迭代&#xff0c;不支持回退。当时想到的方案是把版本号放到bin文件的头部&#xff0c;设备端收到bin文件的首包部数据后判断是否满足升级要求&#xff0c;这里就可…

【Redis 缓存策略】更新、穿透、雪崩、击穿、布隆过滤

目录 缓存简单介绍 缓存更新策略 缓存更新需求 数据库缓存不一致解决方案 先操作缓存还是先操作数据库&#xff1f; 先删除缓存&#xff0c;再操作数据库 先操作数据库&#xff0c;再删除缓存 总结 删除缓存还是更新缓存&#xff1f; 保证缓存与数据库的操作的同时成功或失败 …

利用c语言详细介绍下栈的实现

数据结构中&#xff0c;栈是一种线性结构&#xff0c;数据元素遵循后进先出的原则。栈的一端为栈顶&#xff0c;一端为栈底或栈尾&#xff0c;数据只在栈顶端进行操作。新插入数据称为入栈或者压栈&#xff0c;删除数据叫做出栈或者退栈。 一、图文介绍 我们通过建立一个stack…

5.5 W5500 TCP服务端与客户端

文章目录 1、TCP介绍2、W5500简介2.1 关键函数socketlistensendgetSn_RX_RSRrecv自动心跳包检测getSn_SR 1、TCP介绍 TCP 服务端&#xff1a; 创建套接字[socket]&#xff1a;服务器首先创建一个套接字&#xff0c;这是网络通信的端点。绑定套接字[bind]&#xff1a;服务器将…