pytorch实现图像分类器

news2024/11/25 11:19:00

pytorch实现图像分类器

  • 一、定义LeNet网络模型
    • 1,卷积 Conv2d
    • 2,池化 MaxPool2d
    • 3,Tensor的展平:view()
    • 4,全连接 Linear
    • 5,代码:定义 LeNet 网络模型
  • 二、训练并保存网络参数
    • 1,数据预处理
    • 2,数据集
    • 3,代码
  • 三、图像分类测试

一、定义LeNet网络模型

pytorch 中的卷积、池化、输入输出层中参数的含义与位置,可参考下图:
在这里插入图片描述

1,卷积 Conv2d

常用的卷积(Conv2d)在pytorch中对应的函数是

# in_channels:输入特征矩阵的深度。如输入一张RGB彩色图像,那in_channels=3
# out_channels:输入特征矩阵的深度。也等于卷积核的个数,使用n个卷积核输出的特征矩阵深度就是n
# kernel_size:卷积核的尺寸。可以是int类型,如3 代表卷积核的height=width=3,也可以是tuple类型如(3, 5)代表卷积核的height=3,width=5
# stride:卷积核的步长。默认为1,和kernel_size一样输入可以是int型,也可以是tuple类型
# padding:补零操作,默认为0。可以为int型如1即补一圈0,如果输入为tuple型如(2, 1) 代表在上下补2行,左右补1列。
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

经卷积后的输出层尺寸计算公式为:
在这里插入图片描述
例如:当定义 Conv2d(3, 16, 5) 和 input(3, 32, 32),步长 S 为 1,P 为0时,此时卷积核尺度 F 为5,W 为32,计算得到 output(16, 28, 28)

2,池化 MaxPool2d

最大池化(MaxPool2d)在 pytorch 中对应的函数是:

MaxPool2d(kernel_size, stride)

3,Tensor的展平:view()

注意到,在经过第二个池化层后,数据还是一个三维的Tensor (32, 5, 5),需要先经过展平后(3255)再传到全连接层:

  x = self.pool2(x)            # 第二个池化层 output(32, 5, 5)
  x = x.view(-1, 32*5*5)       # 展平 output(32*5*5)
  x = F.relu(self.fc1(x))      # 传到全连接层 output(120)

4,全连接 Linear

全连接(Linear)在 pytorch 中对应的函数是:

Linear(in_features, out_features, bias=True)

5,代码:定义 LeNet 网络模型

model.py

# 定义LeNet网络模型
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):                  # 继承于nn.Module这个父类
    def __init__(self):                  # 初始化网络结构
        super(LeNet, self).__init__()    # 多继承需用到super函数
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):            # 正向传播过程
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

二、训练并保存网络参数

1,数据预处理

ToTensor:把输入的图像数据为 shape (H x W x C) in the range [0, 255] 转化为 shape (C x H x W) in the range [0.0, 1.0],同时将 image 和 numpy 输入格式转化为 tensor
Normalize:标准化

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

2,数据集

用的是CIFAR10数据集,是 pytorch 自带的一个很经典的图像分类数据集,一共包含 10 个类别的 RGB 彩色图片。
在这里插入图片描述
注意:第一次运行程序,需要下载数据集到本地,所以第一次运行训练集下载时download=True为True,下载完成后改为False。测试集的加载则不用变化。

3,代码

名词定义
epoch对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step对一个batch的数据训练的过程称为 一个 iteration 或 step
# 加载数据集并训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数
import torch
import torchvision
import torch.nn as nn
from model import LeNet 
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 导入、加载训练集
# 导入50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data',      # 数据集存放目录
                                        train=True,          # 表示是数据集中的训练集
                                        download=False,       # 第一次运行时为True,下载数据集,下载完成后改为False
                                        transform=transform) # 预处理过程
# 加载训练集,实际过程需要分批次(batch)训练                                        
train_loader = torch.utils.data.DataLoader(train_set,      # 导入的训练集
                                           batch_size=50,  # 每批训练的样本数
                                           shuffle=False,  # 是否打乱训练集
                                           num_workers=0)  # 使用线程数,在windows下设置为0

# 导入测试集
# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', 
                                        train=False,    # 表示是数据集中的测试集
                                        download=False,transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set, 
                                          batch_size=10000, # 每批用于验证的样本数
                                          shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

#训练过程
net = LeNet()                                       # 定义训练的网络模型
loss_function = nn.CrossEntropyLoss()               # 定义损失函数为交叉熵损失函数 
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)

for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练
    running_loss = 0.0	# 累加过程中的损失
    time_start = time.perf_counter()
    
    for step, data in enumerate(train_loader, start=0):   # enumerate遍历训练集,可以同时返回 data 和 步数,step从0开始计算
        inputs, labels = data 	# 获取训练集的图像和标签
        optimizer.zero_grad()   # 清除历史损失梯度
        
        # forward + backward + optimize
        outputs = net(inputs)  				  # 正向传播
        loss = loss_function(outputs, labels) # 计算损失
        loss.backward() 					  # 反向传播
        optimizer.step() 					  # 优化器更新参数

        # 打印耗时、损失、准确率等数据
        running_loss += loss.item()
        if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次
            with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
                outputs = net(test_image) 				 # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]
                predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                
                print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy
                      (epoch + 1, step + 1, running_loss / 500, accuracy))
                
                print('%f s' % (time.perf_counter() - time_start))        # 打印耗时
                running_loss = 0.0

print('Finished Training')

# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

三、图像分类测试

使用训练并保存好的网络参数,从数据集外找一张图像进行分类测试

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

# 数据预处理
transform = transforms.Compose(
    [transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])    # 数据标准化

# 导入要测试的图像
im = Image.open('./car.jpg').convert('RGB')    # 若图像为4通道,则用 convert('RGB') 转化为3通道,否则 transform 会报错
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width] 

# 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))

# 预测
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()    # 找出最大概率的下标
	predicts = torch.softmax(outputs , dim=1)    # 所有分类的预测概率
print(classes[int(predict)])
print(predicts)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/649059.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mybatis-plus在实际开发中的应用

文章目录 前言一、实体类的注解二、Req查询条件三、Controller接口四、Service接口五、Service接口实现类六、Mapper接口七、枚举的使用总结 前言 最近的项目是使用mybatis-plus作为持久层框架,前面也记录过mybatis-plus的基本使用,此次记录一下本次项目…

行业报告|2022年智能制造人才发展报告:自动化、PLC、机器人等控制执行类研发岗需求增长快

原创 | 文 BFT机器人 近年来,我国智能制造应用规模和发展水平大幅跃升,制造业智能化发展成效明显,有力支撑工业经济的高质量发展。与此同时,我国在2022年首次出现人口负增长,该趋势下我国发展制造业的人口红利正逐步降…

BurpSuite全平台通用扩展

前言 昨天分享的关于springboot3集成ChatGPT实现AI聊天、生成图片,被CSDN以违规拒发了,找人工客服最后一肚子气,从上大学开始入住CSDN一直偏爱,确实不想放弃,似乎现在他已经不再是一个纯技术交流分享平台了&#xff0…

LaWGPT:你的私人法律顾问!

LaWGPT:你的私人法律顾问! LaWGPT 是一系列基于中文法律知识的开源大语言模型。 该系列模型在通用中文基座模型(如 Chinese-LLaMA、ChatGLM 等)的基础上扩充法律领域专有词表、大规模中文法律语料预训练,增强了大模型在…

工商业储能解读

工商业储能解读 0、前言1、2022-2023年工商业储能相关利好政策1.1 2022年1月4日1.2 2022年1月18日1.3 2022年2月10日1.4 2022年3月21日1.5 2022年3月22日1.6 2022年3月29日1.7 2022年4月2日1.8 2022年4月13日1.9 2022年4月25日1.10 2022年5月25日1.11 2022年5月30日1.12 2022年…

传输平台太多?难以管理?看这款跨网传输系统怎样解决

传输作为企业正常运行中最日常的行为,也意味着出现频率最高。微信、QQ、邮件、或是钉钉等办公软件,每天大家上班时开着各种软件,进行着不同的信息交互与传输。很多员工在工作时往往是哪个软件方便顺手就用哪个传输,但是这样也意味…

AI绘画Midjourney的咒语关键词汇总结

近期很多人都在研究Ai,被他强大的运算和准确性所震撼,和我们设计师相关的一个Ai绘画工具-Midjourney,绝对是占设计圈头部流量的,在圈内掀起一片热潮,今天我们就专门围绕他来展开说说,当然除了这个外,我们还…

Linux安装和配置VCenter

Linux安装和配置VCenter 以下演示安装 Linux VCenter,也就是使用VMware-VCSA-all-6.7.0-13010631.iso 镜像包。通过一台 Windows服务器远程连接 ESXI 服务器安装 Linux 版本的 VCenter。也就是Windows 服务器只是安装的界面的一个载体。 Linux VCenter环境搭建 下…

LLM 优先的软件架构:源自 ArchGuard Co-mate 的四个基本设计原则

在优化 ArchGuard 的 AI 辅助架构治理工具 Co-mate 的架构时,发现有一些模式与之前设计 AutoDev、ClickPrompt 等颇为相似。便思考着适合于 ArchGuard Co-mate 的架构设计原则是什么,写下了初步的三条原则。 而正好要在公司内分享 LLM 架构,…

【软考程序员学习笔记】——多媒体基础知识

目录 🍊 一、多媒体的概念及分类 多媒体的分类 🍊二、声音信号的数字化过程 采样 量化 编码 🍊三、常见音频文件格式 🍊四、图形/图像区别 图形 图像 🍊五、常见图像文件格式 🍊六、常见视频文件…

JDK8-2-流(3)- 流操作-distinct

JDK8-2-流&#xff08;3&#xff09;- 流操作-distinct 去重操作&#xff0c;如下开头两个菜品一样&#xff0c;对 menu 去重如下&#xff1a; public class DishDistinctTest1 {public static final List<Dish> menu Arrays.asList(new Dish("pork", false…

享元模式(十四)

每天都是全新的一天&#xff0c;感谢今日努力的自己。 上一章简单介绍了外观模式(十三), 如果没有看过, 请观看上一章 一. 享元模式 引用 菜鸟教程里面的外观模式介绍: https://www.runoob.com/design-pattern/flyweight-pattern.html 享元模式&#xff08;Flyweight Patter…

后,配置文件被清空,导致无法开启WiFi

root cause&#xff1a; /data/vendor/wifi/wpa/wpa_supplicant.conf 是0字节&#xff0c;导致wpa_supplicant_init_iface缺少”p2p_disabled1“的配置就会在走错flow到p2p wpa_supplicant_init_iface 》wpas_p2p_init 从而在HidlManager::registerInterface 进入 if (isP2pIf…

DeepSpeed零冗余优化器Zero Redundancy Optimizer

零冗余优化器 内容 零概述培训环境启用零优化 训练 1.5B 参数 GPT-2 模型训练 10B 参数 GPT-2 模型使用 ZeRO-Infinity 训练万亿级模型 使用 ZeRO-Infinity 卸载到 CPU 和 NVMe分配 Massive Megatron-LM 模型以内存为中心的平铺注册外部参数提取权重 如果您还没有这样做&…

【LeetCode】HOT 100(12)

题单介绍&#xff1a; 精选 100 道力扣&#xff08;LeetCode&#xff09;上最热门的题目&#xff0c;适合初识算法与数据结构的新手和想要在短时间内高效提升的人&#xff0c;熟练掌握这 100 道题&#xff0c;你就已经具备了在代码世界通行的基本能力。 目录 题单介绍&#…

佩戴比较舒适的蓝牙耳机有哪些?值得入手的蓝牙耳机分享

​对于年轻人来说&#xff0c;耳机使用场景丰富&#xff0c;时尚追求度高&#xff0c;喜好的音乐类型也是多种多样&#xff0c;需求侧重也不尽相同。下面我来推荐几款相当不错的蓝牙耳机给大家&#xff0c;总会有喜欢那款&#xff01; 一、南卡OE蓝牙耳机 佩戴舒适度打分&…

【QQ界面展示-设置消息正文的背景图 Objective-C语言】

一、咱们上午说到哪儿了,还记得吗, 1.咱们上午是不是说到这儿了,可以显示正文、可以显示文字、并且,设置好背景图片了, 现在的问题就是,正文里面的文字,是不是超出这个图片了, 正文里面的文字,超出背景图片了, 那么,接下来,就给大家看一下,怎么解决这个问题, …

Macbook Pro双系统装Window10后设置触摸屏滑动方向

最近想给自己的Macbook Pro装Windows10操作系统&#xff0c;毕竟Windows才是真正的生产力工具&#xff0c;装了以后不需要两台笔记本了&#xff0c;直接在一台笔记本上有MacOS和Windows 装好以后发现触摸屏不能轻点触控还有触摸屏的滑动方向是反的 第一个问题&#xff0c;不能轻…

Ansys Zemax | 如何在序列模式下模拟分光棱镜

概述 这篇文章介绍了&#xff1a; 如何在序列模式下使用多重结构创建分光棱镜 如何在布局图以及分析/计算窗口中同时追迹透射和反射光线 在考虑偏振及镀膜的影响下如何计算透射和反射光线的总能量 &#xff08;联系我们获取文章附件&#xff09; 介绍 在 OpticStudio 中…

xxlJob任务管理平台500:xxl-job remoting error(connect timed out)

目录 一、问题截图 二、问题处理 2.1.查看执行器地址 2.2.查看本地端口 2.3.总结 三、关于地址的题外话 一、问题截图 此时可以看到code500&#xff0c;msg是连接超时&#xff0c;说明地址不通&#xff0c;那就是查看地址配置。 二、问题处理 2.1.查看执行器地址 …