《迁移学习》—— 将 ResNet18 模型迁移到食物分类项目中

news2024/11/18 9:26:59

文章目录

  • 一、迁移学习的简单介绍
    • 1.迁移学习是什么?
    • 2.迁移学习的步骤
  • 二、数据集介绍
  • 三、代码实现
    • 1. 步骤
    • 2.所用到方法介绍的文章链接
    • 3. 完整代码

一、迁移学习的简单介绍

1.迁移学习是什么?

  • 迁移学习是指利用已经训练好的模型,在新的任务上进行微调。
  • 迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2.迁移学习的步骤

  • (1) 选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
  • (2) 冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
  • (3) 在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
  • (4) 微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
  • (5) 评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、数据集介绍

  • 下图是数据集的结构
    • 在 food_dataset2 文件夹下含有训练数据和测试数据
    • 训练集和测试集数据中都含有 20 种食物图片,数量在200~400不等
    • trainda.txt 和 testda.txt 文本中存放了每张图片的路径及标签,用 0~19 这20个数字分别对20种食物进行标签
    • 在代码中通过trainda.txt 和 testda.txt 文本中的内容来获取每张图片及对应的标签
      在这里插入图片描述
    • 下面是trainda.txt文本中的部分内容(testda.txt 中的内容格式相同)
      在这里插入图片描述
  • 送福利!!! 私信送此数据集 !!!

三、代码实现

1. 步骤

  • 1.调用resnet18模型,并保存需要训练的模型参数
  • 2.定义一个图像预处理和数据增强字典
  • 3.定义获取每张食物图片和标签的类方法
  • 4.获取训练集和测试集数据
  • 5.对数据集进行打包
  • 6.调用交叉熵损失函数并创建优化器
  • 7.定义训练模型的函数
  • 8.定义测试模型的函数
  • 9.训练模型,并每训练一轮测试一次

2.所用到方法介绍的文章链接

  • ResNet 残差网络神经网络
    • https://blog.csdn.net/weixin_73504499/article/details/142575775?spm=1001.2014.3001.5501
  • 数据增强
    • https://blog.csdn.net/weixin_73504499/article/details/142499263?spm=1001.2014.3001.5501
  • 调整学习率
    • https://blog.csdn.net/weixin_73504499/article/details/142526863?spm=1001.2014.3001.5501

3. 完整代码

import torch
	import torchvision.models as models  # 导入存有各种深度学习模型的模块
	from torch import nn  # 导入神经网络模块
	from torch.utils.data import Dataset, DataLoader  # Dataset: 抽象类,一种用于获取数据的方法  DataLoader:数据包管理工具,打包数据
	from torchvision import transforms  # transforms模块提供了一系列用于图像预处理和数据增强的函数和类
	from PIL import Image  # 用于处理图片
	import numpy as np
	
	""" 调用resnet18模型 """
	resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
	
	for param in resnet_model.parameters():
	    param.requires_grad = False
	# 模型所有参数(即权重和偏差)的 requires_grad 属性设置成 False,从而冻结所有模型参数
	# 使得在反向传播过程中不会计算他们的梯度,从此减少模型的计算量,提高推理速度
	
	in_features = resnet_model.fc.in_features  # 获取resnet18模型全连接层原输入的特征个数
	resnet_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层输入特征个数为: in_features  输出特征个数为:数据集中事务的种类数量
	
	params_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数
	for param in resnet_model.parameters():
	    if param.requires_grad == True:
	        params_to_update.append(param)
	
	""" 图像预处理和数据增强 """
	data_transforms = {
	    'train':
	        transforms.Compose([
	            transforms.Resize([300, 300]),
	            transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机
	            transforms.CenterCrop(224),  # 中心裁剪
	            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平反转 选择一个概率
	            transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
	            # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 亮度、对比度
	            transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R G B
	            transforms.ToTensor(),  # 转化为神经网络可以识别的 Tensor 类型
	            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 对图片数据进行归一化,[均值],[标准差]
	        ]),
	    'valid':
	        transforms.Compose([
	            transforms.Resize([224, 224]),
	            transforms.ToTensor(),
	            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	        ])
	}
	
	""" 定义获取每张食物图片和标签的类方法 """
	
	
	class food_dataset(Dataset):
	    def __init__(self, file_path, transform=None):
	        self.file_path = file_path
	        self.imgs = []
	        self.labels = []
	        self.transform = transform
	        with open(self.file_path) as f:
	            samples = [x.strip().split(' ') for x in f.readlines()]
	            for img_path, label in samples:
	                self.imgs.append(img_path)
	                self.labels.append(label)
	
	    def __len__(self):
	        return len(self.imgs)
	
	    def __getitem__(self, idx):
	        image = Image.open(self.imgs[idx])
	        if self.transform:
	            image = self.transform(image)
	
	        label = self.labels[idx]
	        label = torch.from_numpy(np.array(label, dtype=np.int64))
	        return image, label
	
	
	""" 获取训练集和测试集数据 """
	training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
	test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])
	
	""" 对数据集进行打包 """
	train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 64张图片为一个包,shuffle --> 打乱顺序
	test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
	
	""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """
	device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
	print(f"Using {device} device")
	
	# 把模型传入到 gpu 或 cpu
	model = resnet_model.to(device)
	
	""" 调用交叉熵损失函数 """
	loss_fn = nn.CrossEntropyLoss()
	
	"""" 创建优化器并调整优化器中的学习率--> lr """
	optimizer = torch.optim.Adam(params_to_update, lr=0.001)
	scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
	
	""" 定义训练模型的函数 """
	
	
	def train(dataloader, model, loss_fn, optimizer):
	    model.train()  # 告诉模型,开始训练
	    # pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
	    # 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.
	    for X, y in dataloader:
	        X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或gpu
	        pred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值
	        loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值 loss
	
	        optimizer.zero_grad()  # 梯度值清零
	        loss.backward()  # 反向传播计算得到每个参数的梯度值w
	        optimizer.step()  # 根据梯度更新网络w参数
	
	
	""" 定义测试模型的函数 """
	
	best_acc = 0  # 用于更新准确率
	
	
	def test(dataloader, model, loss_fn):
	    global best_acc
	    size = len(dataloader.dataset)
	    num_batches = len(dataloader)
	    model.eval()  # 测试,w就不能再更新
	    test_loss, correct = 0, 0
	    with torch.no_grad():
	        for X, y in dataloader:
	            X, y = X.to(device), y.to(device)
	            pred = model.forward(X)
	            test_loss += loss_fn(pred, y).item()  # test_loss是会自动累加每一个批次的损失值
	            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # correct是会自动累加每一个批次的正确率
	    test_loss /= num_batches  # 平均的损失值
	    correct /= size  # 平均的正确率
	    print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")
	
	    # 找到最好的准确率
	    if correct > best_acc:
	        best_acc = correct
	
	
	""" 定义模型训练的轮数,并每训练一轮测试一次 """
	epochs = 30
	for e in range(epochs):
	    print(f"Epoch {e + 1}\n---------------------------")
	    train(train_dataloader, model, loss_fn, optimizer)
	    scheduler.step()  # 在每个epoch的训练中,使用scheduler.step()语句进行学习率更新
	    test(test_dataloader, model, loss_fn)
	print('最优的训练结果为:', best_acc)
  • 结果如下
    • 此结果只是训练了30轮后的结果,可以训练更多轮,最后的准确率还会有所提高
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2175482.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛顿迭代法求解x 的平方根

牛顿迭代法是一种可以用来快速求解函数零点的方法。 为了叙述方便,我们用 C C C表示待求出平方根的那个整数。显然, C C C的平方根就是函数 f ( x ) x c − C f(x)x^c-C f(x)xc−C 的零点。 牛顿迭代法的本质是借助泰勒级数,从初始值开始快…

【软件测试】最新Linux大全(超详细!超级全!)

目录 前言1. 操作系统是干什么的2. Linux 是什么3. 为什么要学习 Linux4. Linux 发行版本5. Linux 系统特点6. Linux 安装7. Linux 系统启动8. Linux 操作方式9. Shell 与命令10. 命令格式 一、 Linux终端命令格式1. 终端命令格式2. 查阅命令帮助信息 二、 常用Linux命令的基本…

项目计划软件如何助力企业策略规划和执行监控

项目管理软件助力任务、时间和协作管理,如ZohoProjects集成了任务管理、时间跟踪、协作工具等功能,提高性价比,适合不同规模团队。其简化流程、专业度高,成为企业提升效率的重要工具。 一、项目计划软件的由来 项目计划软件的历史…

暴雨受邀出席2024 AI大模型生态算力峰会

9月25日,2024 AI大模型生态暨算力峰会在北京国家会议中心正式开幕,AI行业头部厂家、业界专家及人工智能行业精英齐聚一堂,暴雨华北大区产品总监丁海受邀出席并发表演《用AI奔赴新质生产力》的主题演讲,深度诠释了人工智能如何驱动…

解开BL锁之后如何安装模块及安装注意事项

本文是在解开BL锁的前提下进行的。 解开BL锁请参考:出厂非澎湃OS手机解BL锁 本文 参考: Magisk中文网 Magisk资源分享 ROM基地 我安装了这几个模块,切记先按照救砖模块。 解开BL锁之后,需要将下载系统ROM包提取boot.img。 目前我知道的又…

基于云开发进行快速搭建企业智能名片小程序

如何基于云开发进行快速搭建企业智能名片小程序? 首先,需要注册一个小程序账号,获取AppID。如果还不知道怎么注册的朋友,可以去看我前面写的那篇教程,有比较详细的注册步骤图文教程。 复制AppID,打开开发者…

基于SpringBoot+Vue+MySQL的旅游管理系统

系统展示 用户前台界面 管理员后台界面 系统背景 随着旅游业的蓬勃发展,传统的旅游信息查询与订票方式已难以满足现代游客的多元化需求。为了提升用户体验,提高旅游管理的效率,我们开发了基于SpringBootVueMySQL的旅游管理系统。该系统旨在通…

大模型微调4:Alpaca模型微调、Adalora、Qlora

Alpaca模型微调: 整个pipeline 1. 主流底座:Candidate 中文:YI-34B 英文:LLama,mistral 2. 验证: 我们自己的Instructoin data 通用的Instruction data(适合我们场景的) 3. 收集…

kubernetes存储入门(kubernetes)

实验环境依旧是三个节点拉取镜像,然后在master节点拉取资源清单: 然后同步会话,导入镜像; 存储入门 ConfigMap volume卷--》volumemount(挂载卷) Glusterfs NFS ISCSI HostPath ConfigMap Secret E…

acw(树的重心)

给定一颗树,树中包含 n𝑛 个结点(编号 1∼n1∼𝑛)和 n−1𝑛−1 条无向边。 请你找到树的重心,并输出将重心删除后,剩余各个连通块中点数的最大值。 重心定义:重心是指树…

基于SSM的“在线汽车交易系统”的设计与实现(源码+数据库+文档+开题报告)

基于SSM的“在线汽车交易系统”的设计与实现(源码数据库文档开题报告) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统总体设计图 首页 新闻信息 用户注册 后台登录界面…

从0学习React(2)

经过上一篇的文章,对index.tsx文件的每行代码进行了一个简单的分析之后,我大概对React有了一个简单的了解。虽然也是一知半解,但是起码在心里已经对React有了一个基本的概念。这篇文章,我就讲一下关于React中index.tsx的大致框架。…

Metahuman sdk官方 AI驱动口型蓝图优化

combo stream ATL stream ( audio to lip sync) 以上时实时驱动口型 非实时驱动口型可以在metahuman blueprint里直接加上talk component,实现聊天/回复功能。 Talk sound可以放自己的声音/ talk chat是回复你输入的message和你聊天/ talk text是念出你输入的me…

828华为云征文|部署个人知识管理系统 SiyuanNote

828华为云征文|部署个人知识管理系统 SiyuanNote 一、Flexus云服务器X实例介绍二、Flexus云服务器X实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置2.4 Docker 环境搭建 三、Flexus云服务器X实例部署 SiyuanNote3.1 SiyuanNote 介绍3.2 SiyuanNote 部署3.3 Siyua…

Awcing 799. 最长连续不重复子序列

Awcing 799. 最长连续不重复子序列 解题思路: 让我们找到一个数组中,最长的 不包含重复的数 的连续区间的长度。 最优解是双指针算法: 我们用 c n t [ i ] cnt[i] cnt[i]记录 i i i 这个整数在区间内出现的次数。(因为每个数的大小为 1 0 5 10^5 105, …

报数游戏 - 华为OD统一考试(E卷)

2024华为OD机试(E卷D卷C卷)最新题库【超值优惠】Java/Python/C合集 题目描述 100个人围成一圈,每个人有一个编号,编号从1开始到100。他们从1开始依次报数,报到为M的人自动退出圈圈,然后下一个人接着从1开始…

数据链路层 ——MAC

目录 MAC帧协议 mac地址 以太网帧格式 ARP协议 ARP报文格式​编辑 RARP 其他的网络服务或者协议 DNS ICMP协议 ping traceroute NAT技术 代理服务器 网络层负责规划转发路线,而链路层负责在网络节点之间的转发,也就是"一跳"的具体传输…

ubuntu18.04 Anconda安装及使用

1、安装Anaconda 1)下载: 下载链接:https://www.anaconda.com/download#downloads 点击图中Free Download,登录并下在 下载对应版本 2)安装 sudo bash Anaconda3-2024.06-1-Linux-x86_64.sh输入后,直接回车安装。 出…

NSSCTF [HNCTF 2022 WEEK2]e@sy_flower

将文件拖入ida 就看到很显眼的花指令 对着jmp指令nop掉 将main函数按p定义 F5查看伪c代码 思路就是输入的flag先互换位置,再与0x30异或。 int __cdecl __noreturn main(int argc, const char **argv, const char **envp) {signed int v3; // 存储临时值int i; // 循…

栏目二:Echart绘制动态折线图+柱状图

栏目二:Echart绘制动态折线图+柱状图 配置了一个ECharts图表,该图表集成了数据区域缩放、双Y轴显示及多种图表类型(折线图、柱状图、象形柱图)。图表通过X轴数据展示,支持平滑折线展示比率数据并自动添加百分比标识,柱状图以渐变色展示评论数量,而象形柱图则以矩形形式展…