softmax回实战

news2024/11/28 12:37:12

1.数据集

MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

2.初始化模型参数

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

3.定义softmax操作

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)

(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],
         [0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),
 tensor([1.0000, 1.0000]))

4.定义模型

定义softmax操作后,我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

5.定义损失函数

接下来,我们引入交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

6.分类精度

为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

7.训练

在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。

def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数的实现之前,我们定义一个在动画中绘制数据的实用程序类Animator, 它能够简化本书其余部分的代码。

class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

8.预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

9.总结:

  • 借助softmax回归,我们可以训练多分类的模型。
  • 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

10.简洁实现:

10.1读取数据

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

10.2定义模型,初始化参数

跟线性规划模型是一样的,只不过有十个输出

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

10.3定义损失函数

loss = nn.CrossEntropyLoss(reduction='none')

10.4定义优化器

我们使用学习率为0.1的小批量随机梯度下降作为优化算法

updater = torch.optim.SGD(net.parameters(), lr=0.1)

10.5训练

调用上面写好的函数,只不过这次的模型,损失函数,优化器都是pytorch内置的

num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。
    数,优化器**都是pytorch内置的
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1399211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

25计算机考研408专业课复习计划

点击蓝字&#xff0c;关注我们 今天要分享的是25计算机考研408专业课复习计划。 以下内容供大家参考&#xff0c;大家要根据自己的复习情况进行适当调整。 统考与自命题 统考科目是指计算机学科专业基础综合&#xff08;408&#xff09;&#xff0c;满分150分&#xff0c;试…

用MATLAB函数在图表中建立模型

本节介绍如何使用Stateflow图表创建模型&#xff0c;该图表调用两个MATLAB函数meanstats和stdevstats。meanstats计算平均值&#xff0c;stdevstats计算vals中值的标准偏差&#xff0c;并将它们分别输出到Stateflow数据平均值和stdev。 请遵循以下步骤&#xff1a; 1.使用以下…

医保移动支付加密解密请求工具封装【国密SM2SM4】

文章目录 医保移动支付加密解密请求工具封装一、项目背景二、使用方法三、接口调用四、源码介绍五、下载地址 医保移动支付加密解密请求工具封装 定点医药机构向地方移动支付中心发起费用明细上传、支付下单、医保退费等交易时需要发送密文&#xff0c;由于各大医疗机构厂商的开…

揭秘AI换脸技术:从原理到应用

随着人工智能技术的不断发展&#xff0c;AI换脸技术逐渐成为人们关注的焦点。这项神奇的技术能够将一张图像或视频中的人脸替换成另一张人脸&#xff0c;让人不禁惊叹科技的神奇。那么&#xff0c;AI换脸技术究竟是如何实现的呢&#xff1f;本文将带您深入了解AI换脸技术的原理…

python系列-输入输出关系运算符算术运算符

&#x1f308;个人主页: 会编程的果子君​&#x1f4ab;个人格言:“成为自己未来的主人~” 目录 注释的语法 注释的规范 输入输出 通过控制台输出 通过控制台输入 运算符 算术运算符 关系运算符 注释的语法 python中有两种注释风格&#xff1a; 1.注释行&#xff1a;…

无人机打击激光器

激光器的应用非常广泛&#xff0c;涵盖了多个领域。以下是一些主要的激光器应用&#xff1a; 医疗领域&#xff1a;激光器在医疗行业中有着重要应用&#xff0c;比如用于激光手术&#xff08;如眼科手术&#xff09;、皮肤治疗、牙科治疗、肿瘤治疗等。 工业制造&#xff1a;在…

(菜鸟自学)初学脚本编程

&#xff08;菜鸟自学&#xff09;初学脚本编程 Bash脚本概述编写一个测试在线主机的脚本程序 Python脚本概述编写一个与Netcat功能类似的脚本程序 C语言脚本概述编写C语言脚本程序&#xff08;Hello World&#xff09; Bash脚本概述 Bash脚本是一种基于Bash&#xff08;Bourn…

图片批量建码怎么用?每张图片快速生成二维码

当我们需要给每个人分别下发对应的个人证件类图片信息&#xff0c;比如制作工牌、荣誉展示或者负责人信息展示时&#xff0c;现在都开始使用二维码的方法来展示员工信息。那么如何快速将每个人员的信息图片分别制作成二维码图片呢&#xff0c;最简单的方法就是使用图片批量建码…

vue中内置指令v-model的作用和常见使用方法介绍以及在自定义组件上支持

文章目录 一、v-model是什么二、什么是语法糖三、v-model常见的用法1、对于输入框&#xff08;input&#xff09;&#xff1a;2、对于复选框&#xff08;checkbox&#xff09;&#xff1a;3、对于选择框&#xff08;select&#xff09;&#xff1a;4、对于组件&#xff08;comp…

群发邮件效果追踪:掌握数据,优化营销策略

我们在邮件群发结束后&#xff0c;如果想要了解到这次群发活动的效果怎么样&#xff0c;就需要通过一些数据。比如说邮件达到率、打开率、跳出率、退订率等。这些信息可以将收件人的行为数据化&#xff0c;让我们可以更清晰地对活动进行深入分析让我们及时地找出问题和优点&…

C语言数据结构——顺序表

&#xff08;图片由AI生成&#xff09; 0.前言 在程序设计的世界里&#xff0c;数据结构是非常重要的基础概念。本文将专注于C语言中的一种基本数据结构——顺序表。我们将从数据结构的基本概念讲起&#xff0c;逐步深入到顺序表的内部结构、分类&#xff0c;最后通过一个实…

网络安全:守护数字世界的盾牌

在当今数字化的时代&#xff0c;网络已经渗透到我们生活的方方面面。从社交媒体到在线银行&#xff0c;从在线购物到工作文件传输&#xff0c;网络几乎无处不在。然而&#xff0c;随着网络的普及&#xff0c;网络安全问题也日益凸显。那么&#xff0c;如何确保我们的数字资产安…

Vue2的双向数据绑定

Vue2的双向数据绑定 Observer&#xff1a;观察者&#xff0c;这里的主要工作是递归地监听对象上的所有属性&#xff0c;在属性值改变的时候&#xff0c;触发相应的watcher。 Watcher&#xff1a;订阅者&#xff0c;当监听的数据值修改时&#xff0c;执行响应的回调函数&#x…

KubeSphere 核心实战之二【在kubesphere平台上部署redis】(实操篇 2/4)

文章目录 1、登录kubesphere平台2、redis部署分析3、redis容器启动代码4、kubesphere平台部署redis4.1、创建redis配置集4.2、创建redis工作负载4.3、创建redis服务 5、测试连接redis 在kubesphere平台上部署redis应用都是基于redis镜像进行部署的&#xff0c;所以所有的部署操…

【Github搭建网站】零基础零成本搭建个人Web网站~

Github网站&#xff1a;https://github.com/ 这是我个人搭建的网站&#xff1a;https://xf2001.github.io/xf/ 大家可以搭建完后发评论区看看&#xff01;&#xff01;&#xff01; 搭建教程&#xff1a;https://www.bilibili.com/video/BV1xc41147Vb/?spm_id_from333.999.0.0…

2023.12 电子学会青少年软件编程(Python) 等级考试试卷(三级)

2023年12月 电子学会青少年软件编程&#xff08;Python&#xff09; 等级考试试卷&#xff08;三级&#xff09; 分数&#xff1a; 100 题数&#xff1a; 38 一、单选题(共 25 题&#xff0c; 共 50 分) 1. 一个非零的二进制正整数&#xff0c; 在其末尾添加两个“0” &#xf…

【排序算法】六、快速排序(C/C++)

「前言」文章内容是排序算法之快速排序的讲解。&#xff08;所有文章已经分类好&#xff0c;放心食用&#xff09; 「归属专栏」排序算法 「主页链接」个人主页 「笔者」枫叶先生(fy) 目录 快速排序1.1 原理1.2 Hoare版本&#xff08;单趟&#xff09;1.3 快速排序完整代码&…

70.Redis缓存优化实践(基于分类树场景)

文章目录 前言第一次优化第二次优化第三次优化第四次优化第五次优化 前言 分类树查询功能&#xff0c;在各个业务系统中可以说随处可见&#xff0c;特别是在电商系统中。 而在实际工作中&#xff0c;这样一个分类树查询&#xff0c;我们都不断的改进了好几次。这是为什么呢&…

Lite AD的安装

1、Lite AD的安装及配置 Lite AD流程&#xff1a; &#xff08;1&#xff09;创建一个新的Windows 10&#xff0c;安装tools&#xff0c;再安装ITA组件&#xff08;安装Lite AD会自动安装VAG/VLB&#xff09; &#xff08;2&#xff09;创建一个新的Windows 10&#xff0c;安…

【GNN报告】“青源Talk”-图可信学习与图大模型研究进展

北航王啸-图自监督学习 简介 介绍 浙大杨洋-探索大图模型预训练 总括 介绍 参考 Yang Yang - Zhejiang University dgraph-web DGraph: ALarge-Scale Financial Dataset for Graph Anomaly Detection All in One: Multi-task Prompting for Graph Neural Networks&#xf…