pytorch11:模型加载与保存、finetune迁移训练

news2025/1/14 18:31:50

在这里插入图片描述

目录

  • 一、模型加载与保存
    • 1.1 序列化与反序列化概念
    • 1.2 pytorch中的序列化与反序列化
    • 1.3 模型保存的两种方法
    • 1.4 模型加载两种方法
  • 二、断点训练
    • 2.1 断点保存代码
    • 2.2 断点恢复代码
  • 三、finetune
    • 3.1 迁移学习
    • 3.2 模型的迁移学习
    • 3.2 模型微调步骤
      • 3.2.1 模型微调步骤
      • 3.2.2 模型微调训练方法
    • 3.3 迁移训练实验

一、模型加载与保存

1.1 序列化与反序列化概念

序列化是将数据结构或对象转换为可以存储或传输的格式的过程,而反序列化则是将存储或传输的数据重新转换为数据结构或对象的过程。
在计算机科学中,序列化和反序列化通常用于数据持久化、网络传输和进程间通信等场景。以下是对序列化和反序列化的详细解读:

  1. 序列化:
    • 序列化的过程将数据结构或对象转换为字节流或其他格式,以便在存储或传输时能够被保存下来或发送出去。这通常涉及将数据结构中的字段和属性转换为二进制码或文本格式,以便能够被存储在文件中或通过网络传输。
    • 序列化的结果可以是二进制数据、JSON、XML等格式,不同的数据类型和应用场景可能采用不同的序列化格式。
    • 序列化的过程可以包括将对象进行扁平化、编码、压缩等操作,以便提高存储和传输的效率和安全性。
  2. 反序列化:
    • 反序列化的过程是将序列化后的数据重新转换为原始的数据结构或对象,使得在存储或传输后能够恢复原来的数据格式和内容。
    • 反序列化的过程需要根据序列化时采用的格式和规则,对序列化后的数据进行解码、解压缩等操作,最终还原为原始的数据结构或对象。
    • 反序列化的过程需要确保数据的完整性和正确性,以及适当地处理可能存在的异常和错误情况。
      在实际应用中,序列化和反序列化广泛应用于各种领域,如数据库持久化、分布式系统通信、缓存存储、远程过程调用等。常见的序列化和反序列化技术包括JSON、XML、Protocol Buffers、Thrift等,它们在不同的场景和需求下有着不同的优势和适用性。

通过序列化技术,将内存中的数据存储到硬盘,在需要使用的时候通过反序列化的方法转化成可读取数据。
在这里插入图片描述

1.2 pytorch中的序列化与反序列化

  1. torch.save(序列化):用于保存模型
    主要参数:
    • obj:对象
    • f:输出路径
  2. torch.load(反序列化):用于加载模型
    主要参数
    • f:文件路径
    • map_location:指定存放位置, cpu or gpu

1.3 模型保存的两种方法

方法1:保存整个Module模型
torch.save(net, path)
方法2:保存模型参数parameter
state_dict = net.state_dict()
torch.save(state_dict , path)

使用方法1会比较耗时耗费资源,通常我们会使用方法2,只保存模型训练过程中的参数。

代码实现:

import torch
import numpy as np
import torch.nn as nn


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2024111)


net = LeNet2(classes=2024)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()  #模型模型训练参数改变
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

输出结果:
在这里插入图片描述
在这里插入图片描述

1.4 模型加载两种方法

方法1:加载模型
代码实现:

# ================================== load net ===========================
flag = 1
# flag = 0
if flag:

    path_model = "./model.pkl"
    net_load = torch.load(path_model)

    print(net_load)

输出结果:
在这里插入图片描述
在这里插入图片描述

方法2: 加载参数
代码实现:

# ================================== load state_dict ===========================

flag = 1
# flag = 0
if flag:
    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)
    print(state_dict_load.keys())

输出结果:
将保存的参数名称打印出来;
在这里插入图片描述
方法3:将参数加载到新的模型当中

# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:

    net_new = LeNet2(classes=2019)

    print("加载前: ", net_new.features[0].weight[0, ...])
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])

输出结果:
在这里插入图片描述

以上完整代码:

# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x
    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)
# ================================== load net ===========================
# flag = 1
flag = 0
if flag:
    path_model = "./model.pkl"
    net_load = torch.load(path_model)
    print(net_load)
# ================================== load state_dict ===========================
flag = 1
# flag = 0
if flag:
    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)
    print(state_dict_load.keys())
# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:
    net_new = LeNet2(classes=2024)
    print("加载前: ", net_new.features[0].weight[0, ...])
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])

二、断点训练

首先我们需要确定模型训练过程中哪些参数是会一直发生变化的,模型中的权值以及优化器中的可学习参数是一直发生变化的,数据以及损失函数是保持不变的。
在这里插入图片描述
断点训练函数方法:
在这里插入图片描述

2.1 断点保存代码

当训练到第5次的时候我们进行人为中断训练,将当前训练阶段的模型权值参数、优化器参数、训练轮数保存到checkpoint

    if (epoch+1) % checkpoint_interval == 0:  # checkpoint_interval初始值设置为5
        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
    if epoch > 5:
        print("训练意外中断...")
        break

输出结果:
在这里插入图片描述

2.2 断点恢复代码

加载上一次训练相关参数数据

# ============================ step 5+/5 断点恢复 ============================
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数
start_epoch = checkpoint['epoch']  # 加载上一次训练轮数
scheduler.last_epoch = start_epoch  # 学习率策略更新

输出结果:
当前训练初始轮数从第5轮开始训练,所以精度可以很快增加。
在这里插入图片描述
完整代码展示

save_checkpoint.py;保存断点参数数据

# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
import sys

hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision

set_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)  # 设置学习率下降策略

# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数

start_epoch = checkpoint['epoch']  # 加载上一次训练轮数

scheduler.last_epoch = start_epoch  # 学习率策略更新

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i + 1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer.state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    #     print("训练意外中断...")
    #     break

    # validate the model
    if (epoch + 1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val / len(valid_loader), correct / total))

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

checkpoint_resume.py:加载断点参数数据

# -*- coding: utf-8 -*-
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
import sys

hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision

set_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)  # 设置学习率下降策略

# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])  # 加载网络模型参数

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器当中相关可学习参数

start_epoch = checkpoint['epoch']  # 加载上一次训练轮数

scheduler.last_epoch = start_epoch  # 学习率策略更新

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i + 1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer.state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    #     print("训练意外中断...")
    #     break

    # validate the model
    if (epoch + 1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val / len(valid_loader), correct / total))

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

三、finetune

3.1 迁移学习

首先了解一下迁移学习(Transfer Learning)概念:它机器学习分支,研究源域(source domain)的知识如何应用到目标域(target
domain),来提高模型的性能。
在这里插入图片描述

图(a)是传统的机器学习过程,针对某一个任务进行网络模型训练
图(b)是迁移学习过程,通过对源任务进行模型训练得到一个"知识",当我们需要训练一个新的任务时,可以在源任务训练的"知识"上继续进行训练,从而得到target模型;

3.2 模型的迁移学习

假设我们已经训练好一个模型了,我们把网络训练过程中的权值当做"知识",当我想要再训练一个新的模型任务时,但是数据量较小,不足以训练一个较好的模型,我们把上一个模型的知识应用到新的任务当中,这就是模型的迁移训练,从而提高模型的精度和效果。就好比一个人学会了一门乐器之后,已经掌握了相关乐理知识,再让他去学习另外一门乐器就会更加容易学习!!!在这里插入图片描述

3.2 模型微调步骤

通常我们会找到模型训练过程中具有相同共性的部分,例如下面这个神经网络,分为两部分,分别是特征提取部分feature和图像分类部分classifier两个部分,当我们需要进行其他图像分类任务的时候,我们可以保留图像特征提取部分,改变分类部分的output类别数。在这里插入图片描述

3.2.1 模型微调步骤

  1. 获取预训练模型参数:源任务当中学习到的"知识"。
  2. 加载模型(load_state_dict):将知识加载到新的模型当中。
  3. 修改输出层,不同任务输出层类别数不同。

3.2.2 模型微调训练方法

  1. 固定预训练的参数(固定参数的方法:requires_grad =False;lr=0)。
  2. Features Extractor较小学习率(params_group),不同的参数组设置不同的学习率,例如特征提取模块我们希望它变动不大,可以设置较小的学习率,在特征分类模块可以设置较大的学习率。

3.3 迁移训练实验

1、数据准备
Finetune Resnet -18 用于二分类,蚂蚁蜜蜂二分类数据,训练集:各120~张 验证集:各70~张
下载Resnet -18预训练模型,下载地址:https://download.pytorch.org/models/resnet18-5c106cde.pth
在这里插入图片描述

2、实验结果
在经过25轮训练之后,不使用预训练模型的精度提升的很慢,但是使用了预训练模型之后精度可以快速上升。
在这里插入图片描述


数据处理模块,用于获取蜜蜂和蚂蚁图片路径以及文件夹标签:

class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        # 初始化函数,接收数据目录和数据变换操作
        self.label_name = {"ants": 0, "bees": 1}  # 定义标签对应的字典
        self.data_info = self.get_img_info(data_dir)  # 获取图片信息
        self.transform = transform  # 保存数据变换操作

    def __getitem__(self, index):
        # 获取指定索引处的数据
        path_img, label = self.data_info[index]  # 获取图片路径和标签
        img = Image.open(path_img).convert('RGB')  # 打开图片并转换为RGB格式

        if self.transform is not None:
            img = self.transform(img)  # 对图片进行数据变换
        return img, label  # 返回处理后的图片和标签

    def __len__(self):
        # 返回数据集的长度
        return len(self.data_info)

    def get_img_info(self, data_dir):
        # 获取图片信息的函数
        data_info = list()  # 创建空列表用于保存图片信息
        for root, dirs, _ in os.walk(data_dir):
            # 遍历数据目录中的子目录
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))  # 获取子目录下的文件列表
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))  # 筛选出.jpg格式的文件名

                # 遍历该子目录下的图片
                for i in range(len(img_names)):
                    img_name = img_names[i]  # 获取图片文件名
                    path_img = os.path.join(root, sub_dir, img_name)  # 构建完整的图片路径
                    label = self.label_name[sub_dir]  # 获取图片的标签
                    data_info.append((path_img, int(label)))  # 将图片路径和标签添加到数据信息列表中

        if len(data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(
                data_dir))  # 若数据信息列表为空,则抛出异常
        return data_info  # 返回图片信息列表

加载预训练模型模块:

flag = 1
if flag:
    path_pretrained_model = os.path.join("finetune_resnet18-5c106cde.pth")
    if not os.path.exists(path_pretrained_model):
        raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip\n放到 {}下,并解压即可".format(
            path_pretrained_model, os.path.dirname(path_pretrained_model)))
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

冻结网络层方法

# 冻结所有网络层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.fc.weight[0, 0, ...]))

替换fc层,因为我们是2分类模型,所以需要修改分类网络;

# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features  # 获取原网络中的特征输入
resnet18_ft.fc = nn.Linear(num_ftrs, classes)  #classes=2

网络分组模块,我们希望特征提取部分参数更新小一些,分类部分参数更新大一些;

if flag:
    # 将网络划分为两个参数组
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))
    """
    这一行首先使用resnet18_ft.fc.parameters()获取了ResNet18模型中全连接层的参数,
    然后通过map(id, ...)将每个参数的内存地址映射为一个列表。这样得到的fc_params_id列表包含了全连接层参数的内存地址。
    """
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())
    """
    这一行使用filter函数和lambda表达式来过滤resnet18_ft模型中不属于全连接层的参数。具体来说,filter函数通过
    lambda p: id(p) not in fc_params_id对resnet18_ft.parameters()中的参数进行过滤,
    保留那些内存地址不在fc_params_id列表中的参数。这样就得到了base_params,其中包含了除全连接层参数外的其他层的参数,也就是特征提取部分。
    """
    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR * 0},  # 卷积层的学习率设置小一些,如果设置为0的话,会直接冻结卷积层
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器

完整模型训练代码如下:

# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from PIL import Image
import sys
import random

hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)
import torchvision.models as models
import torchvision

BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))


# =====================参数设置=====================
def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}

# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 5


# =======================读取图片数据==============================
class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        # 初始化函数,接收数据目录和数据变换操作
        self.label_name = {"ants": 0, "bees": 1}  # 定义标签对应的字典
        self.data_info = self.get_img_info(data_dir)  # 获取图片信息
        self.transform = transform  # 保存数据变换操作

    def __getitem__(self, index):
        # 获取指定索引处的数据
        path_img, label = self.data_info[index]  # 获取图片路径和标签
        img = Image.open(path_img).convert('RGB')  # 打开图片并转换为RGB格式

        if self.transform is not None:
            img = self.transform(img)  # 对图片进行数据变换
        return img, label  # 返回处理后的图片和标签

    def __len__(self):
        # 返回数据集的长度
        return len(self.data_info)

    def get_img_info(self, data_dir):
        # 获取图片信息的函数
        data_info = list()  # 创建空列表用于保存图片信息
        for root, dirs, _ in os.walk(data_dir):
            # 遍历数据目录中的子目录
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))  # 获取子目录下的文件列表
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))  # 筛选出.jpg格式的文件名

                # 遍历该子目录下的图片
                for i in range(len(img_names)):
                    img_name = img_names[i]  # 获取图片文件名
                    path_img = os.path.join(root, sub_dir, img_name)  # 构建完整的图片路径
                    label = self.label_name[sub_dir]  # 获取图片的标签
                    data_info.append((path_img, int(label)))  # 将图片路径和标签添加到数据信息列表中

        if len(data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(
                data_dir))  # 若数据信息列表为空,则抛出异常
        return data_info  # 返回图片信息列表


# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "..", "data", "hymenoptera_data"))
if not os.path.exists(data_dir):
    raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(
        data_dir, os.path.dirname(data_dir)))

train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载模型参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join("finetune_resnet18-5c106cde.pth")
    if not os.path.exists(path_pretrained_model):
        raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip\n放到 {}下,并解压即可".format(
            path_pretrained_model, os.path.dirname(path_pretrained_model)))
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 冻结所有网络层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.fc.weight[0, 0, ...]))

# 冻结卷积层
flag_c = 0
# flag_c = 1
if flag_c:
    for name, param in resnet18_ft.named_parameters():
        if "fc" in name:  # 如果参数名中不包含"fc",即不是全连接层的参数
            param.requires_grad = False

# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features  # 获取原网络中的特征输入
resnet18_ft.fc = nn.Linear(num_ftrs, classes)

resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()  # 选择损失函数

# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
flag = 0
# flag = 1
if flag:
    # 将网络划分为两个参数组
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))
    """
    这一行首先使用resnet18_ft.fc.parameters()获取了ResNet18模型中全连接层的参数,
    然后通过map(id, ...)将每个参数的内存地址映射为一个列表。这样得到的fc_params_id列表包含了全连接层参数的内存地址。
    """
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())
    """
    这一行使用filter函数和lambda表达式来过滤resnet18_ft模型中不属于全连接层的参数。具体来说,filter函数通过
    lambda p: id(p) not in fc_params_id对resnet18_ft.parameters()中的参数进行过滤,
    保留那些内存地址不在fc_params_id列表中的参数。这样就得到了base_params,其中包含了除全连接层参数外的其他层的参数,也就是特征提取部分。
    """
    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR * 0},  # 卷积层的学习率设置小一些,如果设置为0的话,会直接冻结卷积层
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器
    # optimizer = optim.Adam(resnet18_ft.parameters(), lr=0.01)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    resnet18_ft.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18_ft(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i + 1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

            # if flag_m1:
            print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.fc.weight[0, 0, ...]))

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch + 1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

                loss_val += loss.item()

            loss_val_mean = loss_val / len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val))

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1376145.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

编码器与解码器LLM全解析:掌握NLP核心技术的关键!

让我们深入了解:基于编码器和基于解码器的模型有什么区别? 编码器与解码器风格的Transformer 从根本上说,编码器和解码器风格的架构都使用相同的自注意力层来编码词汇标记。然而,主要区别在于编码器旨在学习可以用于各种预测建模…

BP神经网络(公式推导+举例应用)

文章目录 引言M-P神经元模型激活函数多层前馈神经网络误差逆传播算法缓解过拟合化结论实验分析 引言 人工神经网络(Artificial Neural Networks,ANNs)作为一种模拟生物神经系统的计算模型,在模式识别、数据挖掘、图像处理等领域取…

TYPE-C接口取电芯片介绍和应用场景

随着科技的发展,USB PDTYPE-C已经成为越来越多设备的充电接口。而在这一领域中,LDR6328Q PD取电芯片作为设备端协议IC芯片,扮演着至关重要的角色。本文将详细介绍LDR6328Q PD取电芯片的工作原理、应用场景以及选型要点。 一、工作原理 LDR63…

MySQL——性能优化与关系型数据库

文章目录 什么是性能?什么是关系型数据库?数据库设计范式 常见的数据库SQL语言结构化查询语言的六个部分版本 MySQL数据库故事历史版本5.6/5.7差异5.7/8.0差异 什么是性能? 吞吐与延迟:有些结论是反直觉的,指导我们关…

芯课堂 | 如何配置SWM系列系统时钟?

如何配置SWM系列 系统时钟? 华芯微特科技有限公司SWM系列芯片可通过软件配置改变时钟的速度,可以让我们的设计更加灵活,频率可选空间也更加广泛,用户可以根据自己的实际需求配置需要的系统时钟。为了让用户能够更简单的使用这一功能&#xf…

java导出word套打

这篇文档手把手教你完成导出word套打&#xff0c;有这个demo&#xff0c;其他word套打导出都通用。 1、主要依赖 <!--hutool--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.3.0</ve…

为什么要做FP独立站?FP独立站有哪些优势?

近年来&#xff0c;跨境电商的商家们面临越来越大的平台政策压力&#xff0c;商家们纷纷把眼光聚焦到独立站上&#xff0c;眼下独立站已经成为出海卖家的标配。 特别是想做FP商品的卖家&#xff0c;相对于亚马逊平台&#xff0c;独立站才是你们的最终出路... 那么&#xff0c;问…

低代码平台可以开发哪些软件系统

在当今数字化时代&#xff0c;企业管理系统的开发已成为各行各业中不可或缺的一环。然而&#xff0c;传统的软件开发过程往往复杂且耗时&#xff0c;难以满足快速变化的市场需求。低代码平台的出现为企业管理系统的开发带来了革命性的变革。本文将探讨低代码平台在企业管理系统…

菱形以及各种组合图形讲解(*#@¥$)

引言&#xff1a; ***形对于新手了解循环以及嵌套循环帮助是非常大的。&#xff08;以下的题各题之间有关联&#xff09; 我们最终目的&#xff0c;就是会编程写菱形&#xff1b;看下面的图片 解题思路&#xff1a;运用拆分法&#xff0c;我们将菱形分为4个部分&#xff0c;看…

MySQL一主一从读写分离

​ MySQL主从复制 一、主从复制概念 主从复制是指将主数据库的DDL和DML操作通过二进制日志传到从服务器中&#xff0c;然后在从服务器上对这些日志重新执行也叫重做&#xff0c;从而使得从数据库和主库的数据保持同步。 MySQL支持一台主库同时向多台从库进行赋值&#xff0c;从…

Halcon实例:提取图像的纹理特征

Halcon实例&#xff1a;提取图像的纹理特征 举例说明&#xff0c;输入的是一幅灰度图像&#xff0c;分别选取其中两个矩形区域的灰度图像&#xff0c;分析其灰度变化。首先选取灰度变化较为明显的矩形1&#xff0c;然后选取灰度变化比较平滑的矩形2&#xff0c;生成灰度共生矩…

AD软件与其他EDA软件工程的问题汇总

1:如何在AD中使用eagle工程 在ad中打不开原理图&#xff0c;要使用导入功能,转化为ad的文件后&#xff0c;就可以打开了 2:打开旧版本的Protel文件 有时候新版本的AD打不开以前Protel的PCB文件&#xff0c;可以在DXP菜单下的Extension下进行配置&#xff08;Configure&…

高效降压控制器FP7132XR:为高亮度LED提供稳定可靠的电源

目录 一. FP7132概述 二. 驱动电路&#xff1a;FP7132 三. FP7132应用 高亮度LED作为新一代照明技术的代表&#xff0c;已经广泛应用于各种领域。然而&#xff0c;高亮度LED的工作电压较低&#xff0c;需要一个高效降压控制器来为其提供稳定可靠的电源。在众多降压控制器…

【AI大模型应用开发】1.0 Prompt Engineering(提示词工程)- 典型构成、原则与技巧,代码中加入Prompt

从这篇文章开始&#xff0c;我们就正式开始学习AI大模型应用开发的相关知识了。首先是提示词工程&#xff08;Prompt Engineering&#xff09;。 文章目录 0. 什么是提示词&#xff08;Prompt&#xff09;1. 为什么Prompt会起作用 - 大模型工作原理2. Prompt的典型构成、原则与…

ubuntu20.04 deepstream 6.3安装

1.基础环境gstreamer sudo apt install \ libssl-dev \ libgstreamer1.0-0 \ gstreamer1.0-tools \ gstreamer1.0-plugins-good \ gstreamer1.0-plugins-bad \ gstreamer1.0-plugins-ugly \ gstreamer1.0-libav \ libgstreamer-plugins-base1.0-dev \ libgstrtspserver-1.0-0 …

微信小程序开发学习笔记《8》tabBar

微信小程序开发学习笔记《8》tabBar 博主正在学习微信小程序开发&#xff0c;希望记录自己学习过程同时与广大网友共同学习讨论。tabBar官方文档 tabBar这一节还是相当重要的。 一、什么是tabBar tabBar是移动端应用常见的页面效果&#xff0c;用于实现多页面的快速切换。小…

第十四章JSON

第十四章JSON 1.什么是JSON2.JSON的定义和访问3.JSON在JavaScript中两种常用的转换方式4.JavaBean和JSON的相互转换5.List集合和JSON的相互转换6.map集合和JSON的相互转换 1.什么是JSON 2.JSON的定义和访问 JSON的定义 JSON的类型是一个Object类型 JSON的访问 我们要…

kafka下载安装部署

Apache kafka 是一个分布式的基于push-subscribe的消息系统&#xff0c;它具备快速、可扩展、可持久化的特点。它现在是Apache旗下的一个开源系统&#xff0c;作为hadoop生态系统的一部分&#xff0c;被各种商业公司广泛应用。它的最大的特性就是可以实时的处理大量数据以满足各…

详解如何撰写一个基础的技术交底书

大家好,我是英子老师。作为一名知识产权专家,深耕于专利行业十余年,具有丰富的专利工作经验:曾在大型专利代理机构从事专利代理工作、专利质检工作(抽查代理机构的专利代理人的撰写质量并评分);之后在知名上市企业、行业龙头企业担任高级专利工程师的职位,主要工作内容…

使用Flash_Download_Tool下载PlatformIO生成的bin程序到ESP32

使用Flash_Download_Tool下载PlatformIO生成的bin程序到ESP32 来源 当我们没有PlatformIO环境时&#xff0c;还要下载PlatformIO生成的程序时&#xff0c;可以使用Flash_Download_Tool工具下载。 说明 使用PlatformIO时&#xff0c;用cmd终端命令下载程序pio run -v -t upl…