如何用pytorch进行图像分类

news2024/12/23 14:18:55

如何用pytorch进行图像分类

在这里插入图片描述

使用PyTorch进行图像分类是深度学习中的一个常见任务,涉及一系列步骤,从数据预处理到模型训练和评估。下面将详细描述每个步骤,从零开始构建一个图像分类器。

1. 安装必要的库

在开始之前,首先需要确保已经安装了PyTorch及其相关的库,这些库包括torch、torchvision(用于处理图像数据集)以及matplotlib(用于数据可视化)。这些库可以通过pip进行安装:

pip install torch torchvision matplotlib

2. 导入必要的库

在编写代码前,需要导入PyTorch和相关的Python库,这些库将为我们提供创建、训练和测试神经网络所需的工具。

import torch
import torch.nn as nn  # 用于构建神经网络
import torch.optim as optim  # 用于优化网络
import torchvision  # 包含了流行的数据集和模型
import torchvision.transforms as transforms  # 用于数据增强和预处理
import matplotlib.pyplot as plt  # 用于绘图和数据可视化

3. 数据预处理

在进行图像分类之前,需要对图像数据进行预处理。常见的预处理步骤包括调整图像大小、将图像转换为PyTorch张量(Tensor)格式、以及对图像进行标准化。
在这里插入图片描述

transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将所有图像调整为32x32像素
    transforms.ToTensor(),  # 将图像转换为Tensor格式,范围为[0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到[-1, 1]范围
])
  • Resize:调整图像大小,使所有图像的尺寸一致,方便后续处理。
  • ToTensor:将图像从PIL Image格式转换为PyTorch张量。
  • Normalize:将图像的每个通道(红、绿、蓝)的像素值标准化,使其均值为0.5,标准差为0.5,这有助于加速模型的收敛。

4. 加载数据集

PyTorch提供了许多常用的数据集,例如CIFAR-10。我们可以使用torchvision.datasets来轻松加载这些数据集,并使用DataLoader类来迭代数据。

# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

# CIFAR-10数据集中的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  • CIFAR-10:这是一个包含10个类别的彩色图像数据集,每个类别包含6000张32x32的图像。
  • DataLoader:这是PyTorch中用于批量加载数据的工具,batch_size指定每个批次加载的图像数量,shuffle决定是否打乱数据顺序。

5. 定义神经网络

在这里插入图片描述
在这个步骤中,我们将定义一个简单的卷积神经网络(CNN),用于图像分类任务。CNN由一系列卷积层、池化层、激活函数和全连接层组成。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 第一层卷积,输入通道3(RGB),输出通道6,卷积核大小5x5
        self.pool = nn.MaxPool2d(2, 2)  # 最大池化层,窗口大小2x2
        self.conv2 = nn.Conv2d(6, 16, 5)  # 第二层卷积,输入通道6,输出通道16,卷积核大小5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全连接层,输入维度16*5*5,输出维度120
        self.fc2 = nn.Linear(120, 84)  # 第二个全连接层,输入维度120,输出维度84
        self.fc3 = nn.Linear(84, 10)  # 最后一层,全连接层,输出维度10(对应CIFAR-10的10个类别)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 卷积 -> ReLU激活 -> 最大池化
        x = self.pool(F.relu(self.conv2(x)))  # 卷积 -> ReLU激活 -> 最大池化
        x = x.view(-1, 16 * 5 * 5)  # 展平操作,将卷积层的输出展平成一维向量
        x = F.relu(self.fc1(x))  # 全连接 -> ReLU激活
        x = F.relu(self.fc2(x))  # 全连接 -> ReLU激活
        x = self.fc3(x)  # 全连接层输出分类结果
        return x

net = Net()
  • Conv2d:二维卷积层,用于提取图像的特征。
  • MaxPool2d:最大池化层,用于下采样,减少特征图的大小。
  • ReLU:一种常用的激活函数,能够增加模型的非线性。

6. 定义损失函数和优化器

损失函数用于衡量模型输出与真实标签之间的差距,而优化器用于更新模型参数,以最小化损失函数。

criterion = nn.CrossEntropyLoss()  # 交叉熵损失,用于分类任务
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # 随机梯度下降优化器,带动量
  • CrossEntropyLoss:交叉熵损失函数,常用于多分类任务。
  • SGD:随机梯度下降,lr是学习率,momentum是动量,用于加速收敛。

7. 训练模型

在这里插入图片描述

模型的训练过程通常涉及多个epoch,每个epoch是一次完整的训练集迭代。在每个epoch中,我们通过前向传播计算输出,通过损失函数计算损失,然后通过反向传播更新模型的参数。

for epoch in range(2):  # 训练2个epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data  # 获取输入数据和对应的标签

        optimizer.zero_grad()  # 清零梯度缓存

        outputs = net(inputs)  # 前向传播:计算输出
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播:计算梯度
        optimizer.step()  # 更新模型参数

        running_loss += loss.item()
        if i % 2000 == 1999:  # 每2000个mini-batch打印一次损失
            print(f'[Epoch {epoch + 1}, Mini-batch {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')
  • zero_grad:在每次迭代时清除上一次迭代的梯度。
  • backward:计算损失的梯度,并进行反向传播。
  • step:使用优化器更新模型参数。

8. 在测试集上评估模型

训练完成后,我们需要在测试集上评估模型的性能。通过比较模型的预测结果和真实标签,计算准确率。

correct = 0
total = 0
with torch.no_grad():  # 禁用梯度计算,以节省内存和加速计算
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)  # 获取最大值的索引,即预测的类别
        total += labels.size(0)  # 累计样本总数
        correct += (predicted == labels).sum().item()  # 累计正确预测的样本数

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
  • torch.no_grad():在评估模型时禁用梯度计算,以减少内存消耗。
  • torch.max:从模型输出中选择概率最大的类别。

总结

使用PyTorch进行图像分类是一项系统性任务,涉及数据预处理、模型构建、训练、评估和保存模型等多个环节。首先,我们通过数据预处理将图像转换为适合输入模型的格式,同时进行标准化以加速训练。然后,我们构建了一个简单的卷积神经网络(CNN),通过卷积层和池化层逐步提取图像的特征,最终通过全连接层输出分类结果。

在训练过程中,我们使用了交叉熵损失函数来度量模型预测与真实标签之间的差距,并通过随机梯度下降(SGD)优化器来更新模型的参数。训练过程涉及多次迭代,每次迭代都会通过前向传播计算输出,通过反向传播更新权重,从而使模型逐步学习到数据的特征。

完成训练后,我们在测试集上评估了模型的性能,计算了模型的准确率。这一过程通过禁用梯度计算加快了评估速度,并通过对比模型预测与真实标签的匹配程度,确定模型的准确性。

最后,我们将训练好的模型保存,以备将来使用或进一步微调。整个流程展示了如何从数据到模型,逐步实现图像分类任务。通过这种方法,可以灵活地调整网络架构、超参数和数据处理方式,来应对不同的图像分类任务,进一步提高模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2095132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

驱动(RK3588S)第四课时:模块化编程

目录 一、什么是模块化编程二、怎么把自己编译代码给加载到开发板上运行三、驱动编程的框架四、驱动编程具体实例1、编写单模块化驱动代码2、编写多模块化驱动代码3、编写向模块传参驱动代码4、编写多模块化驱动代码另一种方式 一、什么是模块化编程 在嵌入式里所谓的模块化编…

Vue——day07之条件渲染、列表渲染以及监测数据

目录 1.template标签 2.条件渲染 3.列表渲染 4.v-for中的key的作用以及原理 5.列表过滤 placeholder 前端空字符串 使用数据监视watch实现 使用计算属性实现 6.列表排序 7.Vue更新数据检测失败 原因 总结 1.template标签 template标签是Vue.js中的一个特殊元素&am…

kube-scheduler调度策略之预选策略(三)

一、概述 摘要:本文我们继续分析源码,并聚焦在预选策略的调度过程的执行。 二、正文 说明:基于 kubernetes v1.12.0 源码分析 上文我们说的(g *genericScheduler) Schedule()函数调用了findNodesThatFit()执行预选策略。 2.1 findNodesTha…

Truncated incorrect max_connections value: ‘999999‘

MySQL 的最大连接数(max_connections)可以设置的上限值在不同的资料中有所不同。以下是一些关键信息: 默认值和默认范围: MySQL 的默认最大连接数通常为 100 。一些资料提到默认值为 151 。 最大允许值: MySQL 的最大…

ant-design-vue:a-table表格中插入自定义按钮

本文将介绍如何使用ant-design-vue在a-table表格中加入自定义按钮和图标的代码。 结果如下图所示&#xff0c; 一、简单示例 <template><a-table:columns"columns":data-source"data":row-selection"rowSelection":ellipsis"tru…

对称密码学

1. 使用OpenSSL 命令行 在 Ubuntu Linux Distribution (发行版&#xff09;中&#xff0c; OpenSSL 通常可用。当然&#xff0c;如果不可用的话&#xff0c;也可以使用下以下命令安装 OpenSSL: $ sudo apt-get install openssl 安装完后可以使用以下命令检查 OpenSSL 版本&am…

深度学习基础案例4--构建CNN卷积神经网络实现对猴痘病的识别(测试集准确率86.5%)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 下一周会很忙&#xff0c;更新可能不及时&#xff0c;请大家见谅这个项目我感觉是一个很好的入门案例&#xff0c;但是自己测试的时候测试集准确率只比较…

mcu loader升级固件原理与实现

1 mcu loader升级固件原理 mcu 固件有两部分&#xff0c;如下图所示&#xff0c;一部分是 loader.bin&#xff0c;一部分是 app.bin&#xff0c;将两部分的固件合并在一起烧录进 mcu 的 flash 当中。mcu 上电进入loader 模式执行 loader.bin 部分的程序&#xff0c;然后读取 fl…

前端踩坑记录:javaScript复制对象和数组,不能简单地使用赋值运算

问题 如图&#xff0c;编辑table中某行的信息&#xff0c;发现在编辑框中修改名称的时候&#xff0c;表格中的信息同步更新。。。 检查原因 编辑页面打开时&#xff0c;需要读取选中行的信息&#xff0c;并在页面中回显。代码中直接将当前行的数据对象赋值给编辑框中的表单对…

51单片机——I2C总线

1、I2C总线简介 I2C总线&#xff08;Inter IC BUS&#xff09;是由Philips公司开发的一种通用数据总线 两根通信线&#xff1a;SCL&#xff08;Serial Clock&#xff09;、SDA&#xff08;Serial Data&#xff09; 同步、半双工&#xff0c;带数据应答 通用的I2C总线&#…

Linux基础(包括centos7安装、linux基础命令、vi编辑器)

一、安装CentOS7 需要&#xff1a;1、VMware Workstation&#xff1b;2、CentOS7镜像 1、安装镜像 2、虚拟机配置 开启虚拟机&#xff0c;鼠标从vm中移出来用快捷键ctrlalt 点击开始安装&#xff0c;设置密码&#xff0c;等待安装完成,&#xff0c;重启。 3、注意事项 如果没…

通往RAG之路(二):版面结构检测方法介绍

一、基于yolov5的版面结构检测 AG系统搭建过程中&#xff0c;版面分析是不可缺少的一个步骤&#xff0c;本文介绍用yolov5进行版面结构信息识别&#xff0c;后续再搭配表格识别、公式识别、文字识别等模块进行版面还原&#xff0c;完成PDF结构化输出。 1.1、环境搭建 conda c…

解决方案:在autodl环境下为什么已安装torch打印出来版本号对应不上

文章目录 一、现象二、解决方案 一、现象 平台&#xff1a;autodl 镜像&#xff1a;PyTorch 2.0.0 Python 3.8(ubuntu20.04) Cuda 11.8 GPU&#xff1a;A40(48GB) * 1 CPU&#xff1a;15 vCPU AMD EPYC 7543 32-Core Processor 内存&#xff1a;80GB 安装torch:1.13.0环境&a…

深入理解指针(6)

目录&#xff1a; 1.字符指针变量 2.数组指针变量 3.二维数组传参本质 4.函数指针变量 5.函数指针的应用 1.字符指针变量 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> int main() {char a w;char* p &a;printf("%p ", p);} 当我们想取出…

UE 【材质编辑】自定义材质节点

使用UE的材质编辑器&#xff0c;蓝图提供了大量的节点函数&#xff1a; 实际上&#xff0c;这是一段封装好的包含一串HLSL代码的容器。打开“Source/Runtime/Engine/Classes/Material”&#xff0c;可以看到很多不同节点的头文件&#xff1a; 照葫芦画瓢 以UMaterialExpressi…

ORM 编程思想

一、ORM简介 对象关系映射&#xff08;英语&#xff1a;Object Relational Mapping&#xff0c;简称ORM&#xff0c;或 O/R mapping&#xff09;是一种为了解决面向对象语言与关系数据库存在的 互不匹配的现象。 二、实体类 实体类就是一个定义了属性&#xff0c;拥有getter、…

51单片机——存储器

1、存储器简介 RAM优点存储速度非常快&#xff0c;缺点成本高&#xff0c;掉电丢失数据。 ROM优点掉电不丢失数据&#xff0c;缺点存储速度比较慢。 所以在实际应用中&#xff0c;我们都是采用两者结合的方式。程序运行时&#xff0c;数据存储在RAM中&#xff0c;需…

自己开发完整项目一、登录功能-04(集成jwt)

一、说明 前面文章我们写到了通过数据库查询出用户信息并返回&#xff0c;那么在真实的项目中呢&#xff0c;后端是需要给前端返回一个tocken&#xff0c;当前端通过登录功能认证成功之后&#xff0c;我们后端需要将用户信息和权限整合成一个tocken返回给前端&#xff0c;当前端…

【Python技术】使用langchain、fastapi、gradio生成一个简单的智谱AI问答界面

前几天&#xff0c;智谱AI BigModel开放平台宣布&#xff1a;GLM-4-Flash 大模型API完全免费了&#xff0c;同时开启了GLM-4-Flash 限时免费微调活动。对想薅免费大模型羊毛的个人玩家&#xff0c;这绝对是个好消息&#xff0c;我们不仅可以免费使用BigModel开放平台上的GLM-4-…

产品入门篇笔记

产品和产品经理 产品&#xff1a;解决某个问题的物品&#xff0c;无形、有形都可以。 产品经理&#xff1a;简单而言就是想清楚怎么做的人&#xff0c;需要想清楚产品怎么设计&#xff0c;要分析什么用户、在什么场景、怎么样的需求&#xff1b;然后考虑产品的功能、优势、价值…