基于迁移学习的手势分类模型训练

news2024/9/23 11:27:46

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation

# 定义数据预处理和增强器
transform = Compose([
    RandomHorizontalFlip(),  # 随机水平翻转
    RandomRotation(10),      # 随机旋转10度
    Resize((224, 224)),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考


# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了

# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):
    for epoch in range(num_epochs):
        model.train()  # 设置模型为训练模式
        train_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = train_loss / total
        epoch_acc = 100. * correct / total

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')


# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):

    image = Image.open(image_path).convert("RGB")
    # 对测试的图片进行预处理,需要和训练时处理的方式一样
    transform = Compose([
        Resize((224, 224)),
        CenterCrop(224),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image_tensor = transform(image).unsqueeze(0)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    image_tensor = image_tensor.to(device)


    model = torch.load(model_path,map_location=device)
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output.data, 1)  # 获得分类标记
    return predicted.item()
if __name__=="__main__":
    image_path = "test2/6.jpg"
    print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1950508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【性能优化】在大批量数据下使用 HTML+CSS实现走马灯,防止页面卡顿(二)

上一篇只是简单演示了’下一张’的操作和整体的设计思路,这两天把剩余功能补全了,代码经过精简,可封装当成轮播组件使用,详细如下. 代码 <template><div class"container"><button click"checkNext(last)">上一张</button><b…

C++之栈和队列使用及模拟实现

目录 栈的使用 队列的使用 栈的模拟实现 队列的模拟实现 deuqe容器介绍 在C语言中我们已经学习了栈和队列的相关性质&#xff0c;今天我们主要来学习C语法中栈和队列的相关概念。 栈的使用 在C中栈是一种容器适配器&#xff0c;在其内部适配了其它的容器&#xff0c;其相…

go程序在windows服务中优雅开启和关闭

本篇主要是讲述一个go程序&#xff0c;如何在windows服务中优雅开启和关闭&#xff0c;废话不多说&#xff0c;开搞&#xff01;&#xff01;&#xff01;   使用方式&#xff1a;go程序 net服务启动 Ⅰ 开篇不利 Windows go进程编译后&#xff0c;为一个.exe文件,直接执行即…

使用api 调试接口 ,配置 Header 、 body 远程调试 线上接口

学习目标&#xff1a; 目标 使用api 调试接口 &#xff0c;配置 Header 、 body 远程调试 线上接口 学习内容&#xff1a; 内容 设置请求方式 2. 选择 POST 提交 3.设置 Header 一般默认的 4个 header 属性就可以直接使用&#xff0c;如有特殊情况&#xff0c;需进行属性设…

Docusaurus VS VuePress:哪一个更适合你的技术文档?

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

springcloud接入seata管理分布式事务

下载安装包 链接: seata 配置seata-server 文件上传Linux解压 压缩包我放在/usr/local/seata中 tar -zxvf seata-server-2.0.0.tar.gz修改配置文件 设置nacos为注册和配置中心 进入文件夹 cd /usr/local/seata/seata/conf修改application.yml文件 ...... ...... cons…

关键词查找【Aho-Corasick 算法】

【全程干货】程序员必备算法&#xff01;AC自动机算法敏感词匹配算法&#xff01;动画演示讲解&#xff0c;看完轻松掌握&#xff0c;面试官都被你唬住&#xff01;&#xff01;_哔哩哔哩_bilibili 著名的多模匹配算法 引入依赖&#xff1a; <dependency><groupId>…

ICMPv6与DHCPv6之网络工程师软考中级

ICMPv6概述 ICMPv6是IPv6的基础协议之一。 在IPv6报文头部中&#xff0c;Next Header字段值为58则对应为ICMPv6报文。 ICMPv6报文用于通告相关信息或错误。 ICMPv6报文被广泛应用于其它协议中&#xff0c;包括NDP、Path MTU发现机制等 ICMPv6控制着IPv6中的地址自动配置、地址…

前端知识--前端访问后端技术Ajax及框架Axios

一、异步数据请求技术----Ajax Ajax是前端访问后端的技术&#xff0c;为异步请求&#xff08;不刷新页面&#xff0c;请求数据&#xff0c;只更新局部数据&#xff09;。 例如&#xff1a;在京东网站中搜索电脑&#xff0c;就会出现一些联想搜索&#xff0c;但此时页面并没有…

AI行业合适做必应bing推广吗?怎么开户呢?

快速发展的AI行业中&#xff0c;有效的市场获客渠道是关键&#xff0c;随着数字营销领域的不断演进&#xff0c;必应Bing以其独特的市场定位、庞大的用户基础和高效的广告系统&#xff0c;成为AI企业推广策略中的重要一环。特别是针对那些寻求精准触达、高效转化的AI企业而言&a…

2024国际燃气轮机运维周线上分享第一期开启!共探燃机新生态

为促进国内重型燃气轮机运维技术发展&#xff0c;加快建立独立自主的燃气轮机运维技术体系&#xff0c;2024国际燃气轮机运维大会将于2024年10月17-18日在中国广州盛大召开&#xff01; 2024国际燃气轮机运维大会将通过线上直播会议、线下技术分享及颁奖典礼等形式展开&#xf…

血泪史!ora-00600 16305报错解决过程

一个客户重启操作系统后数据库启动不了,检查日志发现报错ORA-00600 [16305] 在MOS中找了一下,发现说是loopback地址不通: 测试了一下ping 127.0.0.1不通. 再次多次尝试发现登录到服务器上面,在本机上ping 127 localhost 本机实际地址 都不通,但是其它服务器可以ping通他的…

W30-python03-迭代器和生成器

迭代器&#xff1a;迭代是Python最强大的功能之一&#xff0c;是访问集合元素的一种方式。迭代器对象从集合的第一个元素开始访问&#xff0c;直到所有的元素被访问完结束。迭代器只能往前不会后退。 迭代器有两个基本的方法&#xff1a;iter() 和 next()。 生成器&#xff1…

Godot入门 05收集物品

创建新场景&#xff0c;添加Area2D节点&#xff0c;AnimatedSprite2D节点 &#xff0c;CollisionShape2D节点 添加硬币 按F键居中&#xff0c;放大视图。设置动画速度设为10FPS&#xff0c;加载后自动播放&#xff0c;动画循环 碰撞形状设为圆形&#xff0c;修改Area2D节点为Co…

Java---后端文件上传详解

袁门才俊志高远&#xff0c; 震古烁今意决然。 风采翩翩才情显&#xff0c; 雄姿英发立世间。 目录 一&#xff0c;简单案例演示 二&#xff0c;服务器本地存储 三&#xff0c;配置单个文件上传大小限制 一&#xff0c;简单案例演示 首先简单编写一个前端网页&#xff1a; &l…

scrapy 爬取旅游景点相关数据(一)

第一节 Scrapy 练习爬取穷游旅游景点 配套视频可以前往B站&#xff1a;https://www.bilibili.com/video/BV1Vx4y147wQ/?vd_source4c338cd1b04806ba681778966b6fbd65 本项目为scrapy 练手项目&#xff0c;爬取的是穷游旅游景点列表数据 0 系统的环境 现在网上可以找到很多scr…

java基础概念05-运算符

一、自增自减运算符 二、赋值运算符 2-1、注意 三、关系运算符 四、逻辑运算符 4-1、短路逻辑运算符 五、三元运算符 六、运算符的优先级

PostgreSQL 中如何重置序列值:将自增 ID 设定为特定值开始

我是从excel中将数据导入&#xff0c;然后再通过sql插入数据&#xff0c;就报错。 需要设置自增ID开始值 1、确定序列名称&#xff1a; 首先&#xff0c;需要找到与的增字段相关的序列名称。假设表名是 my_table 和自增字段是 id&#xff0c;可以使用以下查询来获取序列名称…

嵌入式C++、Raspberry Pi、LoRa和Wi-Fi技术、TensorFlow、ROS/ROS2:农业巡检数据导航机器人设计流程(代码示例)

随着科技的不断进步&#xff0c;农业领域也在逐渐向智能化发展。农业巡检机器人作为农业智能化的重要组成部分&#xff0c;能够自动化地监测农作物生长状况&#xff0c;提高农业管理的效率和精确度。本文将介绍一个基于Raspberry Pi和NVIDIA Jetson的农业巡检机器人&#xff0c…

华天动力OA downloadWpsFile接口处任意文件读取漏洞复现 [附POC]

文章目录 华天动力OA downloadWpsFile接口处任意文件读取漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现华天动力OA downloadWpsFile接口处任意文件读取漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内…