从0训练一个神经网络分类器

news2024/9/20 20:39:32

从0训练一个神经网络分类器

      • 0. 关于数据?
    • 训练一个图像分类器
      • 1. 使用``torchvision``可以非常容易地加载CIFAR10。
      • 2. 定义一个卷积神经网络
      • 3. 定义损失函数和优化器
      • 4. 训练网路
      • 5. 在测试集上测试网络
      • 6. 检测网络在整个测试集上的结果如何。
      • 7. 在识别哪一个类的时候好,哪一个不好呢?
      • 8. 在GPU上训练

上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重。
你现在可能在想下一步。

0. 关于数据?


一般情况下处理图像、文本、音频和视频数据时,可以使用标准的Python包来加载数据到一个numpy数组中。
然后把这个数组转换成 torch.*Tensor

  • 图像可以使用 Pillow, OpenCV
  • 音频可以使用 scipy, librosa
  • 文本可以使用原始Python和Cython来加载,或者使用 NLTK或
    SpaCy 处理

特别的,对于图像任务,我们创建了一个包
torchvision,它包含了处理一些基本图像数据集的方法。这些数据集包括
Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器,
torchvision.datasetstorch.utils.data.DataLoader

torchvision包不仅提供了巨大的便利,也避免了代码的重复。

我们使用CIFAR10数据集,它有如下10个类别
:‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是
3x32x32大小的,即,3颜色通道,32x32像素。

训练一个图像分类器

依次按照下列顺序进行:

  1. 使用torchvision加载和归一化CIFAR10训练集和测试集

  2. 定义一个卷积神经网络

  3. 定义损失函数

  4. 在训练集上训练网络

  5. 在测试集上测试网络

  6. 读取和归一化 CIFAR10


1. 使用torchvision可以非常容易地加载CIFAR10。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

我们展示一些训练图像。

import matplotlib.pyplot as plt
import numpy as np

# 展示图像的函数


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

2. 定义一个卷积神经网络


从之前的神经网络一节复制神经网络代码,并修改为输入3通道图像。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. 定义损失函数和优化器


我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练网路


在数据迭代器上循环,将数据输入给网络,并优化。

for epoch in range(epoch):  # 多批次循环

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入
        inputs, labels = data

        # 梯度置0
        optimizer.zero_grad()

        # 正向传播,反向传播,优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印状态信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000批次打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. 在测试集上测试网络


我们在整个训练集上进行了epoch次训练,但是我们需要检查网络是否从数据集中学习到有用的东西。
通过预测神经网络输出的类别标签与实际情况标签进行对比来进行检测。
如果预测正确,我们把该样本添加到正确预测列表。
第一步,显示测试集中的图片并熟悉图片内容。

dataiter = iter(testloader)
images, labels = dataiter.next()

# 显示图片
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

在这里插入图片描述
让我们看看神经网络认为以上图片是什么。

outputs = net(images)

输出是10个标签的能量。
一个类别的能量越大,神经网络越认为它是这个类别。所以让我们得到最高能量的标签。

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
# Predicted:  plane plane plane plane

6. 检测网络在整个测试集上的结果如何。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

结果看起来不错,至少比随机选择要好,随机选择的正确率为10%。
似乎网络学习到了一些东西。

7. 在识别哪一个类的时候好,哪一个不好呢?

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

在这里插入图片描述
下一步?

我们如何在GPU上运行神经网络呢?

8. 在GPU上训练


把一个神经网络移动到GPU上训练就像把一个Tensor转换GPU上一样简单。并且这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 确认我们的电脑支持CUDA,然后显示CUDA信息:

print(device)

本节的其余部分假定device是CUDA设备。

然后这些方法将递归遍历所有模块并将模块的参数和缓冲区
转换成CUDA张量:

net.to(device)

记住:inputs, targets 和 images 也要转换。

inputs, labels = inputs.to(device), labels.to(device)

为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。

实践:
尝试增加你的网络的宽度(第一个nn.Conv2d的第2个参数,第二个nn.Conv2d的第一个参数,它们需要是相同的数字),看看你得到了什么样的加速。

实现的目标:

  • 深入了解了PyTorch的张量库和神经网络
  • 训练了一个小网络来分类图片

参考链接:

  1. PyTorch 深度学习:60分钟快速入门 (官方)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/729658.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL基操例题

Ⅰ创建数据库使用create语句: create database 数据库名; Ⅱ创建表同理: create table if not exists 表名 ( 字段名 数据类型 主键 约束, 字段名 数据类型 主键 约束) 设置存储引擎和字符集; …

Text-Augmented Open Knowledge Graph Completion viaPre-Trained Language Models

摘要 开放知识图谱(KG)完成的任务是从已知的事实中得出新的发现。增加KG完成度的现有工作需要(1)事实三元组来扩大图推理空间,或(2)手动设计提示从预训练的语言模型(PLM)中提取知识,表现出有限的性能,需要专家付出昂贵的努力。为此,我们提出了TAGREAL,它自动生成高质量的…

【youcans动手学模型】SENet 模型及 PyTorch 实现

欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 【youcans动手学模型】SENet 模型 【经典模型】SENet 模型-Cifar10图像分类1. SENet 卷积神经网络模型1.1 模型简介1.2 论文介绍1.3 分析与讨论 2. 在 PyTorch 中定义 SENet 模型类2.1 定义 SE Block…

STL好难(5):stack的使用

目录 1.stack的介绍和使用: 2.stack的使用 3.有关stack的练习题: 🍉最小栈 🍉栈的压入、弹出序列 4.stack的模拟实现: 1.stack的介绍和使用: 点击查看stack的文档介绍 1. stack是一种容器适配器&#…

(vue)element-ui表格中插入switch开关

(vue)element-ui表格中插入switch开关 效果&#xff1a; <el-table-column property"enabled" label"启用/禁用" width"150"><template slot-scope"scope"> <el-switchv-model"scope.row.enabled"active-co…

动态规划之746 使用最小花费爬楼梯(第3道)

题目&#xff1a; 给你一个整数数组 cost &#xff0c;其中 cost[i] 是从楼梯第 i 个台阶向上爬需要支付的费用。一旦你支付此费用&#xff0c;即可选择向上爬一个或者两个台阶。 你可以选择从下标为 0 或下标为 1 的台阶开始爬楼梯。 请你计算并返回达到楼梯顶部的最低花费…

差分学习笔记

1.前言 同步于 c n b l o g s cnblogs cnblogs 发布。 前置芝士&#xff1a; 基本树上操作&#xff0c;lca。&#xff08;用于树上差分。&#xff09; 如有错误&#xff0c;欢迎各位大佬指出。&#xff08;顺便复习一下远古算法。&#xff09; 2.什么是差分 我们先给定一…

AR增强现实技术解决企业远程协作需求

随着科技的不断发展&#xff0c;AR(增强现实)远程协同系统已经成为了一种新型的工作方式。这种系统利用AR技术将虚拟信息叠加到现实世界中&#xff0c;从而实现异地高效协作。 由广州华锐互动开发的AR远程协同系统&#xff0c;广泛应用于各个行业的远程协作场景中&#xff0c;…

44. 通配符匹配(从暴力递归到动态规划)

题目链接&#xff1a;力扣 所有的动态规划都可以使用暴力递归求解&#xff0c;如果推导dp方程比较困难&#xff0c;可以先使用暴力递归进行尝试&#xff0c;然后将从递归改为动态规划&#xff0c;这种方式在dp方程求解困难的情况下非常有效&#xff0c;而且从递归修改为动态规划…

计算机网络?

那么这样能通过审核吗&#xff1f;

二次元古代美女【InsCode Stable Diffusion美图活动一期】

二次元古代美女【InsCode Stable Diffusion美图活动一期】 一、前言二、初识 InsCode三、 试玩 Stable Diffusion 模型1.阅读Stable Diffusion 模型在线引导说明2.实际体验 Stable Diffusion 模型 四、模型相关版本和参数配置&#xff1a;五、图片生成提示词与反向提示词六、种…

游戏术语英语

王者荣耀英文术语大全&#xff01;玩这么久你都听懂了吗&#xff1f; 王者荣耀AP、AD、ADC、AOE等专业术语大全_乐游网 Operating System win2003, winXP, win7, win10 MacOS Game Platform 游戏平台 TGP&#xff08;Tencent Game Platform &#xff09; PC &#xff08;Per…

Linux上部署docker与docker-compose的步骤

Centos上部署docker与docker-compose的步骤 linux系统版本为Centos7.2 第一步-检查前置条件是否符合部署docker 64-bit 系统 kernel 3.10 使用uname -r 检查内核版本&#xff0c;返回的值大于3.10即可。 Centos 7.2的kernel是&#xff1a;3.10.0-327&#xff0c;刚好满足条件…

【算法与数据结构】225、LeetCode用队列实现栈

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;第一种解法是利用两个队列&#xff0c;一个用作输出队列&#xff0c;一个用作备份队列。主要难点在于p…

<Java导出Excel> 4.0 Java实现Excel动态模板字段增删改查

思路&#xff1a; 主要是同时操作两张表&#xff1a;一张存储数据的表&#xff0c;一张存储模板字段的表&#xff1b; 查询&#xff1a;只查询模板字段的表&#xff1b; 新增&#xff0c;修改&#xff0c;删除&#xff1a;需要同时操作两张表中的字段 如果两张表字段不一致&…

51单片机--点亮LED灯和流水灯

文章目录 前言LED模块的原理点亮一个LED灯LED灯的闪烁LED流水灯 前言 大家好&#xff0c;这里是诡异森林。我使用的是普中科技的A2的51开发板&#xff0c;适合新手入门。用到的应用是Keil5和Stc-isp&#xff0c;第一个软件主要用来写代码的&#xff0c;第二个是将代码程序输送…

RocketMQ5.0--部署与实例

RocketMQ5.0–部署与实例 一、Idea调试 1.相关配置文件 在E:\rocketmq创建conf、logs、store三个文件夹。从RocketMQ distribution部署目录中将broker.conf、logback_namesrv.xml、logback_broker.xml文件复制到conf目录。如下图所示。 其中logback_namesrv.xml、logback_b…

2.2.cuda驱动API-初始化和检查的理解,CUDA错误检查习惯

目录 前言1. cuInit-驱动初始化2. 返回值检查总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程&#xff0c;之前有看过一遍&#xff0c;但是没有做笔记&#xff0c;很多东西也忘了。这次重新撸一遍&#xff0c;顺便记记笔记 本次课程学习精简 CUDA 教程-Driver API 案例…

氢燃料电池汽车储氢技术及其发展现状

摘要&#xff1a; 氢能的发展可有效地解决经济发展和生态环境间日益增长的矛盾。氢燃料汽车将处于氢能产业体系中核心地位&#xff0c;加快对氢燃料电池车的技术研发&#xff0c;大范围提高氢能源利用率&#xff0c;对于全世界形成以低碳排放为特征的工业体系具有重要意义。在…

【数据库】忘记mysql本地密码

目录 说明 操作步骤操作失败解决1.在以上操作步骤的第四步&#xff0c;输入mysql&#xff0c;报错第一种报错解决办法如下 第二种报错解决办法如下 2.从上面操作第二步后重新操作步骤如下报错解决办法如下 参考链接 说明 太久没使用本地mysql数据库&#xff0c;忘记了密码。 …