12 从0开始学PyTorch| PyTorch全连接网络:建立区分鸟和飞机的模型

news2024/11/24 13:49:14

上一小节我们终于开始搭建神经网络了,只不过它很简单,并且对我们更早的时候做的温度计转换模型做了一次迭代,甚至连效果都没有太大的变化,这一小节我们开始处理一些有意思的事情:做一个图像分类的模型。

小图像数据集

今天要用的数据集称为CIFAR-10,关于这个数据集我前几天还看到一个跟它相关的趣闻,谷歌一个大牛发布了一篇论文,用数万美元 TPU 算力,实现在 CIFAR-10 上 0.03% 的改进,创造了新的 SOTA,受到了很多人的质疑。因为CIFAR-10数据集过于简单,不过这仍然是一个非常经典的数据集,很适合我们拿来做小实验。

CIFAR-10数据集一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。
示例如下,可以看到32×32已经很模糊了,不过人眼大概也能看出是什么东西。

image.png

image.png

我们先把数据集下载下来。下载的过程可能有点慢。

这里用到了一个方法torch.manual_seed,用来设置CPU生成随机数的种子,方便下次复现实验结果。

%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

from torchvision import datasets
data_path = '../data-unversioned/p1ch7/'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True) # 这里用train=True限定下载训练集,下面一行用train=False限定下载验证集
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True) 

Dataset类

下载完数据集之后,这里需要介绍一个数据集的类Dataset,我们也可以自己构建数据集并使它符合Dataset的规范,这样我们可以使用一些Dataset的方法。
先把下载的数据集展示一下

class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck'] #这里给每个类别定义了名字

fig = plt.figure(figsize=(8,3))
num_classes = 10
for i in range(num_classes):
    ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
    ax.set_title(class_names[i])
    img = next(img for img, label in cifar10 if label == i)
    plt.imshow(img)
plt.show()

方法 _ _ mro _ _ :(这里为了清除的看,以及屏蔽markdown的语法,下划线之间加了空格,本来是没有空格的)
python 类有多继承特性,如果继承关系太复杂,很难看出会先调用那个属性或方法。
为了方便且快速地看清继承关系和顺序,可以用 _ _ mro_ _方法来获取这个类的调用顺序。

type(cifar10).__mro__
outs:
(torchvision.datasets.cifar.CIFAR10,
 torchvision.datasets.vision.VisionDataset,
 torch.utils.data.dataset.Dataset,  #可以看到首先继承了视觉数据集VisionDataset,然后继承了Dataset对象
 object)

Dataset类有两个函数对象_ _ len _ _ () 和 _ _ getitem_ _()
len返回的是数据集的大小,getitem可以按序号返回具体的数据item。

len(cifar10)
outs:
50000

img, label = cifar10[99]
img, label, class_names[label]
outs:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x7FB383657390>,
 1,
 'automobile')

plt.imshow(img) #把这个图画出来
plt.show()

这是一辆可爱的红色小汽车

image.png

image.png

数据集我们现在已经有了,接下来我们要回忆一下,在最开头的给图片分类的试验中,我们还需要一个预处理的环节,在里面对图像做了各种变换,然后才能够输入到模型中,如果你已经忘了可以翻一下这个系列的第0节课看一下。PyTorch提供了丰富的图像变换方法,方便我们对图像做各种预处理。

from torchvision import transforms
dir(transforms) #通过dir方法查看transforms包里面都包含哪些方法,如果你有兴趣可以查一下每个方法的作用
outs:
['CenterCrop', #我这里写几个注释,中央裁剪
 'CenterCropVideo',#中央裁剪视频
 'ColorJitter',
 'Compose',#压缩
 'FiveCrop',
 'Grayscale',
 'Lambda',
 'LinearTransformation',
 'Normalize', #标准化
 'NormalizeVideo',
 'Pad',
 'RandomAffine',
 'RandomApply',
 'RandomChoice',
 'RandomCrop',#随机裁剪
 'RandomCropVideo',
 'RandomErasing',
 'RandomGrayscale',
 'RandomHorizontalFlip',
 'RandomHorizontalFlipVideo',
 'RandomOrder',
 'RandomPerspective',
 'RandomResizedCrop',
 'RandomResizedCropVideo',
 'RandomRotation',
 'RandomSizedCrop',
 'RandomVerticalFlip',
 'Resize',
 'Scale',
 'TenCrop',
 'ToPILImage',
 'ToTensor', #转换为Tensor
 'ToTensorVideo',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'functional',
 'functional_video',
 'transforms',
 'transforms_video']

这个时候我们先不做其他的变换,我们看到了最重要的那个ToTensor,它可以帮我们把图片转换成tensor,从而才能进行接下来的工作。

from torchvision import transforms

to_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape
outs:
torch.Size([3, 32, 32])

除了显式调用,我们还可以把它作为一个参数传入Dataset.CIFAR10来获得tensor

关于下面所用的 单下划线 _
按照习惯,有时候单个独立下划线是用作一个名字,来表示某个变量是临时的或无关紧要的。

由于tensor_cifar10[99]返回的是两个参数,一个是图像,一个是标签,这里我们不用标签,所以用_作为占位符接收结果

tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False,
                          transform=transforms.ToTensor())
img_t, _ = tensor_cifar10[99]
type(img_t)
outs: torch.Tensor

在ToTensor变换中,数据的值也会被缩小到0-1的范围。

img_t.min(), img_t.max()
outs:
(tensor(0.), tensor(1.))

接下来我们需要对数据进行标准化,前面我们学到过,标准化有助于我们的训练。
在这里需要为RGB三通道每个通道的数据进行处理,使得标准化后的数据均值为0,标准差为1。

imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3) #把5w张图都读进去,此时imgs的大小为 3 * 32 * 32 * 50000
imgs.view(3, -1).mean(dim=1)  # view建立一个视图,只保留第一维,剩下的维度合成第二维,这样计算出的是RGB三通道每个通道的均值和标准差
imgs.view(3, -1).std(dim=1)
transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616)) #接着调用Normalize方法,对数据进行标准化

接下来把标准化跟之前的转为tensor合起来,作为我们的预处理方法,分别处理训练集和验证集。

transformed_cifar10 = datasets.CIFAR10(
    data_path, train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

transformed_cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

这时候再把图像画出来看看

img_t, _ = transformed_cifar10[99]

plt.imshow(img_t.permute(1, 2, 0))
plt.show()

image.png

image.png

可以看到图像更模糊了,有点像动画的感觉,而且这里Notebook给出了一个提示

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
#这是说输入到imshow方法的数据范围应该在0-255的整数范围或者0-1的浮点数范围,这里我们有一些负的值显然不能满足要求,所以你可以看到大片区域都黑了。

2分类:鸟还是飞机

我们的数据集有10个类别,这里我们先不做那么多的分类,我们先处理一个二分类问题,把鸟和飞机的图像拿出来,做一个全连接的网络来学习,看看能不能用神经网络模型来区分这两个类别。

label_map = {0: 0, 2: 1} #把原来的标签映射到新的标签上,原来分别是0和2,映射到0和1
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10 
          if label in [0, 2]]
cifar2_val = [(img, label_map[label])
              for img, label in cifar10_val
              if label in [0, 2]]

我们筛选出数据之后,就要考虑怎么把数据丢进模型里面,这里用一个最简单的办法,就是把图像的数据顺序拼成一个一维向量,如下图所示,然后经过全连接网络,最后输出两个概率值,分别对应它是鸟的概率或者是飞机的概率。

image.png

image.png

我们的数据是3 * 32 * 32,也就是有3072个元素,接下来就构建我们的模型。

import torch.nn as nn

n_out = 2

model = nn.Sequential(
            nn.Linear(#输入单元
                3072,  
                512,   
            ),
            nn.Tanh(), #激活层
            nn.Linear( #输出单元
                512,   
                n_out, 
            )
        )

对于上面输入单元,有3072的输入维度,512的输出维度,那么在神经网络中就需要3072 * 512个参数,约为157w。后面我们会再讨论关于参数量的问题。
这里我们定义好了模型,下面来看一下该怎么输出分类结果。
考虑我们前面关于温度变换的输出,是预测一个温度结果值,但是在这个分类任务中,我们这里有两个结果,0表示飞机,1表示鸟。前面我们说过处理这种分类数值,可以使用one-hot编码,对出现飞机的时候表示为[1,0],对出现鸟的时候表示为[0,1]。我们在样本上可以做这样的标注,但是让模型给出这样确定的结果确实有点为难,因为它是从数学的角度来理解这个图像到底是鸟还是飞机,比如说你给它一个鸟型的飞机,那它可能就搞不清楚了。思考我们的模型,上面的一个点,对于某个类别的结果损失是比较低的,而另一个结果损失会更高。因此模型给出的结果介于0-1之间,我们可以认为第1项是‘飞机’的概率,第2项是‘鸟’的概率。

根据上面的假设,我们期望我们的模型输出结果满足两个条件:

  • 输出的每个元素值介于0-1之间
  • 输出的元素综总和为1

这听起来似乎有点难度,但是有一个很聪明的函数可以实现这个功能,最关键的是它还可微,满足我们计算梯度的需求,那就是softmax。

softmax方法

softmax的表达式如下,原图上貌似少了几个加号,我给加上了,你应该能看出来吧

image.png

image.png

从字面意思来理解,softmax就是软化的最大化方法,相对而言就有hardmax,比如说我们给最高的那个置为1,其他的都置为0;或者用xi / xi累加,那为啥要用softmax,肯定是有它的优势,这是一个很深入的问题,我们有时间再去研究。

这里我们不妨试一下softmax的效果:

def softmax(x):
    return torch.exp(x) / torch.exp(x).sum()

x = torch.tensor([1.0, 2.0, 3.0])

softmax(x)
outs:#经过softmax运算之后,1,2,3变为如下
tensor([0.0900, 0.2447, 0.6652])

显然softmax的结果满足我们前面提到的模型输出结果的两条要求。
同时,输出的结果中,对于较大的结果有了一定的放大作用,而对于较小的数值有了一定的缩小的作用。
这么好用的功能,自然是已经躺在nn模块里面,我们调用nn.Softmax()方法就可以了。

softmax = nn.Softmax(dim=1) #参数指定了softmax执行的维度

x = torch.tensor([[1.0, 2.0, 3.0],
                  [1.0, 2.0, 3.0]])

softmax(x)
outs:
tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])

#如果我们把dim改成0,输出就是下面这样,就沿着维度0进行了softmax,1和1,2和2,3和3
outs:
tensor([[0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000]])

既然确定要用softmax来作为我们输出时候的变换,那就把它加进我们的神经网络序列里面

model = nn.Sequential(
            nn.Linear(3072, 512),
            nn.Tanh(),
            nn.Linear(512, 2),
            nn.Softmax(dim=1))

这时候我们可以调用一下我们搭建的模型,让我们随便找一幅鸟的图像

img, _ = cifar2[0]

plt.imshow(img.permute(1, 2, 0))
plt.show()

这幅图看起来实在有点难以辨认。

image.png

image.png

然后把图像塞进我们的批数据中,并使用模型去预测它

img_batch = img.view(-1).unsqueeze(0)
out = model(img_batch)
out

outs:
tensor([[0.4784, 0.5216]], grad_fn=<SoftmaxBackward0>)

可以看到,我们的模型输出结果是飞机的概率为0.47,鸟的概率是0.52,它竟然分对了,当然这只是一个巧合,我们随机的模型参数恰巧在这个图像上获得了一个正确的结果。

今天的内容就先到这里,下一节我们再研究怎么去迭代模型的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/765292.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink数据流

文章目录 一.基本概念二.Flink和Spark三. Flink配置文件四. yarn部署flink4.1 session-cluster模式4.2 pre-job-cluster模式 五.Flink运行时架构5.1 任务提交流程5.2 如何实现并行计算5.3 并行任务需要占用多少slot5.4 一个流处理包含多少任务 一.基本概念 官网介绍 Apache F…

css 禁止多次点击导致的选中了目标div的文字

像下面这样的情况&#xff0c;就可以用这种方法避免掉 禁止多次点击&#xff0c;导致的&#xff0c;选中了目标div的文字 或者 禁止多次点击&#xff0c;导致&#xff0c;html结构被选中显示出来 .targetDiv {-webkit-user-select: none;-moz-user-select: none;-ms-user-sel…

Vue3卡片(Card)

可自定义设置以下属性&#xff1a; 卡片宽度&#xff08;width&#xff09;&#xff0c;类型&#xff1a;number | string&#xff0c;默认 ‘auto’是否有边框&#xff08;bordered&#xff09;&#xff0c;类型&#xff1a;boolean&#xff0c;默认 true卡片右上角的操作区域…

所有语言数据类型大汇总(持续更新)

一 c语言 参考 C语言-整数&#xff1a;short、int、long、long long&#xff08;signed和unsigned&#xff09;、原码、反码、补码_c语言signed是什么类型_Talent Q的博客-CSDN博客https://blog.csdn.net/qq_43177371/article/details/105703234 二 system verilog

服务器数据恢复-网站服务器宕机无法重启的数据恢复案例

服务器数据恢复环境&#xff1a; 一台linux操作系统网站服务器&#xff0c;该服务器上部署了几十个网站&#xff0c;服务器上只有一块SATA硬盘。 服务器故障&分析&#xff1a; 服务器正常运行中突然宕机&#xff0c;管理员尝试多次重新启动服务器失败&#xff0c;将服务器…

软件测试之测试用例设计方法

目录 1.基于需求设计测试用例 2.具体的测试用例设计方法 1.等价类 2.边界值法 3.判定表 1.基于需求设计测试用例 需求文档->梳理分析需求&#xff08;掌握需求&#xff09;->针对文档设计测试用例 在分析测试需求时&#xff0c;一般分为功能测试需求和非功能测试…

职工管理系统

woker.h #pragma once #include<iostream> #include<string> using namespace std; class worker { public://显示岗位信息virtual void showInfo() 0;//获取岗位名称virtual string getDeptName() 0;int m_Id;//职工编号string m_Name;//职工姓名int m_DeptId;…

大学生用一周时间给麦当劳做了个App(uni-app版)

背景 有个大学生粉丝最近私信联系我&#xff0c;说基于我之前开源的多语言项目做了个仿麦当劳的项目&#xff0c;虽然只是个样子货&#xff0c;但是收获颇多&#xff0c;希望把自己写的代码开源出来供大家一起学习进度。这个小伙伴确实是非常积极上进&#xff0c;很多大学生&a…

ssh 连接出现错误: kex_exchange_identification: Connection closed by remote host

错误如下表示&#xff1a; windstormLocalHost-Server ~> ssh webase-front192.168.122.22 Couldnt get a file descriptor referring to the console. fish: Unknown command: nc fish: exec nc -X connect -x 127.0.0.1:15732 192.168.122.22 22 ^^ kex_exchange_id…

个人博客系统(二)

该博客系统共有八个页面,即注册页面、登录页面、添加文章页面、修改文章页面、我的博客列表页面、主页、查看文章详情页面、个人中心页面。 1 注册页面 该页面如图所示: 首先,要先判断注册的用户名、密码、确认密码以及验证码是否为空,若有一个为空,点击提交,则会提醒 …

代码随想录二刷day56 | 动态规划之 583. 两个字符串的删除操作 72. 编辑距离

day56 583. 两个字符串的删除操作1.确定dp数组&#xff08;dp table&#xff09;以及下标的含义2.确定递推公式3.dp数组如何初始化4.确定遍历顺序5.举例推导dp数组 72. 编辑距离1. 确定dp数组&#xff08;dp table&#xff09;以及下标的含义2. 确定递推公式3. dp数组如何初始化…

信号采样基本概念 —— 4. 移动平均滤波(Moving Average Filtering)

对于信号的滤波算法中&#xff0c;除了FFT和小波&#xff08;wavelet&#xff09;以外&#xff0c;还有其他一些常见的滤波算法可以对信号denoising。接下来的几个章节里&#xff0c;将逐一介绍这些滤波算法。而今天首先要介绍的就是&#xff0c;移动平均滤波&#xff08;Movin…

android studio 离线打包配置push模块

1.依赖引入 SDK\libs aps-release.aar, aps-unipush-release.aar, gtc.aar, gtsdk-3.2.11.0.aar, 从android studio的sdk中找到对应的包放到HBuilder-Integrate-AS\simpleDemo\libs下面 2.打开build.gradle&#xff0c;在defaultConfig添加manifestPlaceholders节点&#xff0c…

浅谈vue3与vue2的区别

vue3已经出来有一段时间了&#xff0c;相信很多公司项目都已经在用vue3重构项目&#xff0c;或者在新项目中直接用vue3搭建&#xff0c;那么我们学习vue3的必要性就有了。 v2 与 v3 的区别 v3 采用的是 monorepo 方式进行管理&#xff0c;将模块拆分到 package 目录中v3 采用…

用 PerfView 洞察.NET程序非托管句柄泄露

一&#xff1a;背景 1. 讲故事 前几天写了一篇 如何洞察 .NET程序 非托管句柄泄露 的文章&#xff0c;文中使用 WinDbg 的 !htrace 命令实现了句柄泄露的洞察&#xff0c;在文末我也说了&#xff0c;WinDbg 是以侵入式的方式解决了这个问题&#xff0c;在生产环境中大多数情况…

C++ cin

cin 内容来自《C Primer》 cin使用>>运算符从输入流中抽取字符 int carrots;cin >> carrots;如下的例子&#xff0c;用户输入的字符串有空格 #include <iostream>int main() {using namespace std;const int ArSize 20;char name[ArSize]; //用户名char …

HIVE SQL实现通过两字段不分前后顺序去重

--数据建表 drop table if exists db.tb_name; create table if not exists db.tb_name ( suj1 string,suj2 string ) ;insert overwrite table db.tb_name values ("语文","数学") ,("语文","英语") ,("数学","语文&…

[禁止登录]登录失败,建议升级最新版本后重试,或通过问题反馈与我们联系。(错误码:45)

token失效:[禁止登录]登录失败&#xff0c;建议升级最新版本后重试&#xff0c;或通过问题反馈与我们联系。(错误码:45。 [禁止登录]登录失败&#xff0c;建议升级最新版本后重试&#xff0c;或通过问题反馈与我们联系。 使用go-cqhttp开发QQ机器人的时候遇到的问题&#xff0c…

小白入门深度学习 | 6-5:Inception-v1(2014年)详解

1. 理论知识 GoogLeNet首次出现在2014年ILSVRC 比赛中获得冠军。这次的版本通常称其为Inception V1。Inception V1有22层深,参数量为5M。同一时期的VGGNet性能和Inception V1差不多,但是参数量也是远大于Inception V1。 Inception Module是Inception V1的核心组成单元,提出…

市面上的充电桩分类以及系统分析

摘要&#xff1a;智能用电小区是国家电网为了研究智能电网智能用电的先进技术如何运用于居民区&#xff0c;提高人民的生活水平&#xff0c;提高电网智能化水平以及提升用电服务质量而进行的一项尝试。电动汽车作为智能用电小区建设的一个组成部分同样也逐渐被纳入发展规划&…