《深度学习》迁移学习综合应用 原理、案例解析与实现

news2024/12/26 19:07:59

目录

一、迁移学习

1、什么是迁移学习

2、迁移学习步骤

1)选择预训练的模型和适当的层

2)冻结预训练模型的参数

3)在新数据集上训练新增加的层

4)微调预训练模型的层

5)评估和测试

二、案例实现

1、数据准备及目的

2、冻结参数、更改输出特征

3、数据增强处理

4、导入图像并打包

1)文件内容

2)代码部分

5、损失函数、优化器、调整学习率

6、定义训练集

7、定义测试集

8、传入参数进行训练和测试

运行结果:


一、迁移学习

1、什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2、迁移学习步骤

        1)选择预训练的模型和适当的层

                通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

        2)冻结预训练模型的参数

                保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

        3)在新数据集上训练新增加的层

                在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

        4)微调预训练模型的层

                在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调目标。这样做可以提高模型在新数据集上的性能。

        5)评估和测试

                在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、案例实现

1、数据准备及目的

        使用上节课所说的残差网络的18层结构来对其进行微调,该残差网络结构如下图所示:

        此时我们可以发现输入图像的特征大小为3*224*224,输出特征图格式为512*1*1,然后将其进行全连接层处理后变成输入512张特征图,输出1000个预测结果,这个结果的种类太多,我们不需要使用这么多的预测类别,所以当下需要对其微调,调整最后输出时的全连接层输出结果个数及其全连接层中的权重参数。

2、冻结参数、更改输出特征

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from PIL import Image
import numpy as np


""" 将ResNet18模型迁移到食物分类项目中 """   # 残差网络是固定的网络结构,不需要自己来类定义
resnet_model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)  # 即调用了resnet18网络,又使用了训练好的模型
# weights=models.ResNet18_Weights.DEFAULT表示使用在 ImageNet 数据集上预先训练好的权重来初始化模型参数,可进入源代码查看
for param in resnet_model.parameters():  # 遍历模型的所有参数
    print(param)
    param.requires_grad = False
# 模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数

# 使得在反向传播过程中不会计算它们的梯度,以此减少模型的计算量,提高理速度。
in_feature = resnet_model.fc.in_features  # 获取原始模型全连接层fc的输入特征in_feature
resnet_model.fc = nn.Linear(in_feature,20)     # 创建一个全连接层,输入特征为in_features,输出为20,将其赋值给原先的全连接层

params_to_update = []
for param in resnet_model.parameters():   # 再次遍历模型的所有参数,
    if param.requires_grad == True:   # 判断模型参数的属性是否为需要更新,如果是,那么将其参数值增加到列表中,因为我们更改了全连接层,所以此处所有的参数为全连接层的参数
        params_to_update.append(param)

3、数据增强处理

data_transforms = {
    'train':    # 训练集
        transforms.Compose([  # 用来整合图片的数据增强处理
            transforms.Resize([300,300]),   # 将输入的图片尺寸缩放到300*300
            transforms.RandomRotation(45),   # 做数据增强,随机旋转-45-45度
            transforms.CenterCrop(224),   # 对图片做中心裁剪,裁剪为224*224大小
            transforms.RandomHorizontalFlip(p=0.5),   # 随机水平翻转,概率为0.5
            transforms.RandomVerticalFlip(p=0.5),   # 随机垂直翻转,概率为0.5
            # transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),   # 随机更改对比度、饱和度、、
            transforms.RandomGrayscale(p=0.1),   # 随机更改为灰度图,概率为0.1
            transforms.ToTensor(),   # 将图像格式转变为tensor类型
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])   # 对图像做归一化,指定均值,标准差
        ]),
    'valid':    # 验证集
        transforms.Compose([
            transforms.Resize([224,224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])

        ])
}

4、导入图像并打包

        1)文件内容

                train.txt文件内容如下:

                test.txt文件内容同上。

                其中地址分别对应相应的图片,尾部数字代表类别:

        2)代码部分
class food_dataset(Dataset):   # 继承Dataset,food_dataset是自己创建的类名称,可以改为你需要的名称
    def __init__(self,file_path,transform=None):   # 类的初始化,传入参数为图片地址及其标签,数据增强默认为None,解析数据文件txt
        self.file_path = file_path   # 图片地址及其标签传入self空间
        self.imgs = []   # 存放图片
        self.labels = []  # 存放图片标签
        self.transform = transform   # 数据增强
        with open(self.file_path) as f:  # 是把train.txt文件中图片的路径保存在 self.imgs,train.txt
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)   # 图像的路径
                self.labels.append(label)     # 标签,还不是tensor
# 初始化:把图片目录加载到self.

    def __len__(self):    # 类实例化对象后,可以使用len函数测量对象的个数
        return len(self.imgs)

    def __getitem__(self, idx):   # 关键,可通过idx索引的形式获取每一个图片数据及标签
        image = Image.open(self.imgs[idx])   # 读取到图片数据,还不是tensor
        if self.transform:    # 将pil图像数据转换为tensor
            image = self.transform(image)   # 图像处理为256*256,转换为tenor

        label = self.labels[idx]   # label还不是tensor
        label = torch.from_numpy(np.array(label,dtype=np.int64))
        return image,label   # 返回图片及其标签信息

# 传入训练集和测试集图片地址,分比对他们进行数据增强处理
training_data = food_dataset(file_path = './trainda.txt',transform = data_transforms['train'])
test_data = food_dataset(file_path ='./testda.txt',transform = data_transforms['valid'])

# 对返回的图片信息做打包处理,每64张打包成一份
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)

# 确定使用的设备是cpu还是GPU
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

5、损失函数、优化器、调整学习率

# 将上述微调的残差网络结构传入GPU
model = resnet_model.to(device)   # 为什么不需要加括号,resnet_model是一个对象而不是一个类

loss_fn = nn.CrossEntropyLoss()   # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer = torch.optim.Adam(params_to_update,lr=0.001)   # 仅训练部分参数,即params_to_update,其为上述全连接层的参数
# optimizer = torch.optim.Adam(resnet_model.parameters(),lr=0.001)   # 训练更新模型所有层参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)  # 调整学习率,每进行轮训练后,将学习率乘以0.5

6、定义训练集

def train(dataloader,model,loss_fn,optimizer):  # 传入打包好的数据,预定义的残差网络模型,损失函数,优化器
    model.train()   # 模型进行训练模式
    batch_size_num = 1
    for x,y in dataloader:  # 遍历每个打包的图片的信息及标签
        x,y = x.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPl
        pred = model.forward(x)   # 模型进行前向传播
        loss = loss_fn(pred, y)   # 通过交叉熵损失函数计算损失值Loss

        optimizer.zero_grad()  # 梯度值清零
        loss.backward()     # 反向传播计算得到每个参数的梯度
        optimizer.step()   # 根据梯度更新网络参数

        loss = loss.item()   # 获取损失值
        if batch_size_num % 100 == 0:   # 每100轮打印一次损失值和轮数
            print(f"loss: {loss:>7f}[number:{batch_size_num}]")
        batch_size_num += 1

7、定义测试集

bast_acc = 0
def test(dataloader, model,loss_fn):
    global bast_acc   # 定义全局变量
    size = len(dataloader.dataset)   # 返回所有的图片个数
    num_batches = len(dataloader)   # 返回打包的包个数
    model.eval()    # 模型进入测试模式
    test_loss,correct = 0,0   # 初始化总损失值和准确的总个数为0
    with torch.no_grad():   # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()时可以减少
        for x,y in dataloader:
            x,y= x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct +=(pred.argmax(1)== y).type(torch.float).sum().item()   # 判断预测结果是否等于真实值,返回布尔值,将其转换为0、1,然后求和,在转换为python标量

        test_loss /= num_batches
        correct /= size
        print(f"Test result:in Accuracy: {(100*correct)}%, Avg loss: {test_loss}")
        acc_s.append(correct)   # 将总准确个数传入列表acc_s
        loss_s.append(test_loss)  # 键总损失值传入列表loss_s

# 保存最优模型的前2种方法,模型扩展名一般为:py\pth\t7
    if correct > bast_acc:
        bast_acc = correct

8、传入参数进行训练和测试

epochs = 80   # 模型进行80轮训练,每次训练都会更新参数的值
acc_s = []
loss_s = []
for t in range(epochs):

    print(f"Epoch {t+1}\n---------------------------")
    train(train_dataloader,model,loss_fn,optimizer)   # 传入数据进行训练
    scheduler.step()   # 每一轮过后,记录轮数,然后调整学习率
    test(test_dataloader, model, loss_fn)   # 测试
print('最优训练结果',bast_acc)   # 打印最优准确率
        运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2168584.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

内网穿透的应用-Windows系统安装SeaFile并实现远程访问本地共享文件资料详细教程

文章目录 1. 前言2. SeaFile云盘设置2.1 Owncould的安装环境设置2.2 SeaFile下载安装2.3 SeaFile的配置 3. cpolar内网穿透3.1 下载安装3.2 Cpolar注册3.3 Cpolar云端设置3.4 Cpolar本地设置 4.公网访问测试5.结语 1. 前言 本文主要为大家介绍,如何使用两个简单软件…

如何使用ssm实现基于BS的库存管理软件设计与实现+vue

TOC ssm708基于BS的库存管理软件设计与实现vue 绪论 课题背景 身处网络时代,随着网络系统体系发展的不断成熟和完善,人们的生活也随之发生了很大的变化。目前,人们在追求较高物质生活的同时,也在想着如何使自身的精神内涵得到…

【Python报错已解决】ModuleNotFoundError: No module named ‘psutil’

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

【无人机设计与控制】基于改进蚁群算法的机器人_无人机_无人车_无人船的路径规划算法

摘要 改进的蚁群算法 (IACO) 通过结合启发式信息和自适应参数调节,优化了机器人、无人机、无人车和无人船的路径规划问题。本文对传统蚁群算法的局限性进行了分析,并提出了一种改进方法,提升了算法的收敛速度和全局搜索能力。通过实验对比&a…

三篇文章速通JavaSE到SpringBoot框架 (中) IO 进程线程 网络编程 XML MySQL JDBC相关概念与演示代码

文章目录 IOfile类的作用I/O的作用将上篇文章综合项目使用IO流升级所需知识点 进程 线程创建线程的三种方式 网络编程网络编程介绍IP地址端口号网络通信协议网络通信协议的分层演示代码 XMLXML的作用是什么?xml特点 注解什么是注解?注解的使用注解的重要…

STM32堆栈溢出Bug

可以看到x和buf交换位置后,x处于0x200006B0地址上是不会被函数B影响到的,实际上B函数对buf赋值的过程是出现了越界行为的,所以导致了x在buf地址之后的话会被意外修改掉值。

管易云·奇门和金蝶云星空接口打通对接实战

管易云奇门和金蝶云星空接口打通对接实战 对接源平台:管易云奇门 管易云是金蝶旗下专注提供电商企业管理软件服务的子品牌,先后开发了C-ERP、EC-OMS、EC-WMS、E店管家、BBC、B2B、B2C商城网站建设等产品和服务,涵盖电商业务全流程。 写入目标:金蝶云星空…

Python下利用Selenium获取动态页面数据

利用python爬取网站数据非常便捷,效率非常高,但是常用的一般都是使用BeautifSoup、requests搭配组合抓取静态页面(即网页上显示的数据都可以在html源码中找到,而不是网站通过js或者ajax异步加载的),这种类型…

【趣学Python算法100例】冒泡排序

问题描述 对N个整数(数据由键盘输入)进行升序排列。 问题分析 要整理一组相同类型的数,我们可以用一个叫数组的工具来存放它们。冒泡排序,就是通过一次次比较相邻的两个数并交换位置,让原本乱糟糟的数组变得井井有条…

Python画笔案例-066 绘制橙子

1、绘制橙子 通过 python 的turtle 库绘制 橙子,如下图: 2、实现代码 绘制 橙子,以下为实现代码: """橙子.py注意亮度为0.5的时候最鲜艳本程序需要coloradd模块支持,安装方法:pip install coloradd程序运行需要很长时间,请耐心等待。可以把窗口最小化,然后…

【源码+文档+调试讲解】汽车维修管理系统的设计与实现

摘 要 随着计算机技术的高速发展,现代计算机系统已经从以计算为中心向以信息化处理为中心的方向发展。而汽车维修,不仅需要在硬件上为现代社会的人们提供一个汽车维修的平台,获取汽车知识的环境,更要在软件上为车辆提供汽车维修的…

记一次京东自营广电流量卡踩坑

本文首发于只抄博客,欢迎点击原文链接了解更多内容。 前言 最近由于竞合,电信和联通的大流量卡都下架了,29 元的长期套餐流量最多只有 80G 了,想要长期大流量卡只剩下广电这一个选择了。光从套餐上来看 29 元 192G 的流量还是很诱…

Shell 脚本学习

Shell学习 Shell 脚本 Shell 是一个用 C 语言编写的程序,它是用户使用 Linux 的桥梁。Shell 既是一种命令语言,又是一种程序设计语言。 Shell 是指一种应用程序,这个应用程序提供了一个界面,用户通过这个界面访问操作系统内核的服…

安装了 cursor 之后,我写代码不用手了

最近新一代 AI 编程助手 cursor 爆火。 Cloudflare 副总裁家的 8 岁女儿在 45 分钟内用它搭起了一个聊天机器人。 这个女孩甚至不会编程,只是通过输入一些简单的 prompt 就完成了这样一个聊天机器人。 如果我们通过 RPA 或者智能体的方式,将语音直接转…

著名建筑物检测与识别系统源码分享

著名建筑物检测与识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comp…

c++算法第二天

温馨提示:本篇文章适合刚开始练算法的小白,大佬若见勿嘲 题目 题目解析 遇到0写两遍,非0写一遍,其余非零数右移即可 编写原理 第一步找到最后一个被复写的数 先根据题目所给的例子找到最后一次要复写的数字 20240923_142843 第…

【AI学习】Lilian Weng:Extrinsic Hallucinations in LLMs(LLM 的外在幻觉)

来自OpenAI 的 Lilian Weng的《Extrinsic Hallucinations in LLMs》 Date: July 7, 2024 | Estimated Reading Time: 30 min | Author: Lilian Weng 文章链接:https://lilianweng.github.io/posts/2024-07-07-hallucination/ 大概看了一下,这篇文章的核…

重新拉取maven-jar包

问题:经常会出现这种情况:一个项目重新打包之后,在另外一个项目中无法引用。可以尝试一下解决方式 1:右上角重新拉取: 2:清理所有缓存:idea-file-invalidate Caches 3:设置拉取方式&#xff…

【论文速看】DL最新进展20240926-图像分割、图像修复、CNN

目录 【图像分割】【图像修复】【CNN】 【图像分割】 [2024] CAD: Memory Efficient Convolutional Adapter for Segment Anything 论文链接:https://arxiv.org/pdf/2409.15889 代码链接:https://github.com/Kyyle2114/Convolutional-Adapter-for-Segme…

Linux防火墙-什么是防火墙

作者介绍:简历上没有一个精通的运维工程师。希望大家多多关注作者,下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 什么是防火墙 我们想象一下把每台服务器当成一个小区,我们去访问另外一个小区的朋友,我们需要经过什…