动手学深度学习——从零实现softmax分类模型

news2025/1/11 6:03:42

1. 数据集

fashion mnist是一个由10个类别图像组成的服装分类数据集,共包含60000张训练集图像和10000张测试集图像, 前者用于训练模型参数,后者用于评估模型性能。

2.1 数据集下载

先进行依赖库导入:

%matplotlib inline       # jupyter魔法命令,用于显示matplotlib生成的图形。
import torch             # 用于构建和训练深度学习模型。
import torchvision       # pytorch视觉工具库,用于处理图像数据。
from torch.utils import data       # 一些数据处理的工具类
from torchvision import transforms # 图像转换和增强
from d2l import torch as d2l

d2l.use_svg_display()              # 使用svg来显示图片,清晰度更高

接下来使用框架内置函数来下载数据集并读取到内存中,数据集大概在100MB左右。

# ToTensor:图像预处理,将图像数据转为tensor格式
trans = transforms.ToTensor()
# 从网上下载训练数据集,并通过transform转换
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
# 从网上下载验证数据集,并通过tranform转换为张量
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

数据集下载和解析的过程如下,以train开头的为训练集,以t10k开头的为测试集:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100.0%
Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100.0%
Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

每张图像28*28像素,全部为灰度图像,通道数为1,形状如下:

len(mnist_train), len(mnist_test), mnist_train[0][0].shape, mnist_test[0][0].shape

> (60000, 10000, torch.Size([1, 28, 28]), torch.Size([1, 28, 28]))

数据图形示例如下:
在这里插入图片描述

1.2 数据读取

同前面的线性回归一样,我们采用小批量数据读取来训练和测试模型,所以需要封装一个小批量数据读取的迭代器。

batch_size = 256
workers = 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=workers)
test_iter = data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=workers))
  • batch_size: 分批的批次大小
  • shuffle: 置为True可以打乱样本顺序,随机读取
  • num_workers: 使用多少个进程来并发读取数据

train_iter和test_iter都是一个数据迭代器,可以理解为集合中的iterator,只不过每次迭代的不是一条数据,而是batch_size大小的小批量数据集。

以train_iter为例输出下形状:

for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
> torch.Size([256, 1, 28, 28]) torch.float32 torch.Size([256]) torch.int64

读数据是常见的性能瓶颈,训练之前最好先测试下数据读取速度。

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

# 使用1个进程读取数据
> '10.63 sec'
# 使用4个进程读取数据
> '5.77 sec'

到这里,已经准备好Fashion-MNIST数据集,下面可以有它来训练和评估分类算法性能。

2. 模型

2.1 初始化模型参数

原始数据集中的每个样本都是28x28的图像,每个图像都有784个像素,可以理解为784个特征,我们可以把输入数据都看作长度为784的向量。

前文提到过,在softmax回归中,输出与类别一样多。 因为我们的数据集有10个类别,所以网络模型的输出维度为10。 因此,权重W将构成一个784x10的矩阵, 偏置b将构成一个长度为10的行向量。

num_inputs = 784
num_outputs = 10

# 与线性回归一样,使用正态分布初始化我们的权重W,偏置初始化为0。
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2.2 定义softmax操作

参考前文,实现softmax由三个步骤组成:

  • 对每个项求幂(使用exp);
  • 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
    将每一行除以其规范化常数,确保结果的和为1。

数学表达式如下:
在这里插入图片描述
代码实现:

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)  # 这里的1表示坐标轴1,即每一行的所有列求和
    return X_exp / partition  # 这里应用了广播机制

接下来验证是否正确,主要在于两方面:

  • 所有元素是否为正
  • 每一行的和是否为1
X = torch.normal(0, 1, (2, 5))   # 均值为0,标准差为1,2行5列的元素
X_prob = softmax(X)
X_prob, X_prob.sum(1)

> (tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],
         [0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),
 tensor([1.0000, 1.0000]))

2.3 定义模型

模型定义了如何将输入数据通过网络映射到输出。

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
  • X.reshape((-1, W.shape[0])): 将输入X的形状由4维矩阵[256, 1, 28, 28]调整为2维矩阵[256, 784],0维为批量大小,1维为向量W的0维长度784
  • 与线性回归一样,使用torch.matmul来计算矩阵X与向量W的矩阵向量积,再加上偏置b就是线性输出
  • 对线性输出softmax就得到各个类别的预测概率

2.4 定义损失函数

前文提到,交叉熵可以认为是真实标签的预测概率的负对数。那在计算交叉熵之前要先拿到真实标签的预测概率。

拿下面的样本数据来说明,y_hat是一个包含2个样本在3个类别的预测概率, y是对应的真实标签,采用下标来表示类别。

y = torch.tensor([0, 2, 1])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5], [0.075, 0.88, 0.045]])

样本1中,第一类是正确的预测,预测概率为0.1;
样本2中,第三类是正确的预测,预测概率为0.5;
样本3中,第二类是正确的预测,预测概率为0.88;

方法一:采用循环:

result = []
for i in range(len(y)):
    result.append(y_hat[i, y[i]])

torch.tensor(result)

> tensor([0.1000, 0.5000, 0.8800])

方法二:直接将y作为y_hat中概率的索引,因为y中存放的正确类别下标与y_hat中是对应的。

y_hat[[0, 1, 2], y]

> tensor([0.1000, 0.5000, 0.8800])
  • y_hat[[0, 1, 2], y] 本质上与常规二维数组索引方式y_hat[i, j]形式相同,不同点在于i、j不再是具体的数字,因为要一次性取多个样本的预测值;
  • i = [0, 1, 2]表示行方向上取第0、1、2三个样本;
  • j = y表示三个样本列方向分别取第0, 2, 1个元素;
  • 最终取出的元素是y_hat张量中第0行的第0列、第1行的第2列和第2行的第1列;

方法二比方法一要简单很多,由于是python内置语法,运行效率也更高。这样只需一行代码就可以实现交叉熵损失函数。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

cross_entropy(y_hat, y)

> tensor([2.3026, 0.6931, 0.1278])
  • 第1个正确值的概率只有0.1,所以计算出来交叉熵损失2.3026就比较大;
  • 第2个正确值的概率有0.5,所以交叉熵0.6931也有所收敛;
  • 第3个正确值 的概率较高0.88, 所以交叉熵0.1278就比较小;

2.4 分类精度

给定预测概率分布y_hat,当我们必须输出预测类别时,我们通常会选择预测概率最高的类别来作为预测结果,但预测概率高的类别有时候不一定是正确预测,这时候就产生了错误预测。

就如同上面第一个样本数据中,预测概率最高的0.6并非正确类别,实际正确类别的预测概率只有0.1。

我们需要一个指标来衡量模型预测的正确率,称之为分类精度,它是正确预测数量与总预测数量之比。

以上面的y和y_hat示例数据为例,可以通过如下步骤来计算分类精度:

  1. 使用argmax获得每行中最大元素的索引来获得预测类别。
  2. 将预测类别与真实y元素进行等值比较,比较前需要将y_hat的数据类型转换为与y的数据类型一致,因为等式运算符“==”对数据类型很敏感,
  3. 结果是一个包含0(错)和1(对)的张量,进行求和就可以得到正确预测的数量。

代码实现如下:

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

以上面的数据来测试:

accuracy(y_hat, y) / len(y)

> 0.6666666666666666
  • 第一个样本的预测错误,预测概率最大的索引2(概率0.6)与正确标签0不一致。
  • 第二个样本的预测正确,预测概率最大的索引2(概率0.5)与正确标签2一致。
  • 第三个样本的预测正确,预测概率最大的索引1(概率0.88)与正确标签1一致。

由于我们采用的是小批量多轮迭代训练,会有产生多轮预测数据,所以我们需要封装一个能支持多轮迭代的精度计算函数(主要用于训练后的精度测试)。

# @param net: 网络模型,用于对输入数据X进行类别预测,输出预测概率
# @param data_iter: 数据迭代器,每一轮迭代都包含输入数据X和对应的标签y
def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    metric = Accumulator(2)  # 2个元素的累加器,用于统计正确预测数、预测总数;
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

步骤解读:

  1. 先用模型net对输入X进行类别预测,得到预测概率;
  2. 再使用accuracy对预测结果和真实标签计算精度,并把精度和标签数量进行累加;
  3. 返回模型在数据集上的精度,正确预测数与总预测数的比值。

3. 训练

3.1 定义参数更新函数

这里我们复用线性回归中定义的参数优化函数sgd(小批量随机梯度下降),学习率设为0.1。

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

3.2 定义单轮迭代训练流程

def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 长度为3的累加器,分别累加训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 使用模型来计算得到预测概率
        y_hat = net(X)
        # 计算损失
        l = loss(y_hat, y)
        # 反向累积计算梯度
        l.sum().backward()
        # 更新优化参数
        updater(X.shape[0])
        # 累加损失、精度、样本数
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

3.3 定义整体训练流程

整体训练流程比较简单,就是循环执行多轮训练,每轮训练后参数都会得到更新,再拿测试数据集基于更新的参数去执行模型当前的表现,得到一个精度值。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        print(f"epoch: {epoch + 1}, loss: {train_metrics[0]}, test_acc: {test_acc}")


3.4 运行训练

基于前面定义的模型,进行10次迭代训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
  • num_epochs: 迭代训练次数
  • net: 网络模型
  • train_iter: 训练数据集
  • test_iter: 测试数据集,用于测试模型训练后的性能
  • cross_entropy: 损失函数
  • updater: 参数优化器

整个训练过程中的损失和测试精度变化:

epoch: 1, loss: 0.7857203146616618, test_acc: 0.7882
epoch: 2, loss: 0.5686315283457438, test_acc: 0.7985
epoch: 3, loss: 0.5252757650375366, test_acc: 0.8192
epoch: 4, loss: 0.5007046510060629, test_acc: 0.8231
epoch: 5, loss: 0.4856935443242391, test_acc: 0.8196
epoch: 6, loss: 0.4738648806254069, test_acc: 0.8249
epoch: 7, loss: 0.46540179011027016, test_acc: 0.8299
epoch: 8, loss: 0.45916082598368324, test_acc: 0.8271
epoch: 9, loss: 0.45219682502746583, test_acc: 0.833
epoch: 10, loss: 0.4484250022888184, test_acc: 0.8328

可以看出,随着训练的不断迭代,损失在持续减小,测试精度虽然有略微起伏,但总体上也是在不断提升。

4. 预测

使用训练好的模型对图像进行分类预测,比较图像的实际标签和模型预测是否相同:

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

结果如下:
在这里插入图片描述

总结

本文softmax分类模型与前面线性回归模型的整体训练过程比较相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1635998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JENKINS 安装,学习运维从这里开始

Download and deployJenkins – an open source automation server which enables developers around the world to reliably build, test, and deploy their softwarehttps://www.jenkins.io/download/首先点击上面。下载Jenkins 为了学习,从windows开始&#x…

ES集群分布式查询原理

集群分布式查询 elasticsearch的查询分成两个阶段: scatter phase:分散阶段,coordinating node会把请求分发到每一个分片gather phase:聚集阶段,coordinating node汇总data node的搜索结果,并处理为最终结…

【stomp 实战】Spring websocket 用户订阅和会话的管理

通过Spring websocket 用户校验和业务会话绑定我们学会了如何将业务会话绑定到spring websocket会话上。通过这一节,我们来分析一下会话和订阅的实现 用户会话的数据结构 SessionInfo 用户会话 用户会话定义如下: private static final class Sessio…

03 - 步骤 Kafka producer

简介 Kafka producer 步骤,用于将 Kettle 中经过处理或转换的数据发送到 Kafka 的主题中 使用 场景 我需要把经过Kettle处理完的数据发送到一个Kafka中,让后端服务器进行下一步处理。 1、拖拽 Kafka producer 到面板 2、配置 Kafka producer 3、调试…

FSD自动驾驶泛谈

特斯拉的FSD(Full-Self Driving,全自动驾驶)系统是特斯拉公司研发的一套完全自动驾驶系统。旨在最终实现车辆在多种驾驶环境下无需人类干预的自动驾驶能力。以下是对FSD系统的详细探讨: 系统概述 FSD是特斯拉的自动驾驶技术&…

架设WebSocket的最后一环,如何设置好nginx反向代理

WebScoket都已经完工快一个月,经过一段时间的测试,公司还是准备把服务器换到鹅厂,用EO来解决CDN内容分发和DDOS防护问题,由于EO并不支持URL 路径转发,只支持转发到一个站点的80或则443端口,如果想做路径分发…

从Paint 3D入门glTF

Paint 3D Microsoft Paint 3D是微软的一款图像编辑软件,它是传统的Microsoft Paint程序的升级版。 这个新版本的Paint专注于三维设计和创作,使用户可以使用简单的工具创建和编辑三维模型。 Microsoft Paint 3D具有直观的界面和易于使用的工具&#xff0…

小程序地理位置权限如何申请?

这篇内容会教大家如何快速申请“获取当前的地理位置(onLocationChange)”接口,以便帮助大家顺利开通接口。以下内容是本人经历了多次的申请经历得出来的经验,来之不易,望大家给予鼓励! 小程序地理位置接口有…

百川crm系统 汽车销售租赁CRM客户管理系统是不可或缺的利器?

在竞争激烈的汽车销售租赁市场中,如何提升客户满意度、优化业务流程、提高销售效率,成为了每一家汽车销售租赁公司必须面对的问题。而CRM(客户关系管理)客户管理系统,正是应对这些挑战的重要利器。本文将从汽车销售租赁…

18 如何设计微服务才能防止宕机?

在上一讲里,介绍了构建一个稳健的微服务的具体法则:防备上游、做好自己、怀疑下游, 并介绍了为什么要防备上游,以及一些防备上游的具体手段。 在本讲里,咱们一起来学习,做好微服务自身的设计和代码编写的常…

ollama-python-Python快速部署Llama 3等大型语言模型最简单方法

ollama介绍 在本地启动并运行大型语言模型。运行Llama 3、Phi 3、Mistral、Gemma和其他型号。 Llama 3 Meta Llama 3 是 Meta Inc. 开发的一系列最先进的模型,提供8B和70B参数大小(预训练或指令调整)。 Llama 3 指令调整模型针对对话/聊天用…

Centos7+Hadoop3.3.4+KDC1.15+Ranger2.4.0集成

一、集群规划 本次测试采用3台虚拟机,操作系统版本为centos7.6。 kerberos采用默认YUM源安装,版本为:1.15.1-55 Ranger版本为2.4.0 系统用户为ranger:ranger IP地址主机名KDCRanger192.168.121.101node101.cc.localKDC masterRanger Admin…

如何找到台式电脑的ip地址

在数字时代,每台接入网络的设备都拥有一个独特的标识,这就是IP地址。无论是手机、笔记本电脑还是台式电脑,IP地址都扮演着至关重要的角色,它帮助设备在网络世界中定位并与其他设备进行通信。对于许多电脑用户来说,了解…

JavaScript原型链深度剖析

目录 前言 一、原型链 1.原型链的主要组成 原型(Prototype) 构造函数(Constructor) 实例(Instance) 2.原型链的工作原理 前言 在JavaScript的世界中,原型链(Prototype Chain&…

“Postman 中文版使用教程:如何切换到中文界面?”

Postman 的很好用的接口测试软件。但是,Postman 默认是英文版的,也不支持在软件内切换为中文版。很多同学的英语并不是很好,看到一堆的英文很是头痛。 今天我们来介绍下:切换到 Postman 中文版的方法。想要学习更多的关于 Postma…

IDEA 中 git fetch 验证报错 The provided password or token is incorrect

参考链接: 【GitLab】-HTTP Basic: Access denied.remote:You must use a personal access token_http basic: access denied. the provided password o-CSDN博客 idea使用gitLab报错:remote: HTTP Basic: Access denied_idea remote: http basic: acc…

MoonBit 周报 Vol.39:新增 JS 后端、插件和构建系统同步支持多后端开发……

MoonBit 更新 新增JavaScript后端 目前MoonBit已新增对JavaScript的支持并带来前所未有的性能提升,在JS后端实现了超出Json5近8倍性能的优势。更详细的介绍可以看一下这篇文章:IDEA研究院编程语言MoonBit发布JavaScript后端,速度提升25倍 …

Copilot Workspace是GitHub对人工智能驱动的软件工程的诠释

软件开发的未来是人工智能驱动的集成开发环境吗?至少GitHub 是这样想的。 在今年初秋于旧金山举行的 GitHub Universe 年度大会之前,GitHub 发布了 Copilot Workspace,这是一种开发环境,利用 GitHub 所称的 “Copilot 驱动的代理…

[游戏陪玩系统] 陪玩软件APP小程序H5游戏陪玩成品软件源码-线上线下可爆改家政,整理师等功能

简介 随着电竞行业的快速发展,电竞陪玩APP正在逐渐成为用户在休闲娱乐时的首选。为了吸引用户和提高用户体验,电竞陪玩APP开发需要定制一些特色功能,并通过合适的盈利模式来获得收益。本文将为您介绍电竞陪玩APP开发需要定制的特色功能以及常…

超简单的Spring-mvc示例

超简单的Spring-mvc示例