深度学习blog-卷积神经网络(CNN)

news2025/1/3 1:15:02

卷积神经网络(Convolutional Neural Network,CNN)是一种广泛应用于计算机视觉领域,如图像分类、目标检测和图像分割等任务中的深度学习模型。

1. 结构
卷积神经网络一般由以下几个主要层组成:
输入层:接收原始图像数据,通常是三维(高、宽、通道)的张量。
卷积层(Convolutional Layer):使用多个卷积核(滤波器)对输入数据进行卷积操作,提取特征。该层的输出是特征图,显示了输入数据中的特征。卷积层是 CNN 的核心组成部分,它的主要功能是通过卷积操作提取局部特征。
在这里插入图片描述

卷积操作是通过一个小的滤波器(或卷积核)在输入图像上滑动来计算的,每次滑动时,卷积核与局部区域的像素值做点积运算,并输出一个新的值。这些新值组成了特征图(feature map)。

步长指定卷积核在输入数据上滑动的步伐。
填充(Padding)
填充是为了确保卷积操作不会丢失边缘信息,通常会在输入数据的边缘添加一些零值,称为零填充。
激活层(Activation Layer):常用的激活函数包括ReLU(修正线性单元)等,负责引入非线性因素,提高网络学习能力。通常放在卷积层之后。
池化层(Pooling Layer):对特征图进行下采样,通常使用最大池化或平均池化,减少特征的尺寸,降低计算复杂度,同时保留重要特征。避免过拟合。
常见的池化操作有最大池化和平均池化
最大池化(Max Pooling),对每个子区域选择最大值。
平均池化(Average Pooling),对每个子区域取平均值。

全连接层(Fully Connected Layer):将高层次的特征输出转换为最终的分类结果。每个神经元与前一层的所有神经元相连接。(将提取的高维特征映射到标签空间)

输出层:提供最终的预测结果,比如分类标签或回归值。

  1. 原理
    卷积神经网络的核心原理是利用卷积操作进行特征提取。卷积层通过卷积核在输入图像上滑动,不断提取局部区域的特征,能够自动学习并优化这些特征。
    卷积操作:通过卷积核与输入图像的局部区域进行点积,生成特征图。这个过程能够捕捉图像中的边缘、角点等基础特征。
    参数共享:同一个卷积核在整个图像上重复使用,可以减少模型参数,提高模型的泛化能力。
    局部感知:卷积核的大小限制了每个神经元的感知范围,使网络能学习到局部特征。

  2. 工作流程
    卷积神经网络的工作流程通常包括以下几个步骤:
    图像输入:将图像数据输入到网络中。
    特征提取:
    在卷积层中,通过多个卷积核对输入图像进行卷积,生成特征图。
    通过激活函数引入非线性。
    使用池化层进行特征降维。
    分类阶段:
    将经过多层特征提取后的特征图展平成一维向量,输入到全连接层。
    使用激活函数进行处理。
    损失计算:通过损失函数计算预测值与真实值之间的误差。
    反向传播:通过反向传播算法更新网络中的权重和偏置,以最小化损失。
    预测输出:经过最后的输出层,网络给出分类结果或回归输出。

例子,识别手写数字:

import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# 构建 CNN 模型
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dropout(0.15),
    layers.Dense(10, activation='softmax')  # 10 类输出
])

# 编译模型
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
history = model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

# 在测试集上评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc:.4f}")

# 随机选择一些测试图像的索引
num_images = 10
random_indices = np.random.choice(x_test.shape[0], num_images, replace=False)

test_images = x_test[random_indices]
true_labels = np.argmax(y_test[random_indices], axis=1)
predicted_labels = np.argmax(model.predict(test_images), axis=1)

plt.figure(figsize=(12, 4))
for i in range(num_images):
    plt.subplot(2, 5, i + 1)
    plt.imshow(test_images[i].reshape(28, 28), cmap='gray')
    plt.title(f"True: {true_labels[i]}\nPred: {predicted_labels[i]}")
    plt.axis('off')
plt.show()


# 绘制训练过程中的准确率和损失
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')

在这里插入图片描述
pytorch实现:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision import datasets

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 tensor,并归一化为 [0, 1] 区间
])
train_dataset = datasets.MNIST(root='../../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='../../data', train=False, download=True, transform=transform)

# 数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 构建 CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 28x28x1 -> 28x28x32
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 下采样
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 28x28x32 -> 28x28x64
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  # 28x28x64 -> 28x28x64
        self.fc1 = nn.Linear(64 * 7 * 7, 64)  # 根据池化后特征的形状计算输入大小
        self.dropout = nn.Dropout(0.15)  # Dropout 层
        self.fc2 = nn.Linear(64, 10)  # 输出10类

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # Conv1 + ReLU + Pooling
        x = self.pool(torch.relu(self.conv2(x)))  # Conv2 + ReLU + Pooling
        x = torch.relu(self.conv3(x))               # Conv3 + ReLU
        x = x.view(-1, 64 * 7 * 7)  # 展平
        x = torch.relu(self.fc1(x))  # FC1 + ReLU
        x = self.dropout(x)           # Dropout
        x = self.fc2(x)               # FC2
        return x

    # 创建模型实例
model = CNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for batch_images, batch_labels in train_loader:
        optimizer.zero_grad()  # 梯度清零
        outputs = model(batch_images)  # 前向传播
        loss = criterion(outputs, batch_labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 在测试集上评估模型
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for batch_images, batch_labels in test_loader:
        outputs = model(batch_images)  # 前向传播
        loss = criterion(outputs, batch_labels)  # 计算损失
        test_loss += loss.item()  # 累加损失
        _, predicted = torch.max(outputs.data, 1)  # 预测
        correct += (predicted == batch_labels).sum().item()  # 统计正确样本数

# 计算准确率
test_accuracy = correct / len(test_dataset)
print(f"Test accuracy: {test_accuracy:.4f}")

# 随机选择一些测试图像的索引并可视化
num_images = 10
random_indices = np.random.choice(len(test_dataset), num_images, replace=False)

test_images = []
true_labels = []
predicted_labels = []

for idx in random_indices:
    img, label = test_dataset[idx]
    test_images.append(img)
    true_labels.append(label)

test_images_tensor = torch.stack(test_images)
with torch.no_grad():
    outputs = model(test_images_tensor)  # 前向传播
    _, predicted = torch.max(outputs.data, 1)  # 预测
    predicted_labels = predicted.numpy()

# 可视化结果
plt.figure(figsize=(12, 4))
for i in range(num_images):
    plt.subplot(2, 5, i + 1)
    plt.imshow(test_images[i].numpy()[0], cmap='gray')  # 仅显示通道1
    plt.title(f"True: {true_labels[i]}\nPred: {predicted_labels[i]}")
    plt.axis('off')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2268041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习笔记(6)——循环神经网络RNN

循环神经网络 RNN 核心思想:RNN内部有一个“内部状态”,随着序列处理而更新 h t f W ( h t − 1 , x t ) h_tf_W(h_{t-1},x_t) ht​fW​(ht−1​,xt​) 一般来说 h t t a n h ( W h h h t − 1 W x h x t ) h_ttanh(W_{hh}h_{t-1}W_{xh}x_t) ht​tanh(Whh​ht−1​Wxh​xt…

最新版Edge浏览器加载ActiveX控件技术——alWebPlugin中间件V2.0.28-迎春版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品,致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器,实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

【前端,TypeScript】TypeScript速成(三):枚举类型

枚举类型 枚举类型是 TypeScript 相较于 JavaScript 而言特有的部分。一个简单的枚举声明如下: enum HTTPStatus {OK,NOT_FOUND,INTERNAL_STATUS_ERROR, }与编译成 JavaScript 的代码相比较: 显然 TypeScript 非常的简洁。 尝试使用上述枚举类型&…

Webpack学习笔记(6)

首先搭建一个基本的webpack环境: 执行npm init -y,创建pack.json,保存安装包的一些信息 执行npm install webpack webpack-cli webpack-dev-server html-webpack-plugin -D,出现node_modules和package-lock.json。 1.source-Ma…

Java高频面试之SE-06

hello啊,各位老6!!!本牛马baby今天又来了!哈哈哈哈哈嗝🐶 访问修饰符 public、private、protected的区别是什么? 在Java中,访问修饰符用于控制类、方法和变量的访问权限。主要的访…

报表工具DevExpress Reporting v24.2亮点 - AI功能进一步强化

DevExpress Reporting是.NET Framework下功能完善的报表平台,它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集,包括数据透视表、图表,因此您可以构建无与伦比、信息清晰的报表。 报表工具DevExpress Reporting v24.2将于近期发…

每天40分玩转Django:Django表单集

Django表单集 一、知识要点概览表 类别知识点掌握程度要求基础概念FormSet、ModelFormSet深入理解内联表单集InlineFormSet、BaseInlineFormSet熟练应用表单集验证clean方法、验证规则熟练应用自定义配置extra、max_num、can_delete理解应用动态管理JavaScript动态添加/删除表…

MVCC实现原理以及解决脏读、不可重复读、幻读问题

MVCC实现原理以及解决脏读、不可重复读、幻读问题 MVCC是什么?有什么作用?MVCC的实现原理行隐藏的字段undo log日志版本链Read View MVCC在RC下避免脏读MVCC在RC造成不可重复读、丢失修改MVCC在RR下解决不可重复读问题RR下仍然存在幻读的问题 MVCC是什么…

自学记录鸿蒙API 13:实现人脸比对Core Vision Face Comparator

完成了文本识别和人脸检测的项目后,我发现人脸比对是一个更有趣的一个小技术玩意儿。我决定整一整,也就是对HarmonyOS Next最新版本API 13中的Core Vision Face Comparator API的学习,这项技术能够对人脸进行高精度比对,并给出相似…

代码解析:安卓VHAL的AIDL参考实现

以下内容基于安卓14的VHAL代码。 总体架构 参考实现采用双层架构。上层是 DefaultVehicleHal,实现了 VHAL AIDL 接口,并提供适用于所有硬件设备的通用 VHAL 逻辑。下层是 FakeVehicleHardware,实现了 IVehicleHardware 接口。此类可模拟与实…

通过 Ansys Electronics Desktop 中的高级仿真优化 IC 设计

半导体行业继续通过日益复杂的集成电路 (IC) 设计突破技术界限。随着工艺节点缩小和电路密度达到前所未有的水平,电磁效应对设备性能和可靠性变得越来越重要。现代 IC 设计面临着来自复杂的布局相关耦合机制、信号完整性问题和功率分布问题的挑战,这些问…

Kafka数据迁移全解析:同集群和跨集群

文章目录 一、同集群迁移二、跨集群迁移 Kafka两种迁移场景,分别是同集群数据迁移、跨集群数据迁移。 一、同集群迁移 应用场景: broker 迁移 主要使用的场景是broker 上线,下线,或者扩容等.基于同一套zookeeper的操作。 实践: 将需要新添加…

【OpenGL ES】GLSL基础语法

1 前言 本文将介绍 GLSL 中数据类型、数组、结构体、宏、运算符、向量运算、矩阵运算、函数、流程控制、精度限定符、变量限定符(in、out、inout)、函数参数限定符等内容,另外提供了一个 include 工具,方便多文件管理 glsl 代码&a…

ffmpeg之播放一个yuv视频

播放YUV视频的步骤 初始化SDL库: 目的:确保SDL库正确初始化,以便可以使用其窗口、渲染和事件处理功能。操作:调用 SDL_Init(SDL_INIT_VIDEO) 来初始化SDL的视频子系统。 创建窗口用于显示YUV视频: 目的:…

复习打卡大数据篇——Hadoop MapReduce

目录 1. MapReduce基本介绍 2. MapReduce原理 1. MapReduce基本介绍 什么是MapReduce MapReduce是一个分布式运算程序的编程框架,核心功能是将用户编写的业务逻辑代码和自带默认组件整合成一个完整的分布式运算程序,并发运行在Hadoop集群上。 MapRed…

小程序配置文件 —— 13 全局配置 - window配置

全局配置 - window配置 这里讲解根目录 app.json 中的 window 字段,window 字段用于设置小程序的状态栏、导航条、标题、窗口背景色; 状态栏:顶部位置,有网络信号、时间信息、电池信息等;导航条:有一个当…

el-pagination 为什么只能展示 10 条数据(element-ui@2.15.13)

好的&#xff0c;我来帮你分析前端为什么只能展示 10 条数据&#xff0c;以及如何解决这个问题。 问题分析&#xff1a; pageSize 的值&#xff1a; 你的 el-pagination 组件中&#xff0c;pageSize 的值被设置为 10&#xff1a;<el-pagination:current-page"current…

单片机与MQTT协议

MQTT 协议简述 MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息队列遥测传输协议&#xff09;&#xff0c;是一种基于发布 / 订阅&#xff08;publish/subscribe&#xff09;模式的 “轻量级” 通讯协议&#xff0c;该协议构建于 TCP/IP 协议上&#xf…

Debian-linux运维-docker安装和配置

腾讯云搭建docker官方文档&#xff1a;https://cloud.tencent.com/document/product/213/46000 阿里云安装Docker官方文档&#xff1a;https://help.aliyun.com/zh/ecs/use-cases/install-and-use-docker-on-a-linux-ecs-instance 天翼云常见docker源配置指导&#xff1a;htt…

使用Docker-compose部署SpringCloud项目

docker编写dockfile遇到的问题&#xff1a; 需要在docker-compose.yml文件下执行命令 docker-compose.yml文件格式的问题 1和2处空2格&#xff0c;3处空1格&#xff0c;4为本地配置文件目录&#xff0c;5为docker容器的目录&#xff0c;version为自己安装的docker-compose版本 …