PyTorch 系列教程:使用CNN实现图像分类

news2025/3/13 19:55:18

图像分类是计算机视觉领域的一项基本任务,也是深度学习技术的一个常见应用。近年来,卷积神经网络(cnn)和PyTorch库的结合由于其易用性和鲁棒性已经成为执行图像分类的流行选择。

理解卷积神经网络(cnn)

卷积神经网络是一类深度神经网络,对分析视觉图像特别有效。他们利用多层构建一个可以直接从图像中识别模式的模型。这些模型对于图像识别和分类等任务特别有用,因为它们不需要手动提取特征。

cnn的关键组成部分

  • 卷积层:这些层对输入应用卷积操作,将结果传递给下一层。每个过滤器(或核)可以捕获不同的特征,如边缘、角或其他模式。
  • 池化层:这些层减少了表示的空间大小,以减少参数的数量并加快计算速度。池化层简化了后续层的处理。
  • 完全连接层:在这些层中,神经元与前一层的所有激活具有完全连接,就像传统的神经网络一样。它们有助于对前一层识别的对象进行分类。
    在这里插入图片描述

使用PyTorch进行图像分类

PyTorch是开源的深度学习库,提供了极大的灵活性和多功能性。研究人员和从业人员广泛使用它来轻松有效地实现尖端的机器学习模型。

设置PyTorch

首先,确保在开发环境中安装了PyTorch。你可以通过pip安装它:

pip install torch torchvision

用PyTorch创建简单的CNN示例

下面是如何定义简单的CNN来使用PyTorch对图像进行分类的示例。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)      # 第一个卷积层:3输入通道,6输出通道,5x5卷积核
        self.pool = nn.MaxPool2d(2, 2)        # 最大池化层:2x2窗口,步长2
        self.conv2 = nn.Conv2d(6, 16, 5)     # 第二个卷积层:6输入通道,16输出通道,5x5卷积核
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出
        self.fc2 = nn.Linear(120, 84)      # 全连接层2:120输入 -> 84输出
        self.fc3 = nn.Linear(84, 10)       # 输出层:84输入 -> 10类 logits

    def forward(self, x):
        # 输入形状:[batch_size, 3, 32, 32]
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸减半)
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] 
        x = x.view(-1, 16 * 5 * 5)            # 展平为一维向量:16 * 5 * 5=400
        x = F.relu(self.fc1(x))             # -> [batch, 120]
        x = F.relu(self.fc2(x))             # -> [batch, 84]
        x = self.fc3(x)                     # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)
        return x

这个特殊的网络接受一个输入图像,通过两组卷积和池化层,然后是三个完全连接的层。根据数据集的复杂性和大小调整网络的架构和超参数。

模型定义

  • SimpleCNN 继承自 nn.Module
  • 使用两个卷积层提取特征,三个全连接层进行分类
  • 最终输出未应用 softmax,而是直接输出 logits(与 CrossEntropyLoss 配合使用)

训练网络

对于训练,你需要一个数据集。PyTorch通过torchvision包提供了用于数据加载和预处理的实用程序。

import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 初始化模型、损失函数和优化器
net = SimpleCNN()               # 实例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), 
                            lr=0.001,      # 学习率
                            momentum=0.9)   # 动量参数

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),          
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  

# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True,  # 自动下载数据集
    transform=transform
)

trainloader = DataLoader(trainset, 
                     batch_size=4,   # 每个batch包含4张图像
                     shuffle=True)  # 打乱数据顺序

模型配置

  • 损失函数CrossEntropyLoss(自动包含 softmax 和 log_softmax)
  • 优化器:SGD with momentum,学习率 0.001

数据加载

  • 使用 torchvision.datasets.CIFAR10 加载数据集

  • batch_size:4(根据 GPU 内存调整,CIFAR-10 建议 batch size ≥ 32)

  • transforms.Compose 定义数据预处理流程:

    • ToTensor():将图像转换为 PyTorch Tensor
    • Normalize():标准化图像像素值到 [-1, 1]

加载数据后,训练过程包括通过数据集进行多次迭代,使用反向传播和合适的损失函数:

# 训练循环
for epoch in range(2):  # 进行2个epoch的训练
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # 前向传播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()   # 清空梯度
        loss.backward()         # 计算梯度
        optimizer.step()       # 更新参数
        
        running_loss += loss.item()
        
        # 每2000个batch打印一次
        if i % 2000 == 1999:
            avg_loss = running_loss / 2000
            print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
            running_loss = 0.0

print("训练完成!")

训练循环

  • epoch:完整遍历数据集一次
  • batch:数据加载器中的一个批次
  • 梯度清零:每次反向传播前需要清空梯度
  • 损失计算outputs 的形状为 [batch_size, 10]labels 为整数标签

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 定义CNN模型(修复了变量引用问题)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)      # 第一个卷积层:3输入通道,6输出通道,5x5卷积核
        self.pool = nn.MaxPool2d(2, 2)        # 最大池化层:2x2窗口,步长2
        self.conv2 = nn.Conv2d(6, 16, 5)     # 第二个卷积层:6输入通道,16输出通道,5x5卷积核
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层1:400输入 -> 120输出
        self.fc2 = nn.Linear(120, 84)      # 全连接层2:120输入 -> 84输出
        self.fc3 = nn.Linear(84, 10)       # 输出层:84输入 -> 10类 logits

    def forward(self, x):
        # 输入形状:[batch_size, 3, 32, 32]
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸减半)
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] 
        x = x.view(-1, 16 * 5 * 5)            # 展平为一维向量:16 * 5 * 5=400
        x = F.relu(self.fc1(x))             # -> [batch, 120]
        x = F.relu(self.fc2(x))             # -> [batch, 84]
        x = self.fc3(x)                     # -> [batch, 10](未应用softmax,配合CrossEntropyLoss使用)
        return x

# 初始化模型、损失函数和优化器
net = SimpleCNN()               # 实例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数(自动处理softmax)
optimizer = torch.optim.SGD(net.parameters(), 
                            lr=0.001,      # 学习率
                            momentum=0.9)   # 动量参数

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),            
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])

# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True,  # 自动下载数据集
    transform=transform
)
trainloader = DataLoader(trainset, 
                         batch_size=4,   # 每个batch包含4张图像
                         shuffle=True)  # 打乱数据顺序

# 训练循环
for epoch in range(2):  # 进行2个epoch的训练
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # 前向传播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()   # 清空梯度
        loss.backward()         # 计算梯度
        optimizer.step()       # 更新参数
        
        running_loss += loss.item()
        
        # 每2000个batch打印一次
        if i % 2000 == 1999:
            avg_loss = running_loss / 2000
            print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
            running_loss = 0.0

print("训练完成!")

最后总结

通过PyTorch和卷积神经网络,你可以有效地处理图像分类任务。借助PyTorch的灵活性,可以根据特定的数据集和应用程序构建、训练和微调模型。示例代码仅为理论过程,实际项目中还有大量优化空间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2314468.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 大视界 -- Java 大数据中的数据可视化大屏设计与开发实战(127)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

「Unity3D」UGUI将元素固定在,距离屏幕边缘的某个比例,以及保持元素自身比例

在不同分辨率的屏幕下,UI元素按照自身像素大小,会发生位置与比例的变化,本文仅利用锚点(Anchors)使用,来实现UI元素,固定在某个比例距离的屏幕边缘。 首先,将元素的锚点设置为中心&…

Deep research深度研究:ChatGPT/ Gemini/ Perplexity/ Grok哪家最强?(实测对比分析)

目前推出深度研究和深度检索的AI大模型有四家: OpenAI和Gemini 的deep research,以及Perplexity 和Grok的deep search,都能生成带参考文献引用的主题报告。 致力于“几分钟之内生成一份完整的主题调研报告,解决人力几小时甚至几天…

关于sqlalchemy的ORM的使用

关于sqlalchemy的ORM的使用 二、创建表三、使用数据表、查询记录三、批量插入数据四、关于with...as...:的使用 二、创建表 使用Mapped来映射字段 from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker,Mapped,mapped_columnBa…

【leetcode hot 100 148】排序序列

解法一:(双重循环)第一个循环head,逐步将head的node加入有序列表;第二个循环在有序列表中找到合适的位置,插入node。 /*** Definition for singly-linked list.* public class ListNode {* int val;* …

【Linux】在VMWare中安装Ubuntu操作系统(2025最新_Ubuntu 24.04.2)#VMware安装Ubuntu实战分享#

今天田辛老师为大家带来一篇关于在VMWare虚拟机上安装Ubuntu系统的详细教程。无论是学习、开发还是测试,虚拟机都是一个非常实用的工具,它允许我们在同一台物理机上运行多个操作系统。Ubuntu作为一款开源、免费且用户友好的Linux发行版,深受广…

AutoGen学习笔记系列(十三)Advanced - Logging

这篇文章瞄的是AutoGen官方教学文档 Advanced 章节中的 Logging 篇章,介绍了怎样在使用过程中添加日志信息,其实就是使用了python自带的日志库 logging。 官网链接:https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-g…

scrcpy pc机远程 无线 控制android app 查看调试log

背景: 公司的安卓机,是那种大屏幕的连接usb外设的。不好挪动,占地方,不能直接连接pc机上的android stduio来调试。 所以从网上找了一个python adb.exe控制器,可以局域网内远程控制开发的app,并在android stduio上看…

UE5.5 Niagara发射器更新属性

发射器属性 在 Niagara 里,Emitter 负责控制粒子生成的规则和行为。不同的 Emitter 属性决定了如何发射粒子、粒子如何模拟、计算方式等。 发射器 本地空间(Local Space) 控制粒子是否跟随发射器(Emitter)移动。 ✅…

MongoDB备份与还原

备份恢复工具介绍 1)mongoexport/mongoimport 2)mongodump/mongorestore 备份工具区别 mongoexport/mongoimport 导入/导出的是JSON格式或者CSV格式 mongodump/mongorestore 导入/导出的是BSON格式。二进制方式,速度快 1)…

计算机:基于深度学习的Web应用安全漏洞检测与扫描

目录 前言 课题背景和意义 实现技术思路 一、算法理论基础 1.1 网络爬虫 1.2 漏洞检测 二、 数据集 三、实验及结果分析 3.1 实验环境搭建 3.2 模型训练 最后 前言 📅大四是整个大学期间最忙碌的时光,一边要忙着备考或实习为毕业后面临的就业升学做准备,…

Java 大视界 -- Java 大数据在智能安防视频摘要与检索技术中的应用(128)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

部署项目至服务器:响应时间太长,无法访问此页面?

在我们部署项目到服务器上的时候,一顿操作猛如虎,打开页面..... 这里记录一下这种情况是怎么回事。一般就是服务器上的安全组没有放行端口。 因为我是用宝塔进行项目部署的。所以遇到这种情况,要去操作两边(宝塔and服务器所属平台…

【数据结构】List介绍

目录 1. 什么是List 2. 常见接口介绍 3. List的使用 1. 什么是List 在集合框架中,List是一个接口,继承自Collection。此时extends意为拓展 Collection也是一个接口,该接口中规范了后序容器中常用的一些方法,具体如下所示&…

vs2022用git插件重置--删除更改(--hard)后恢复删除的内容

1、先到项目工程中打开需要恢复的分支。 2、进入代码管理根目录文件夹。 3、在根目录文件夹点右键,点git bash here 正常情况下如果git目录权限足够,是可以如上图所示显示当前分支和当前目录的。 在git权限不足的情况下会出现如下提示: …

vscode接入DeepSeek 免费送2000 万 Tokens 解决DeepSeek无法充值问题

1. 在vscode中安装插件 Cline 2.打开硅基流动官网 3. 注册并登陆,邀请码 WpcqcXMs 4.登录后新建秘钥 5. 在vscode中配置cline (1) API Provider 选择 OpenAI Compatible ; (2) Base URL设置为 https://api.siliconflow.cn](https://api.siliconfl…

【MySQL】用户管理和权限

欢迎拜访:雾里看山-CSDN博客 本篇主题:【MySQL】用户管理和权限 发布时间:2025.3.12 隶属专栏:MySQL 目录 引言用户用户信息创建用户语法案例 修改用户密码语法案例 删除用户语法案例 权限权限列表查看和刷新用户的权限给用户授权…

指令微调 (Instruction Tuning) 与 Prompt 工程

引言 预训练语言模型 (PLMs) 在通用语言能力方面展现出强大的潜力。然而,如何有效地引导 PLMs 遵循人类指令, 并输出符合人类意图的响应, 成为释放 PLMs 价值的关键挑战。 指令微调 (Instruction Tuning) 和 Prompt 工程 (Prompt Engineerin…

UE5.5 Niagara 发射器粒子更新模块

Particle State (粒子状态)模块 Particle State 主要用于控制粒子的生存状态,包括死亡、消失、响应事件等。 Particle State Kill Particles When Lifetime Has Elapsed 当粒子的生命周期结束时,销毁这些粒子。 Lifetime &…

机器学习(吴恩达)

一, 机器学习 机器学习定义: 计算机能够在没有明确的编程情况下学习 特征: 特征是描述样本的属性或变量,是模型用来学习和预测的基础。如: 房屋面积, 地理位置 标签: 监督学习中需要预测的目标变量,是模型的输出目标。如: 房屋价格 样本: 如: {面积100㎡…