使用预训练模型构建自己的深度学习模型(迁移学习)

news2025/1/22 17:51:07

在深度学习的实际应用中,很少会去从头训练一个网络,尤其是当没有大量数据的时候。即便拥有大量数据,从头训练一个网络也很耗时,因为在大数据集上所构建的网络通常模型参数量很大,训练成本大。所以在构建深度学习应用时,通常会使用预训练模型

需求

训练一个模型来分类蚂蚁ants和蜜蜂bees

步骤

  1. 加载数据集
  2. 编写函数(训练并寻找最优模型)
  3. 编写函数(查看模型效果)
  4. 使用torchvision微调模型
  5. 使用tensorboard可视化训练情况

Pytorch保存和加载模型的两种方式

1.完整保存模型、加载模型

torch.save(net, 'mnist.pth')
net = torch.load('mnist.pth', 'map_location="cpu"')
# # Model class must be defined somewhere

load() 默认会将该张量加载到保存时所在的设备上,map_location 可以强制加载到指定的设备上

2. 保存、加载模型的状态字典(模型中的参数)

torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

迁移学习全过程

1.加载数据集

ants和bees各有约120张训练图片。

每个类有75张验证图片,从零开始在 如此小的数据集上进行训练通常是很难泛化的。

由于我们使用迁移学习,模型的泛化能力会相当好。

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
cudnn.benchmark = True  
plt.ion()

lr_scheduler:学习率调度器,用于在训练过程中动态调整学习率

torch.backends.cudnn.benchmark = True:大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,从而加速运算

plt.ion():这允许你在一个交互式环境中运行matplotlib

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

transforms.RandomResizedCrop(224):随机大小裁剪和缩放图像,使得裁剪后的图像尺寸为224x224像素。这个操作有助于模型学习不同尺度和长宽比的图像特征,从而提高模型的泛化能力。

transforms.RandomHorizontalFlip():以一定的概率(默认为0.5)对图像进行水平翻转。这也是一种数据增强技术,有助于模型学习对称性

transforms.Resize(256):将图像缩放到256x256像素

transforms.CenterCrop(224):从图像的中心裁剪出224x224像素的区域。这种裁剪方式确保每次裁剪的都是图像的中心部分,有助于在验证或测试时获得更一致的结果

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化数据

def imshow(inp, title=None): 
    # 可视化一组 Tensor 的图片
    inp = inp.numpy().transpose((1, 2, 0)) 
    mean = np.array([0.485, 0.456, 0.406]) 
    std = np.array([0.229, 0.224, 0.225]) 
    inp = std * inp + mean 
    inp = np.clip(inp, 0, 1) 
    plt.imshow(inp) 
    if title is not None: 
        plt.title(title) 
    plt.pause(0.001) # 暂停一会儿,为了将图片显示出来
# 获取一批训练数据
inputs, classes = next(iter(dataloaders['train'])) 
# 批量制作网格
out = torchvision.utils.make_grid(inputs) 
imshow(out, title=[class_names[x] for x in classes]) 

2.编写函数(训练并寻找最优模型)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 
    """ 
    训练模型,并返回在验证集上的最佳模型和准确率 
    - criterion: 损失函数 

   Return: 
    - model(nn.Module): 最佳模型 
    - best_acc(float): 最佳准确率 
    """

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        # 训练集和验证集交替进行前向传播
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()   

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据集
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # 清空梯度,避免累加了上一次的梯度
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # 正向传播
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # 反向传播且仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward() 
                        optimizer.step()

                # 统计loss、准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 
    """ 训练模型,并返回在验证集上的最佳模型和准确率 
    Args: 
    - model(nn.Module): 要训练的模型 
    - criterion: 损失函数 
    - optimizer(optim.Optimizer): 优化器 
    - scheduler: 学习率调度器 
    - num_epochs(int): 最大 epoch 数 
    Return: 
    - model(nn.Module): 最佳模型 
    - best_acc(float): 最佳准确率 
    """
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # 训练集和验证集交替进行前向传播
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置为训练模式,可以更新网络参数
            else:
                model.eval()   # 设置为预估模式,不可更新网络参数

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据集
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清空梯度,避免累加了上一次的梯度
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # 正向传播
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 反向传播且仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward() # 反向传播
                        optimizer.step()

                # 统计loss、准确率
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # 发现了更优的模型,记录起来
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # 加载训练的最好的模型
    model.load_state_dict(best_model_wts)
    return model

 

3.编写函数(查看模型效果)

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

4.使用torchvision微调模型

微调步骤图

加载预训练模型,并将最后一个全连接层重置

model = models.resnet18(pretrained=True) # 加载预训练模型
num_ftrs = model.fc.in_features # 获取低级特征维度 
model.fc = nn.Linear(num_ftrs, 2) # 替换新的输出层 
model = model.to(device) 
criterion = nn.CrossEntropyLoss() 
# 所有参数都参加训练 
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 
# 每过 7 个 epoch 将学习率变为原来的 0.1 
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练

model_ft = train_model(model, criterion, optimizer_ft, scheduler, num_epochs=3)

预估

visualize_model(model_ft)

 

5.使用tensorboard可视化训练情况

将以下代码块合理的加入到训练模块里

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
ep_losses, ep_acces = [], []

ep_losses.append(epoch_loss)
ep_acces.append(epoch_acc.item())

writer.add_scalars('loss', {'train': ep_losses[-2], 'val': ep_losses[-1]}, global_step=epoch)
writer.add_scalars('acc', {'train': ep_acces[-2], 'val': ep_acces[-1]}, global_step=epoch)
writer.close()

运行命令启动tensorboard

tensorboard --logdir runs

可以通过执行命令的终端看到tensorboard可视化页面地址,且该命令会在执行该命令的目录下生成一个runs文件夹(日志文件的保存位置)

以下是我训练了25个epochs的acc和loss的动态图,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1620250.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据学习第四天

文章目录 yaml 三大组件的方式交互流程hive 使用安装mysql(hadoop03主机)出现错误解决方式临时密码 卸载mysql (hadoop02主机)卸载mysql(hadoop01主机执行)安装hive上传文件解压解决版本差异修改hive-env.sh修改 hive-site.xml上传驱动包初始化元数据在hdfs 创建hive 存储目录启…

毫米波雷达模块在高精度人体姿态识别的应用

人体姿态识别是计算机视觉领域中的重要问题之一,具有广泛的应用前景,如智能安防、虚拟现实、医疗辅助等。毫米波雷达技术作为一种无需直接接触目标就能实现高精度探测的感知技术,在人体姿态识别领域具有独特的优势。本文将探讨毫米波雷达模块…

kubeadmin搭建自建k8s集群

一、安装要求 在开始之前,部署Kubernetes集群的虚拟机需要满足以下几个条件: 操作系统 CentOS7.x-86_x64硬件配置:2GB或更多RAM,2个CPU或更多CPU,硬盘30GB或更多【注意master需要两核】可以访问外网,需要…

Python 全栈体系【四阶】(三十四)

第五章 深度学习 六、PaddlePaddle 图像分类 4. 思路及实现 4.1 数据集介绍 来源:爬虫从百度图片搜索结果爬取 内容:包含 1036 张水果图片,共 5 个类别(苹果 288 张、香蕉 275 张、葡萄 216 张、橙子 276 张、梨 251 张&#…

NVIDIA Jetson jtop查看资源信息

sudo -H pip install -U jetson-stats 安装好之后可能需要reboot 执行jtop: 时间久了可能会退出,可参考如下再次启动。 nvidiategra-ubuntu:~$ jtop The jtop.service is not active. Please run: sudo systemctl restart jtop.service nvidiategra-ub…

【古琴】倪诗韵古琴雷修系列(形制挺多的)

雷音系列雷修:“修”字取意善、美好的,更有“使之完美”之意。精品桐木或普通杉木制,栗壳色,纯鹿角霜生漆工艺。 方形龙池凤沼。红木配件,龙池上方有“倪诗韵”亲笔签名,凤沼下方,雁足上方居中位…

mPEG-Biotin,Methoxy PEG Biotin在免疫亲和层析、荧光标记和生物传感器等领域发挥关键作用

【试剂详情】 英文名称 mPEG-Biotin,Methoxy PEG Biotin 中文名称 聚乙二醇单甲醚生物素,甲氧基-聚乙二醇-生物素 外观性状 由分子量决定,固体或者粘稠液体。 分子量 0.4k,0.6k,1k,2k,3.…

Activiti7基础

Activiti7 一、工作流介绍 1.1 概念 工作流(Workflow),就是通过计算机对业务流程自动化执行管理。它主要解决的是“使在多个参与者之间按照某种预定义的规则自动进行传递文档、信息或任务的过程,从而实现某个预期的业务目标,或者促使此目标…

2024-04-23 linux 查看内存占用情况的命令free -h和cat /proc/meminfo

一、要查看 Linux 系统中的内存占用大小,可以使用 free 命令或者 top 命令。下面是这两个命令的简要说明: 使用 free 命令: free -h这将显示系统当前的内存使用情况,包括总内存、已用内存、空闲内存以及缓冲区和缓存的使用情况。…

Git笔记-配置ssh

Git在Deepin中的ssh配置 一、环境二、安装1. 查看GitHub账户2. 配置 git3. 生成 ssh key 三、配置 一、环境 系统: Deepin v23 Git仓库:GitHub 二、安装 1. 查看GitHub账户 在设置界面看到自己的邮箱,这个邮箱就是后面会用到的邮箱 2. …

上位机图像处理和嵌入式模块部署(树莓派4b的一种固件部署方法)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 如果软件开发好了之后,下面就是实施和部署。对于树莓派4b来说,部署其实就是烧录卡和拷贝文件。之前我们烧录卡,…

Jenkins CI/CD 持续集成专题四 Jenkins服务器IP更换

一、查看brew 的 services brew services list 二、编辑 homebrew.mxcl.jenkins-lts.plist 将下面的httpListenAddress值修改为自己的ip 服务器,这里我是用的本机的ip 三 、重新启动 jenkins-lts brew services restart jenkins-lts 四 浏览器访问 http://10.85…

26版SPSS操作教程(高级教程第十三章)

前言 #今日世界读书日,宝子你,读书了嘛~ #本期内容:主成分分析、因子分析、多维偏好分析 #由于导师最近布置了学习SPSS这款软件的任务,因此想来平台和大家一起交流下学习经验,这期推送内容接上一次高级教程第十二章…

卓越体验的秘密武器:评测ToDesk云电脑、青椒云、天翼云的稳定性和流畅度

大家好,我是猫头虎。近两年随着大模型的火爆,我们本地环境常常难以满足运行这些大模型的硬件需求。因此,云电脑平台成为了一个理想的解决方案。今天,我将介绍并评测几款主流云电脑产品:ToDesk云电脑、天翼云电脑和青椒…

网络通信安全

一、网络通信安全基础 TCP/IP协议简介 TCP/IP体系结构、以太网、Internet地址、端口 TCP/IP协议简介如下:(from文心一言) TCP/IP(Transmission Control Protocol/Internet Protocol,传输控制协议/网际协议&#xff0…

PVE虚拟机隐藏状态栏虚拟设备

虚拟机启动后,状态栏会出现一些虚拟设备,点击弹出会导致虚拟机无法使用。 解决方案: 1、在桌面新建disable_virtio_removale.bat文件,内容如下: ECHO OFF FOR /f %%A IN (reg query "HKLM\SYSTEM\CurrentContro…

Docker容器化技术

Docker容器化技术 1.Docker概念 Docker是一个开源的应用容器引擎基于go语言实现Docker可以让开发者们打包他们的应用以及依赖包到一个轻量级的、可移植的容器中,然后发布到任何流行的Linux机器上容器是完全使用沙箱机制,相互隔离容器性能开销极低Docke…

Facebook的时间机器:回溯社交媒体的历史

1. 社交媒体的起源与早期模式 社交媒体的历史可以追溯到互联网的早期发展阶段。在Web 1.0时代,互联网主要是一个信息发布平台,用户主要是被动地接收信息。但随着Web 2.0的兴起,互联网逐渐转变为一个互动和参与的平台,社交媒体应运…

HTTP与SOCKS-哪种协议更适合您的代理需求?

网络代理技术是我们日常使用网络时必不可少的一项技术,它可以为我们提供隐私保护和负载均衡的能力,从而保证我们的网络通信更加安全和顺畅。而其中最主流的两种协议就是HTTP和SOCKS。虽然它们都是用于网络代理的协议,但在实际应用中却存在着一…

时间复杂度和空间复杂度是什么

如何衡量代码好坏,算法的考察到底是在考察什么呢? 衡量代码好坏有两个非常重要的标准就是:运行时间和占用空间,就是我们后面要说到的时间复杂度和空间复杂度,也是学好算法的重要基石。 确切的占内用存或运行时间无法进…