PyTorch深度学习网络(二:CNN)

news2024/9/20 9:33:22

卷积神经网络(CNN)是一种专门用于处理具有类似网格结构数据的深度学习模型,例如图像(2D网格的像素)和时间序列数据(1D网格的信号强度)。CNN在图像识别、图像分类、物体检测、语音识别等领域有着广泛的应用。

CNN的核心特点包括局部连接和权值共享:

  1. 局部连接意味着每个神经元只与输入数据的一个局部区域相连,这大大减少了参数的数量,提高了计算效率;
  2. 权值共享是指在卷积层中,相同的卷积核被用于整个输入数据,这不仅进一步减少了参数数量,还使得模型具有平移不变性,即无论物体出现在图像的哪个位置,都能被识别出来。

CNN的基本结构包括输入层、卷积层、激活函数、池化层、全连接层和输出层:

  1. 输入层接收原始图像数据;
  2. 卷积层通过卷积操作提取图像的特征;
  3. 激活函数引入非线性,增强模型的表达能力;
  4. 池化层通过下采样减少数据量,同时保留重要特征,增强模型的鲁棒性,此外多层卷积和池化层的堆叠使得模型能够逐层提取更高层次的特征;
  5. 全连接层将前面各层提取的特征综合起来,用于最终的分类或回归任务。

这种结构使得CNN特别适合处理图像数据,能够自动学习图像中的复杂特征,实现高效准确的图像识别。

本文展示了几种CNN网络结构在图像或文本分类中的应用,包含以下内容:

  1. LeNet的搭建和应用
  2. 微调预训练的VGG16网络
  3. TextCNN的搭建和应用

一、LeNet的搭建和应用

LeNet是早期最经典的卷积神经网络,由 Yann LeCun 等人在 1998 年提出,最初用于手写数字识别(MNIST 数据集),取得了十分显著的效果,其网络结构如图所示:

图片来自 LeNet - Wikipedia

 代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data as Data
import torchvision
from torchvision import models, transforms, datasets

from process import classify  # procss的代码见:https://blog.csdn.net/moyao_miao/article/details/141466047


class LeNet(nn.Module):
    """
    LeNet模型
    """

    def __init__(self, size):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        s = size // 4 - 3
        self.fc1 = nn.Linear(16 * s * s, 120)
        self.fc2 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = self.output(x)
        return output

使用MNIST数据集训练模型:

if __name__ == '__main__':
    train_data = torchvision.datasets.MNIST(
        root=r"C:\Users\57158\data\MNIST",
        train=True,
        transform=transforms.ToTensor(),
        download=False,
    )
    test_data = torchvision.datasets.MNIST(
        root=r"C:\Users\57158\data\MNIST",
        train=False,
        transform=transforms.ToTensor(),
        download=False,
    )
    model = LeNet(28)
    optimizer = Adam(model.parameters(), lr=0.0003)
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    classify(
        (train_data, test_data),
        model,
        optimizer,
        criterion,
        batch_size=64,
        epochs=5,
        device=device,
    )

分类效果:

 二、微调预训练的VGG16网络

VGG网络是由牛津大学视觉几何组(Visual Geometry Group)在2014年提出的一种卷积神经网络架构。它在当年的ImageNet图像分类挑战赛中取得了优异的成绩,并因其简洁的架构和良好的性能而广受欢迎。其网络结构如图所示:

VGG系列网络结构,图片来自 VGGNet-16 Architecture: A Complete Guide (kaggle.com)

VGG网络的主要特点:

  1. 深度:VGG网络以其深度著称,最深的版本(VGG16和VGG19)分别有16层和19层。这种深度使得网络能够学习到更复杂的特征。
  2. 小卷积核:VGG网络使用3x3的小卷积核,而不是之前常用的更大的卷积核(如7x7)。小卷积核的优势在于可以减少参数数量,同时通过叠加多个3x3卷积层可以模拟更大的感受野。
  3. 固定结构:VGG网络的结构非常规整,主要由卷积层和全连接层组成。卷积层通常使用ReLU激活函数,全连接层后面通常接一个softmax层用于分类。
  4. 池化层:VGG网络在每几个卷积层之后会插入一个最大池化层(Max Pooling),用于降低特征图的尺寸,减少计算量,并增强特征的平移不变性。

以VGG16为例,其结构如图: 

VGG16网络结构,图片来自 VGGNet-16 Architecture: A Complete Guide (kaggle.com)

尽管VGG网络具有简洁的结构和良好的性能,但由于其网络较深、参数较多,VGG网络的计算量和内存消耗都比较大,导致从头开始训练比较费时,而且对算力的要求也比较高。幸运的是PyTorch提供了预训练好的网络模型可供调用,开发者在其基础上微调即可快速搭建自己的网络。

一个通过微调预训练的VGG16网络用于分类10种猴子的图像分类器代码如下:

class MyVggModel(nn.Module):
    """
    自定义的VGG16模型
    """

    def __init__(self):
        super().__init__()
        # 加载预训练的vgg16模型
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
        # 冻结参数
        for param in vgg.parameters():
            param.requires_grad_(False)
        # 预训练的vgg16的特征提取层
        self.vgg = vgg
        # 自定义的全连接层
        self.classifier = nn.Sequential(
            nn.Linear(25088, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(256, 10),
            nn.Softmax(dim=1),
        )

    # 定义网络的向前传播路径
    def forward(self, x):
        x = self.vgg(x)
        x = x.view(x.size(0), -1)
        output = self.classifier(x)
        return output

数据集来源:10 Monkey Species (kaggle.com)

数据预处理:

if __name__ == '__main__':
    # 对训练集的预处理
    train_data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),  # 随机将图像裁剪为224*224
        transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
        transforms.ToTensor(),  # 转化为张量并归一化至[0,1]
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 图像标准化处理
    ])
    # 对测试集的预处理
    test_data_transforms = transforms.Compose([
        transforms.Resize(256),  # 将图像缩放为256*256
        transforms.CenterCrop(224),  # 将图像从中心裁剪为224*224
        transforms.ToTensor(),  # 转化为张量并归一化至[0,1]
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 图像标准化处理
    ])
    # 读取图像
    train_data_dir = r'C:\Users\57158\data\10-monkey-species\training\training'
    train_data = datasets.ImageFolder(train_data_dir, transform=train_data_transforms)
    test_data_dir = r'C:\Users\57158\data\10-monkey-species\validation\validation'
    test_data = datasets.ImageFolder(test_data_dir, transform=test_data_transforms)

训练模型:

    model = MyVggModel()
    optimizer = Adam(model.parameters(), lr=0.0003)
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    classify(
        (train_data, test_data),
        model,
        optimizer,
        criterion,
        batch_size=32,
        epochs=10,
        device=device,
    )

分类效果:

三、TextCNN的搭建和应用

TextCNN是一种用于文本分类的卷积神经网络模型,它由 Yoon Kim 在 2014 年提出。TextCNN 通过利用卷积层和池化层来捕捉文本中的局部特征,从而实现高效的文本分类。其结构如图:

TextCNN的结构,图片来自 1510.03820 (arxiv.org)

TextCNN 的基本结构包括以下几个部分:

  1. 嵌入层(Embedding Layer):将输入的词索引转换为词向量。这些词向量可以是预训练的,也可以是随机初始化的。
  2. 卷积层(Convolutional Layer):使用多个不同大小的卷积核来提取不同长度的特征。每个卷积核会在输入的词向量序列上滑动,生成特征图。
  3. 池化层(Pooling Layer):通常使用最大池化(Max Pooling)来提取每个特征图中的最大值,从而减少特征维度并保留最重要的特征。
  4. 全连接层(Fully Connected Layer):将池化后的特征向量输入到全连接层中,进行分类。
  5. 输出层(Output Layer):通常使用 softmax 函数来输出每个类别的概率。

一个通用的TextCNN网络代码如下:

import re
from functools import partial
from typing import Iterator

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
import torchtext;torchtext.disable_torchtext_deprecation_warning()
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import numericalize_tokens_from_iterator
from nltk.corpus import stopwords

from process import classify
# procss的代码见:https://blog.csdn.net/moyao_miao/article/details/141466047


class TextCNN(nn.Module):
    """
    TextCNN模型
    """

    def __init__(self, vocab_size: int, embedding_dim: int, num_filters: int,
                 filter_sizes: Iterator, num_classes: int, dropout: float = 0.5):
        """
        初始化TextCNN模型
        :param vocab_size:词典大小
        :param embedding_dim:词向量维度
        :param num_filters:卷积核个数
        :param filter_sizes:卷积核尺寸
        :param num_classes:输出的维度
        :param dropout:Dropout概率
        """
        super().__init__()
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # 卷积层
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim)) for fs in filter_sizes])
        # 最大池化层
        self.pool = nn.AdaptiveMaxPool1d(1)
        # Dropout层
        self.dropout = nn.Dropout(dropout)
        # 全连接输出层
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)

    # 定义网络的向前传播路径
    def forward(self, text):
        # text:(batch_size,MAX_LENGTH)
        embedded = self.embedding(text).unsqueeze(1)  # embedded:(batch_size,1,MAX_LENGTH,embedding_dim)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]  # conved[n]:(batch_size,num_filters,MAX_LENGTH-filter_sizes[n]+1)
        pooled = [self.pool(conv).squeeze(2) for conv in conved]  # pooled[n]:(batch_size,num_filters)
        cat = self.dropout(torch.cat(pooled, dim=1))  # cat:(batch_size,num_filters*len(filter_sizes))
        return self.fc(cat)

使用IMDB数据集来训练模型:IMDB Dataset of 50K Movie Reviews (kaggle.com)

数据预处理一,文本清洗:

punctuation_regex = re.compile(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]')
stopwords_regex = re.compile('\\b(' + '|'.join(stopwords.words('english')) + ')\\b')
clean_ops = [
    str.lower,  # 转化为小写
    partial(re.sub, '<br /><br />', ' '),  # 去除换行符
    partial(re.sub, '\d+', ''),  # 去除数字
    partial(punctuation_regex.sub, ''),  # 去除符号
    partial(stopwords_regex.sub, ''),  # 去除停用词
    str.strip,  # 去除两端空格
]
def clean_text(s: str, ops: Iterator) -> str:
    """
    模块化的文本清洗函数
    :param s: 待清洗的文本
    :param ops: 清洗函数列表
    :return: 清洗后的文本
    """
    for op in ops:
        s = op(s)
    return s


tokenizer = get_tokenizer('spacy')
def token_gen(texts):
    """
    生成token迭代器
    :param texts: 文本列表
    :return: token迭代器
    """
    for text in texts:
        yield tokenizer(text)


if __name__ == '__main__':
    df = pd.read_csv(r'C:\Users\57158\data\IMDB Dataset.csv')
    df['review'] = df['review'].apply(clean_text, ops=clean_ops)  # 文本清洗
    df['sentiment'] = df['sentiment'].apply(lambda x: 1 if x == 'positive' else 0)  # 标签转换
    df.to_csv(r'IMDB_Dataset_clean.csv', index=False)

数据预处理二,构建数字化文本矩阵:

    VOCAB_SIZE = 20000
    MAX_LENGTH = 100
    df = pd.read_csv(r'IMDB_Dataset_clean.csv')
    vocab = build_vocab_from_iterator(token_gen(df['review']), specials=['<UNK>'], max_tokens=VOCAB_SIZE)  # 构建词典
    vocab.set_default_index(vocab['<UNK>'])  # 设置默认索引处理未知词
    sequence = numericalize_tokens_from_iterator(vocab=vocab, iterator=token_gen(df['review']))  # 数字化文本
    token_ids = [torch.tensor(list(x)) for x in sequence]  # 将数字化的文本转换为tensor
    padded_text = pad_sequence(token_ids, batch_first=True, padding_value=0)[:, :MAX_LENGTH]  # 填充文本并截断

注意:新版本的torchtext接口较旧版本变化比较大,很多旧版本的用法已经失效了,24年以前的torchtext教程就不用再看了。

训练模型:

    model = TextCNN(
        vocab_size=len(vocab),
        embedding_dim=MAX_LENGTH,
        num_filters=100,
        filter_sizes=[3, 4, 5],
        num_classes=2,
    )
    model.embedding.weight.data[vocab['<UNK>']] = torch.zeros(MAX_LENGTH)
    model.embedding.weight.data[vocab['<PAD>']] = torch.zeros(MAX_LENGTH)
    optimizer = Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    classify(
        (padded_text, torch.tensor(df['sentiment'])),
        model,
        optimizer,
        criterion,
        batch_size=32,
        epochs=3,
        device=device,
        to_tensor=False,
    )

分类效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2080752.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

API网关之Kong

Kong 是一个高性能的开源 API 网关和微服务管理平台&#xff0c;用于管理、保护和扩展 API 和微服务。它最初由 Mashape 公司开发&#xff0c;并于 2015 年作为开源项目发布。Kong 能够处理 API 的路由、认证、负载均衡、缓存、监控、限流等多种功能&#xff0c;是微服务架构中…

Mysql中count(*) over 用法讲解

Mysql中count&#xff08;*&#xff09; over &#xff08;&#xff09;用法讲解 一、原理1、原理介绍 二、下面是一个使用COUNT(*) OVER()的代码示例&#xff1a;1、代码示例2、结果详解3、COUNT(*) OVER() 分区用法 三 、总结 一、原理 1、原理介绍 在MySQL中&#xff0c;C…

MySQL集群的基础部署及主从复制详解

一、Msql在服务器中的部署方法 官网&#xff1a;http://www.mysql.com 在企业中90%的服务器操作系统均为Linux 在企业中对于Mysql的安装通常用源码编译的方式来进行 1.1 在Linux下部署MySQL 1.1.1 部署环境 主机IP角色MySQL-node1172.25.254.13masterMySQL-node2172.25.…

【C语言】深入理解指针(四)qsort函数的实现

指针4 1.回调函数是什么2.qsort使用举例3.qsort函数的模拟实现 1.回调函数是什么 回调函数就是⼀个通过函数指针调⽤的函数。 如果你把函数的指针&#xff08;地址&#xff09;作为参数传递给另⼀个函数&#xff0c;当这个指针被⽤来调⽤其所指向的函数 时&#xff0c;被调⽤的…

【CanMV K230】外接传感器

【CanMV K230】外接传感器 外接LED灯 B站视频链接 抖音链接 我们后面主要做是机器视觉。K230能帮我们捕捉到图像信息。更多小功能需要我们自己来做。 比如舵机抬杆&#xff0c;测温报警等 都需要我们外接传感器。 本篇就来分享一下如何使用K230外接传感器 首先需要知道K230…

Leetcode JAVA刷刷站(98)验证二叉搜索树

一、题目概述 二、思路方向 在Java中&#xff0c;要判断一个二叉树是否是有效的二叉搜索树&#xff08;BST&#xff09;&#xff0c;我们可以采用递归的方法&#xff0c;通过维护一个外部的范围&#xff08;通常是Integer.MIN_VALUE到Integer.MAX_VALUE作为初始范围&#xff…

网络优化4|网络流问题|路径规划问题|车辆路径问题

网络流问题 网络最大流问题 研究网络通过的流量也是生产管理中经常遇到的问题 例如&#xff1a;交通干线车辆最大通行能力、生产流水线产品最大加工能力、供水网络中最大水流量等。这类网络的弧有确定的容量&#xff0c;虽然常用 c i j c_{ij} cij​表示从节点 i i i到节点 j…

怎么检测电脑的RAM?丨什么是RAM?

RAM 是 Random Access Memory 的缩写&#xff0c;它是一个允许计算机短期存储数据以更快访问的组件。众所周知&#xff0c;操作系统、应用程序和各种个人文件都存储在硬盘驱动器中。 当 CPU 需要调用硬盘上的数据进行计算和运行时&#xff0c;CPU 会将数据传输到 RAM 中进行计…

安防视频汇聚平台EasyCVR启动后无法访问登录页面是什么原因?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台基于云边端一体化架构&#xff0c;兼容性强、支持多协议接入&#xff0c;包括国标GB/T28181协议、部标JT808、GA/T1400协议、RTMP、RTSP/Onvif协议、海康Ehome、海康SDK、大华SDK、华为SDK、宇视SDK、乐橙SDK、萤石云SDK等…

科研绘图系列:R语言多组极坐标图(grouped polar plot)

介绍 Polar plot(极坐标图)是一种二维图表,它使用极坐标系统来表示数据,而不是像笛卡尔坐标系(直角坐标系)那样使用x和y坐标。在极坐标图中,每个数据点由一个角度(极角)和一个半径(极径)来确定。角度通常从水平线(或图表的某个固定参考方向)开始测量,而半径则是…

Jenkins发邮件功能如何配置以实现自动化?

jenkins发邮件的设置指南&#xff1f;Jenkins怎么配置服务器&#xff1f; Jenkins作为一个流行的自动化服务器&#xff0c;其发邮件功能是通知团队成员构建状态的重要手段。AokSend将详细介绍如何配置Jenkins发邮件功能&#xff0c;以实现自动化通知。 Jenkins发邮件&#xf…

NVI技术创新联盟成立,BOSMA博冠IP轻量化制播已运用

2024年北京国际广播电影电视展览会&#xff08;BIRTV&#xff09;首日&#xff0c;由中央广播电视总台与中国电影电视技术学会联合牵头组建的NVI技术创新联盟在BIRTV 2024超高清全产业链发展研讨会上宣布正式成立。作为国产8K摄像机先行者&#xff0c;BOSMA博冠受邀加入NVI技术…

Flowable BPMN bpmnjs 设计器

最近半年我一直在打造一款行业顶尖的流程设计器&#xff0c;适配了flowable所有的组件&#xff0c;美观&#xff0c;大方&#xff0c;灵活&#xff0c;好用。所有的组件都进行严格的测试并在生产环境上线了。 1、在线预览 2、整体框架布局 3、组件分组 4、完整模式切换 给大…

若依,前后端分离项目,部署到服务器

1.后端项目用maven打包 正式服的话&#xff0c;测试不用加。 application.yml加上context-path: /prod-api 一定要选择root的ruoyi&#xff0c;他会把你自动打包其他模块的依赖 全部成功。然后去ruoyi-admin拿到这个包&#xff0c;java -jar ruoyi-admin.jar就可以了 将jar上…

STM32嵌套向量中断控制器—NVIC

NVIC简介&#xff1a; NVIC&#xff0c;即Nested Vectored Interrupt Controller&#xff08;嵌套向量中断控制器&#xff09;&#xff0c;是STM32中的中断控制器。它负责管理和协调处理器的中断请求&#xff0c;是STM32中处理异步事件的重要机制。 NVIC提供了灵活、高效、可扩…

IoTDB 在顶级会议 VLDB 2024:四篇最新论文入选,特邀做 TPC 报告与讨论会!

再获权威顶会认可 8 月 26 日至 8 月 30 日&#xff0c;数据库领域的顶级国际会议 VLDB 2024 在广州举行。IoTDB 三篇论文的最新研发成果被本次大会录用&#xff0c;这其中也包括 TsFile 成为 Apache Top-Level 项目后发表的第一篇顶会论文。 同时&#xff0c;在国际权威数据库…

博弈论详解 2(SG函数 和 SG定理)

传送门&#xff1a;博弈论详解 1&#xff08;基本理论定义 和 Nim 游戏&#xff09; 什么是 SG 函数 接着上次的讲解&#xff0c;我们来了解一个更通用的模型。我们把每一个状态变成一个点&#xff08;在 Nim 游戏里就代表 a a a 数组&#xff09;&#xff0c;如果可以从一种…

008、架构_分布式事务

分布式事务控制 对于一个分布式写事务,计算节点会向GTM申请全局事务GTID,GTID申请成功后,称当前GTID对应的事务是活跃事务,处于未提交状态。如果涉及数据更新,则将GTID信息同步更新到该事务要更新的事务中。成功提交事务后,这里的成功是指分布式事务涉及所有数据节点均提…

C++入门基础知识37——【关于C++ 运算符——关系运算符】

成长路上不孤单&#x1f60a;【14后&#xff0c;C爱好者&#xff0c;持续分享所学&#xff0c;如有需要欢迎收藏转发&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#xff01;&#xff01;&#xff01;&#xff01;&#xff…

Day24 第11站 出发 c++!

1> 思维导图 2> 提示并输入一个字符串&#xff0c;统计该字符串中字母个数、数字个数、空格个数、其他字符的个数 string s1;cout << "请输入一个字符串" << endl;getline(cin,s1);int len s1.length();char buf[128]"";strcpy(buf,s1…