计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

news2025/1/11 2:29:01

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用,本文我们借助PyTorch,快速构建和训练卷积神经网络(CNN)等模型,以实现街道房屋号码的准确识别。引入并注意力机制,它是一种模仿人类视觉注意机制的方法,在图像处理任务中具有广泛应用。通过引入注意力机制,模型可以自动关注图像中与房屋号码相关的区域,提高识别的准确性和鲁棒性。

一、项目介绍

街道房屋号码识别是计算机视觉中的一个重要任务,通过对街道房屋号码的自动识别,可以对街道图像进行更好的理解和分析。本文将介绍如何使用PyTorch框架和注意力机制,结合SVHN数据集,来实现街道房屋号码的分类识别。

二、SVHN数据集

SVHN(Street View House Numbers)是一个公开的大规模街道数字图像数据集。该数据集包含了从Google Street View中获取的房屋门牌号码图像,可以用于训练和测试机器学习模型,以实现自动识别街道房屋号码的任务。

2.1 数据集下载和加载

首先,我们需要下载并加载SVHN数据集。在PyTorch中,我们可以使用torchvision库中的datasets模块来实现这一步。

数据集的下载与查看:

train_dataset = datasets.SVHN(root='./data', split='train', download=True)

images = train_dataset.data[:10]  # shape: (10, 3, 32, 32)
labels = train_dataset.labels[:10]

images = np.transpose(images, (0, 2, 3, 1))

# Plot the images
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
axs = axs.ravel()

for i in range(10):
    axs[i].imshow(images[i])
    axs[i].set_title(f"Label: {labels[i]}")
    axs[i].axis('off')

plt.tight_layout()
plt.show()

在这里插入图片描述

数据集的加载,预处理,便于输入模型训练:

import torch
from torchvision import datasets, transforms

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# 下载并加载SVHN数据集
trainset = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = datasets.SVHN(root='./data', split='test', download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

三、卷积网络搭建

使用PyTorch搭建卷积神经网络。卷积神经网络(Convolutional Neural Network, CNN)是一种主要用于处理具有类似网格结构的数据的神经网络,如图像(2D网格的像素点)或者文本(1D网格的单词)。

3.1 网络结构定义

下面是一个基础的卷积神经网络模型,包含两个卷积层、两个最大池化层和两个全连接层。

from torch import nn

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        return self.fc2(out)

四、加入注意力机制

注意力机制是一种能够改进模型性能的技术。在我们的模型中,我们将添加一个注意力层来帮助模型更好地专注于输入图像中的重要部分。

4.1 注意力层定义

我将实现基本的注意力层,这个层将会生成一个和输入同样大小的注意力图,然后将输入和这个注意力图对应元素相乘,以此来实现对输入的加权。

注意力机制层的数学原理:
注意力机制的数学原理可以用以下公式表示:

给定输入张量 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w,其中 b b b 是批量大小, c c c 是通道数, h h h 是高度, w w w 是宽度。注意力机制分为两个阶段:特征提取和特征加权。
1.特征提取阶段:
首先,通过自适应平均池化层(AdaptiveAvgPool2d)将输入张量 x x x 在高度和宽度上进行平均池化,得到形状为 b × c × 1 × 1 b \times c \times 1 \times 1 b×c×1×1 的张量 y y y。这里使用自适应平均池化是为了使得张量 y y y 在不同尺寸的输入上也能产生相同的输出。
2.特征加权阶段:
接下来,通过全连接层(Linear)和非线性激活函数ReLU对张量 y y y 进行特征变换,减少通道数,并保留重要特征。然后再通过另一个全连接层和Sigmoid激活函数得到权重张量 y ′ ∈ R b × c × 1 × 1 y' \in \mathbb{R}^{b \times c \times 1 \times 1} yRb×c×1×1,表示每个通道的权重值。这里的权重值在0到1之间,用于控制每个通道在后续的计算中所占的比重。将权重张量 y ′ y' y 扩展成与输入张量 x x x 相同的形状,并将其与输入张量相乘,得到经过注意力加权的特征张量。这样就实现了对输入张量的自适应特征加权。

数学表示为:
y = AdaptiveAvgPool2d ( x ) y ′ = Sigmoid ( Linear ( ReLU ( Linear ( y ) ) ) ) output = x ⊙ y ′ y = \text{AdaptiveAvgPool2d}(x) \\ y' = \text{Sigmoid}(\text{Linear}(\text{ReLU}(\text{Linear}(y)))) \\ \text{output} = x \odot y' y=AdaptiveAvgPool2d(x)y=Sigmoid(Linear(ReLU(Linear(y))))output=xy

其中 ⊙ \odot 表示按元素相乘操作。

注意力机制层的搭建代码:

class AttentionLayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(AttentionLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel// reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

4.2 在网络中加入注意力层

我们将注意力层加入到ConvNet模型中:

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(32))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            AttentionLayer(64))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(8 * 8 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        return self.fc2(out)

五、模型训练与测试

接下来,我们将进行模型的训练和测试。

5.1 模型训练

import torch.optim as optim

model = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5.2 模型测试

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

六、结论

这篇文章就像是一张奇妙的地图,引领你进入计算机视觉任务的神奇世界。在这个世界里,你将与PyTorch和注意力机制这两位强大的伙伴结伴前行,共同探索街道房屋号码识别的奥秘。

想象一下,你置身于繁忙的街道上,满目琳琅的房屋号码挑战着你的视力。而你却拥有了一种神奇的眼力,能轻松识别出每一个号码。这种超凡能力正是计算机视觉任务的魔法所在。

我们要携手PyTorch这位强大的工具,它如同一把巧妙的魔法棒,能帮助我们构建强大的神经网络模型。通过PyTorch,我们可以灵活地定义模型的结构,设置各种参数,并进行高效的训练和推理。

我们遇到了注意力机制,就像是一盏明亮的灯塔,照亮了我们前进的方向。注意力机制能够使神经网络集中注意力于图像中的重要区域,从而提高识别的准确性。利用这种机制,我们可以让模型更加聪明地注重街道房屋号码所在的位置和细节,从而更好地进行识别。而SVHN数据集则是我们探险的指南,其中包含了大量真实世界中的街道房屋号码图像。通过导入这些数据,我们可以让模型从中学习并提高自己的识别能力。这些图像将带领我们穿越城市的角落,感受不同场景下的挑战和变化。通过这篇文章,我们不仅可以更深入地理解计算机视觉任务的本质,还能获得启发。就像是一次奇妙的冒险,我们将学会如何使用PyTorch和注意力机制来实现街道房屋号码的识别任务。让我们一起跟随这个引人入胜的旅程,开拓视野,追寻新的可能性吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/898225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式设计中对于只有两种状态的变量存储设计,如何高效的对循迹小车进行偏差量化

前言 (1)在嵌入式程序设计中,我们常常会要对各类传感器进行数据存储。大多时候的传感器,例如红外光传感器,返回的数据要么是0,要么是1。因此,只需要一bit就能够存储。而很多人却常常使用char型数…

Python中的“ @”

一、介绍 这是Python装饰器的语法,使用符号,表示将装饰器函数放在被装饰函数的上方。当调用被装饰函数时,实际上是调用了装饰器函数,装饰器函数可以在调用被装饰函数之前或之后执行一些额外的操作。 #funA 作为装饰器函数 def f…

慎用!澳洲留学生用ChatGPT写论文被控学术不端!AI论文漏洞百出,各高校已加强捡测!

自从进入ChatGPT时代以来,留学生们纷纷表示写作业,so easy。留学生们在用AI写论文时候没有预计到的是,ChatGPT存在杜撰文献的问题,并且学校已经在使用AI检测工具。 目前全澳大多数大学都可以选择使用现在很流行的反剽窃软件服务T…

【观察】戴尔科技:构建企业创新“韧性”,开辟数实融合新格局

过去几年,国家高度重视发展数字经济,将其上升为国家战略。其中,“十四五”规划中,就明确提出要推动数字经济和实体经济的深度融合,以数字经济赋能传统产业转型升级;而2023年年初正式发布的《数字中国建设整…

LangChain 手记 Conclusion结语

整理并翻译自DeepLearning.AILangChain的官方课程:Conclusion Conclusion 结语 本系列短课展示了大量使用LangChain构建的大语言模型应用,包括处理用户反馈、文档上的问答系统甚至使用LLM来决定发起外部工具的调用(比如搜索)来回答…

如何进行远程debug?

文章目录 前言一、使用步骤1.首先通过nohup在启动jar包的我们可以添加参数:2.具体参数的含义如下:3. 查询监听的端口: 前言 在工作中,排查问题我们经常需要进行debug,而远程debug能够方便的帮助我们排查线上的问题。 …

【C语言基础】宏定义的用法详解

📢:如果你也对机器人、人工智能感兴趣,看来我们志同道合✨ 📢:不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 📢:文章若有幸对你有帮助,可点赞 👍…

【无标题】WIN11下 ESP8266 _RTOS_SDK3.0以上开发环境搭建(记录及避坑必看)

前提参考文档 1、乐鑫官网: https://docs.espressif.com/projects/esp8266-rtos-sdk/en/latest/get-started/index.html 官网上有如何搭建windows linux macos 三种环境,以及如何配置Eclipse去编译和开发项目(如何安装Eclipse环境&#xff0…

高品质音乐下载命令行工具Musicn

又到了小苏同学的生日🎂,宝贝,生日快乐!祝永远健康、快乐、心想事成! 什么是 Musicn ? Musicn 是一个可播放及下载高品质🎵音乐🎵的命令行工具。支持咪咕、酷我、酷狗和网易云的服务…

《合成孔径雷达成像算法与实现》Figure3.10

代码复现如下: clc clear close all% 参数设置 TBP 100; % 时间带宽积 T 7.2e-6; % 脉冲持续时间 t_0 1e-6; % 脉冲回波时延% 参数计算 B TBP/T; …

springboot整合websocker启动失败

在工作的时候,准备使用websocker建立长连接来统计网站在线人数,但是在配置好所有东西后,发现springboot启动失败 详细错误 java.lang.IllegalStateException: Failed to register ServerEndpoint class: class com.example.pipayshopapi.co…

【算法题解】54. 树的冗余连接

这是一道 中等难度 的题 https://leetcode.cn/problems/redundant-connection/ 题目 树可以看成是一个连通且 无环 的 无向 图。 给定往一棵 n n n 个节点 (节点值 1 ~ n 1~n 1~n) 的树中添加一条边后的图。添加的边的两个顶点包含在 1 …

Python可视化在量化交易中的应用(11)_Seaborn折线图

举个栗子,用seaborn绘制折线图。 Seaborn中折线图的绘制方法 在seaborn中,我们一般使用sns作为seaborn模块的别名,因此,在下文中,均以sns指代seaborn模块。 seaborn中绘制折线图使用的是sns.plot()函数: …

【算法学习】两数之和II - 输入有序数组

题目描述 原题链接 给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] &#xff0c;则 1 < index1 < …

JavaScript中的作用域(scope)是什么?以及有哪些类型的作用域?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 作用域&#xff08;Scope&#xff09;是什么&#xff1f;1. 全局作用域&#xff08;Global Scope&#xff09;2. 局部作用域&#xff08;Local Scope&#xff09;3. 块级作用域&#xff08;Block Scope&#xff09; ⭐ 写在最后 ⭐ 专栏简…

MPLAB X IDE 中的查找方式

1.第一种选择到变量&#xff0c;然后按ctrlf&#xff0c;这种方式只能在单个文件中查找&#xff1b; 2.第二种&#xff0c;按ctrlshiftf&#xff0c;前提必须在英文模式下&#xff0c; 对于普通用户来说&#xff0c;只需要知道Containing Text是搜索对象&#xff1b;最下面的F…

stm32红绿灯源代码示例(附带Proteus电路图)

本代码不能直接用于红路灯&#xff0c;只是提供一个思路 #include "main.h" #include "gpio.h" void SystemClock_Config(void); void MX_GPIO_Init(void) {GPIO_InitTypeDef GPIO_InitStruct {0};/* GPIO Ports Clock Enable */__HAL_RCC_GPIOB_CLK_ENAB…

JavaScript中的变量声明方式有哪些?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 变量声明方式var 声明&#xff08;ES5及以前&#xff09;let 声明&#xff08;ES6以后&#xff09;const 声明&#xff08;ES6以后&#xff09; ⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者…

【算法学习】平方数之和

title: 【力扣】633.平方数之和 cover: ‘https://storage.bummon.com/image/202308171051399.png’ tags: 算法LeetCode双指针法二分查找法 categories:算法笔记 abbrlink: 2911343079 date: 2023-08-17 10:49:52 mathjax: true 【力扣】633.平方数之和 题目 原题链接 给定…

Vue 2 计算属性与侦听器

计算属性 vs 方法 vs 侦听器 计算属性的出现是为了解决模板内表达式太过复杂而变得难以维护。 假设我们知道长和宽&#xff0c;要计算一个矩形的面积&#xff0c;如果没有计算属性&#xff0c;我们可能像下面这样处理&#xff1a; <div id"app"><input t…