基于CNN+RNNs(LSTM, GRU)的红点位置检测(pytorch)

news2024/11/25 22:20:10

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

在这里插入图片描述
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

在这里插入图片描述

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

在之前尝试过纯 RNNs 检测红点,但是准确率感人,在噪声极低的情况下并不能精准识别位置。但是有次尝试transformer位置编码之后发现效果不错:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
Position embedding + GRU16.34035025.0/9000 (56%)
Position embedding + LSTM204.15511603.0/9000 (18%)

这说明模型的难点在于学习位置信息而不是寻找颜色有问题的点。联想到CNN也能提供位置信息,我决定尝试卷积一下的效果。

2 数据集

还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
在这里插入图片描述
加入噪音后,每个样本的预览如下图所示:

在这里插入图片描述

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
在这里插入图片描述
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

在这里插入图片描述

3 思路

其实思路特别朴素。就是在RNNs要读序列化数据之前先用CNN把数据跑一遍,让原始的输入序列变成具有局部特征表示的嵌入表示,卷积后提取的特征输入到 RNN层,RNN 保持了序列中的长时依赖信息。接下来先用 fc1 把 RNN 的输出映射成分数,然后用 fc2 预测三个具体位置,经过 Sigmoid 输出 [0, 1] 的相对位置,再与宽度相乘得到真实位置。具体的流程如下图所示:

在这里插入图片描述

4 结果

在图片长度为1080、低噪声环境时,对比实验的结果如下:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
CNN+GRU1419.5781601.0/9000 (7%)
CNN+LSTM1166.4599762.0/9000 (8%)

1080长度下图片抽样预测的效果如下:

在这里插入图片描述

在简单图片中的效果跟其他方法差距不大——基本都能准确定位红线,但是还是没办法做到像素级别的精确

在这里插入图片描述

可能是我的打开方式不对,但是CNN+RNN的效果并不如意。

从训练过程来看存在过拟合:

在这里插入图片描述

5 代码

CNN+GRU结构:


class CNN_GRU(nn.Module):
    def __init__(self, config):
        super(CNN_GRU, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device

        # CNN
        self.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)

        self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.6)

        self.fc1 = nn.Sequential(
            nn.Linear(self.hidden_size * 2, 1)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(config.width, 3),  # predict 3 points
            nn.Sigmoid(),
        )

        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        x = x.squeeze(2)
        x = F.relu(self.conv1(x))  # (batch_size, 64, width)
        x = F.relu(self.conv2(x))  # (batch_size, 128, width)
        x = F.relu(self.conv3(x))  # (batch_size, input_size, width)
      	
      	x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.gru(x0, h0)
        
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)
        predicted_positions = self.fc2(scores)
        
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

CNN+LSTM结构:

class CNN_GRU(nn.Module):
    def __init__(self, config):
        super(CNN_GRU, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device
        # CNN
        self.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)
        self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.6)
        self.fc1 = nn.Sequential(
            nn.Linear(self.hidden_size * 2, 1)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(config.width, 3),  # predict 3 points
            nn.Sigmoid(),
        )
        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        x = x.squeeze(2)
        x = F.relu(self.conv1(x))  # (batch_size, 64, width)
        x = F.relu(self.conv2(x))  # (batch_size, 128, width)
        x = F.relu(self.conv3(x))  # (batch_size, input_size, width)
      
      	x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.lstm(x, (h0, c0))
        
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)
        predicted_positions = self.fc2(scores)
        
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

路过的大佬有什么建议 ball ball 在评论区打出来,我会去尝试~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2247521.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端:JavaScript (学习笔记)【2】

目录 一,数组的使用 1,数组的创建 [ ] 2,数组的元素和长度 3,数组的遍历方式 4,数组的常用方法 二,JavaScript中的对象 1,常用对象 (1)String和java中的Stri…

全面解析多种mfc140u.dll丢失的解决方法,五种方法详细解决

当你满心期待地打开某个常用软件,却突然弹出一个错误框,提示“mfc140u.dll丢失”,那一刻,你的好心情可能瞬间消失。这种情况在很多电脑用户的使用过程中都可能出现。无论是游戏玩家还是办公族,面对这个问题都可能不知所…

STM32总体架构简单介绍

目录 一、引言 二、STM32的总体架构 1、三个被动单元 (1)内部SRAM (2)内部闪存存储器 (3)AHB到APB的桥(AHB to APBx) 2、四个主动(驱动)单元 &#x…

【PHP】 环境以及插件的配置,自学笔记(一)

文章目录 环境的准备安装 XAMPPWindowMacOS 配置开发环境Vscode 关于 PHP 的插件推荐Vscode 配置 php 环境Apache 启动Hello php配置热更新 参考 环境的准备 下载 XAMPP , 可以从 官网下载 https://www.apachefriends.org/download.html 安装 XAMPP XAMPP 是一个跨平台的集成开…

跟着问题学5——深度学习中的数据集详解(1)

深度学习数据集的创建与读取 数据 (计算机术语) 数据(data)是事实或观察的结果,是对客观事物的逻辑归纳,是用于表示客观事物的未经加工的的原始素材。 数据可以是连续的值,比如声音、图像,称为模拟数据。…

实验-Linux文件系统和磁盘管理

操作1 远程连接Linux系统 下述连接方式2选一即可。 使用xshell工具连接Linux系统。打开xshell,新建连接,将主机ip修改为实际Linux系统的ip(ifconfig命令查看),可以新建多个xshell会话,使用不同的用户名登录,方便后续…

GPTZero:高效识别AI生成文本,保障学术诚信与内容原创性

产品描述 GPTZero 是一款先进的AI文本检测工具,专为识别由大型语言模型(如ChatGPT、GPT-4、Bard等)生成的文本而设计。它通过分析文本的复杂性和一致性,判断文本是否可能由人类编写。GPTZero 已经得到了超过100家媒体机构的报道&…

Apple Vision Pro开发003-PolySpatial2.0新建项目

unity6.0下载链接:Unity 实时开发平台 | 3D、2D、VR 和 AR 引擎 一、新建项目 二、导入开发包 com.unity.polyspatial.visionos 输入版本号 2.0.4 com.unity.polyspatial(单独导入),或者直接安装 三、对应设置 其他的操作与之前的版本相同…

百度在下一盘大棋

这两天世界互联网大会在乌镇又召开了。 我看到一条新闻,今年世界互联网大会乌镇峰会发布“2024 年度中国互联网企业创新发展十大典型案例”,百度文心智能体平台入选。 这个智能体平台我最近也有所关注,接下来我就来讲讲它。 百度在下一盘大棋…

038集——quadtree(CAD—C#二次开发入门)

效果如下: using Autodesk.AutoCAD.ApplicationServices; using Autodesk.AutoCAD.DatabaseServices; using Autodesk.AutoCAD.EditorInput; using Autodesk.AutoCAD.Geometry; using System; using System.Collections.Generic; using System.Linq; using System.T…

深入探讨 Puppeteer 如何使用 X 和 Y 坐标实现鼠标移动

背景介绍 现代爬虫技术中,模拟人类行为已成为绕过反爬虫系统的关键策略之一。无论是模拟用户点击、滚动,还是鼠标的轨迹移动,都可以为爬虫脚本带来更高的“伪装性”。在众多的自动化工具中,Puppeteer作为一个无头浏览器控制库&am…

RabbitMQ2:介绍、安装、快速入门、数据隔离

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

Linux 下进程基本概念与状态

文章目录 一、进程的定义二、 描述进程-PCBtask_ struct内容分类 三、 进程状态 一、进程的定义 狭义定义:进程是正在运行的程序的实例(an instance of a computer program that is being executed)。广义定义:进程是一个具有一定…

数据库MYSQL——表的设计

文章目录 前言三大范式:几种实体间的关系:一对一关系:一对多关系:多对多关系: 前言 之前的博客中我们讲解的是关于数据库的增删改查与约束的基本操作, 是在已经创建数据库,表之上的操作。 在实…

开源动态表单form-create-designer 扩展个性化配置的最佳实践教程

在开源低代码表单设计器 form-create-designer 的右侧配置面板里,field 映射规则为开发者提供了强大的工具去自定义和增强组件及表单配置的显示方式。通过这些规则,你可以简单而高效地调整配置项的展示,提升用户体验。 源码地址: Github | G…

面试:请阐述MySQL配置文件my.cnf中参数log-bin和binlog-do-db的作用

大家好,我是袁庭新。星球里的小伙伴去面试,面试官问:MySQL配置文件my.cnf中参数log-bin和binlog-do-db的作用?一脸懵逼~不知道该如何回答。 在MySQL的配置文件my.cnf中,log-bin和binlog-do-db是与二进制日志…

Java语言编程,通过阿里云mongo数据库监控实现数据库的连接池优化

一、背景 线上程序连接mongos超时,mongo监控显示连接数已使用100%。 java程序报错信息: org.mongodb.driver.connection: Closed connection [connectionId{localValue:1480}] to 192.168.10.16:3717 because there was a socket exception raised by…

Hive离线数仓结构分析

Hive离线数仓结构 首先,在数据源部分,包括源业务库、用户日志、爬虫数据和系统日志,这些都是数据的源头。这些数据通过Sqoop、DataX或 Flume 工具进行提取和导入操作。这些工具负责将不同来源的数据传输到基于 Hive 的离线数据仓库中。 在离线…

MySQL1——基本原理和基础操作

文章目录 Mysql数据库1——基本原理和基础操作1. 基本概念2. Mysql体系结构2.1 连接层2.2 服务层2.3 存储引擎层 3. 三级范式与反范式4. 完整性约束4.1 实体完整性约束4.2 参照完整性约束 5. CRUDDDLDMLDCLDQL 6. 高级查询基础查询条件查询分页查询查询结果排序分组聚合查询联表…

【Ubuntu24.04】服务部署(虚拟机)

目录 0 背景1 安装虚拟机1.1 下载虚拟机软件1.2 安装虚拟机软件1.2 安装虚拟电脑 2 配置虚拟机2.1 配置虚拟机网络及运行初始化脚本2.2 配置服务运行环境2.2.1 安装并配置JDK172.2.2 安装并配置MySQL8.42.2.3 安装并配置Redis 3 部署服务4 总结 0 背景 你的服务部署在了你的计算…