深度学习之PyTorch实现卷积神经网络(CNN)

news2025/3/17 1:12:43

在深度学习领域,卷积神经网络(Convolutional Neural Networks,CNN)是一种非常强大的模型,专门用于处理图像数据。CNN通过卷积操作和池化操作来提取图像中的特征,具有较好的特征学习能力,特别适用于图像识别和计算机视觉任务。PyTorch作为一种流行的深度学习框架,提供了方便易用的工具来构建和训练CNN模型。本文将介绍如何使用PyTorch构建一个简单的CNN,并通过一个图像分类任务来演示其效果。

1. CNN的结构

典型的CNN结构包括卷积层、池化层和全连接层。卷积层通过卷积操作提取图像特征,池化层通过降采样操作减小特征图的尺寸,全连接层用于最终的分类。

在这里插入图片描述

2. 环境配置

在开始之前,确保已经安装了PyTorch和相关的Python库。可以通过以下命令安装:

pip install torch torchvision matplotlib

3. 数据集准备

在这个示例中,我们将使用PyTorch提供的CIFAR-10数据集,它包含了10个类别的60000张32x32彩色图像。我们将图像分为训练集和测试集,并加载到PyTorch的数据加载器中。

import torch
import torchvision
import torchvision.transforms as transforms

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

4. 构建CNN模型

我们将构建一个简单的CNN模型,包括卷积层、池化层、全连接层和激活函数。这个模型将接受3通道的32x32图像作为输入,并输出10个类别的概率分布。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
	super().__init__()
        # 定义网络层:卷积层+池化层
        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层
        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)


    def forward(self, x):
        # 卷积+relu+池化
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        # 卷积+relu+池化
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        # 将特征图做成以为向量的形式:相当于特征向量
        x = x.reshape(x.size(0), -1)
        # 全连接层
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        # 返回输出结果
        return self.out(x)

net = Net()

5. 模型训练

定义损失函数和优化器,并在训练集上训练CNN模型。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练网络
for epoch in range(2):  # 多次遍历数据集
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:  # 每2000个小批量数据打印一次损失值
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('训练结束!!!')

6. 模型测试

在测试集上评估训练好的模型的性能。

correct = 0
total = 0
# 禁用梯度追踪
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('10000 测试图片的准确率: %d %%' % (100 * correct / total))

7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

%%’ % (100 * correct / total))


## 7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

这篇博客介绍了如何使用PyTorch构建和训练一个简单的卷积神经网络。通过实际的代码示例,读者可以了解CNN模型的基本原理,并掌握如何在PyTorch中实现和训练这样的模型。希望本文能够对你理解和应用深度学习模型有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1598148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为OD机试 - 连续天数的最高利润额(Java 2024 C卷 100分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…

AI音乐,8大变现方式——Suno:音乐版的ChatGPT - 第505篇

悟纤之歌 这是利用AI为自己制作的一首歌,如果你也感兴趣,可以花点时间阅读下本篇文章。 ​ 导读 随着新一代AI音乐创作工具Suno V3、Stable audio2.0、天工SkyMusic的发布,大家玩自创音乐歌曲,玩的不亦乐乎。而有创业头脑的朋友…

一些实用的工具网站

200 css渐变底色 https://webgradients.com/ 200动画效果复制 https://css-loaders.com/classic/ 二次贝塞尔曲线 https://blogs.sitepointstatic.com/examples/tech/canvas-curves/bezier-curve.html 三次贝塞尔曲线 https://blogs.sitepointstatic.com/examples/tech/c…

百货商场用户画像描绘and价值分析(下)

目录 内容概述数据说明技术点主要内容4 会员用户画像和特征字段创造4.1 构建会员用户基本特征标签4.2 会员用户词云分析 5 会员用户细分和营销方案制定5.1 会员用户的聚类分析及可视化5.2 对会员用户进行精细划分并分析不同群体带来的价值差异 内容概述 本项目内容主要是基于P…

环形链表II

给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 …

【数据结构|C语言版】单链表

前言1. 单链表的概念和结构1.1 单链表的概念1.2 单链表的结构 2. 单链表的分类3.单链表的实现3.1 新节点创建3.2 单链表头插3.3 单链表头删3.4 单链表尾插3.5 单链表尾删3.6 链表销毁 4. 代码总结4.1 SLT.h4.2 SLT.c4.3 test.c 后言 前言 各位小伙伴大家好!时隔不久…

mysql 日环比 统计

接到一个任务,要计算日环比的情况。 16、查询销售额日环比情况 日环比: (今日-昨日)/ 昨日 的一个比率情况。 1,建表 DROP TABLE IF EXISTS sale; create table sale(id int not null AUTO_INCREMENT,record_date da…

偏微分方程算法之二维初边值问题(交替方向隐(ADI)格式)

一、研究对象 以二维抛物型方程初边值问题为研究对象: 为了确保连续性,公式(1)中的相关函数满足: 二、理论推导 2.1 向前欧拉格式 首先进行网格剖分。将三维长方体空间(二维位置平面一维时间轴&#xff09…

还在担心报表不好做?不用怕,试试这个方法(四)

系列文章: 《还在担心报表不好做?不用怕,试试这个方法》(一) 《还在担心报表不好做?不用怕,试试这个方法》(二) 《还在担心报表不好做?不用怕,…

UE5学习日记——制作多语言版本游戏,同时初步学习UI制作、多语言化、控制器配置、独立进程测试、打包配置和快速批量翻译等

所有的文本类,无论变量还是控件等都能实现本地化,以此实现不同语言版本。 在这里先将重点注意标注一下: 所有文本类的变量、控件等都可以多语言;本地化控制板中收集、编译时,别忘了编译这一步;支持批量复制…

海思Hi3519 DV500 部署yolov5并加速优化

本项目代码已开源,见文末 导出onnx模型 yolov5官方地址 利用官方命令导出python export.py --weights yolov5n.pt --include onnx 或者自写代码导出 import os import sys os.chdir(sys.path[0]) import onnx import torch sys.path.append(..) from models.co…

ASP.NET MVC企业级程序设计 (EF+三层架构+MVP实现查询数据)

目录 效果图 实现过程 1创建数据库 2创建项目文件 3创建控制器,右键添加,控制器 ​编辑 注意这里要写Home​编辑 创建成功 数据模型创建过程之前作品有具体过程​编辑 4创建DAL 5创建BLL 6创建视图,右键添加视图 ​编辑 7HomeContr…

[计算机效率] 本地视频播放器:QQPlayer

3.26 本地视频播放器:QQPlayer QQPlayer是一款由腾讯公司开发的视频播放软件,它支持多种视频格式,包括MP4、AVI、FLV等,并且可以播放高清视频。 强大的播放功能:QQPlayer具有强大的解码功能,可以轻松播放…

GIS 数据格式转换

1、在线工具 mapshaper 2、数据上传 3、数据格式转换 导入数据可导出为多种格式:Shapefile、Json、GeoJson、CSV、TopJSON、KML、SVG

面试算法-174-二叉树的层序遍历

题目 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:[[3],[9,20],[15,7]] 解 class Solut…

Antd:在文本框中展示格式化JSON

要想将对象转换为格式化 JSON 展示在文本框中,需要用到 JSON.stringify JSON.stringify 方法接受三个参数: value:必需,一个 JavaScript 值(通常为对象或数组)要转换为 JSON 字符串。replacer&#xff1a…

基于springboot+vue实现的疫情防控物资调配与管理系统

作者主页:Java码库 主营内容:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】:Java 【框架】:spring…

【软考高项】二十一、政策法规之招投标法知识点学习

一、总则部分 必须招标的 社会公共利益、公众安全的;国有资金投资或者国家融资的;外国政府贷款、援助资金的原则 公开、公平、公正和诚实信用 不得将依法必须进行招标的项目化整为零或者以其他任何方式规避招标、不受地区或者部门的限…

【Go】原子并发操作

目录 一、基本概念 支持的数据类型 主要函数 使用场景 二、基础代码实例 开协程给原子变量做加法 统计多个变量 原子标志判断 三、并发日志记录器 四、并发计数器与性能监控 五、优雅的停止并发任务 worker函数 Main函数 应用价值 Go语言中,原子并发操…

ABAP MESSAGE 常用的类型

类型文本描述A终止处理终止,用户必须重启事务X退出与消息类型A 类似,但带有程序崩溃 MESSAGE_TYPE_XE错误处理受到干扰,用户必须修正输入条目,左下角提示!W警告处理受到干扰,用户可以修正输入条目,左下角提示!I信息处理受到干扰&a…