机器学习深度学习——softmax回归从零开始实现

news2025/1/22 20:53:26

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——向量求导问题
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

就跟之前从零开始实现线性回归一样,softmax回归也很重要,因此也进行一次从0开始实现。之前的章节中,我们已经引入了Fashion-MNIST数据集,并设置数据迭代器的批量大小为256。

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

softmax回归的从零开始实现

  • 初始化模型参数
  • 定义softmax操作
    • 回顾sum运算符
    • 构建softmax运算函数
  • 定义模型
  • 定义损失函数
    • NumPy的整数数组索引
    • 交叉熵损失函数定义
  • 分类精度
  • 训练
  • 预测

初始化模型参数

和之前线性回归例子一样,每个样本都用固定长度的向量表示,则之前数据集中每个样本都是28×28的图像,将要进行展平,把他们看做是长度为784的向量。(在这里我们暂且把每个像素的位置都看作是一个特征,其实严格意义上要讨论其空间结构的,在这不做讨论)
而在softmax回归中,我们的输出和类别一样多,因为数据集由10个类别,所以网络输出维度为10。因此,权重将构成一个784×10的矩阵,偏置将构成一个1×10的行向量。与线性回归一样,我们将使用正态分布初始化我们的权重W,偏置初始化为0。

num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs),requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

定义softmax操作

回顾sum运算符

按照之前的线性代数的内容,给定一个矩阵X,我们可以利用sum函数给所有元素求和(默认)。也可以对同一列(轴0)或同一行(轴1)进行求和。用例子表示:

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(X.sum(0, keepdim=True), X.sum(1, keepdim=True))  # keepdim表示还保留着之前维度即二维

结果:

tensor([[5., 7., 9.]]) tensor([[ 6.],
[15.]])

构建softmax运算函数

回想一下softmax的三个步骤:
1、对每个项求幂(使用exp);
2、对每一行求和(因为小批量中每一行就是一个样本),得到每个样本的规范化常数
3、将每一行除以其规范化常数,确保结果的和为1
回顾一下表达式:
s o f t m a x ( X ) i j = e x p ( X i j ) ∑ k e x p ( X i k ) softmax(X)_{ij}=\frac{exp(X_{ij})}{\sum_kexp(X_{ik})} softmax(X)ij=kexp(Xik)exp(Xij)

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 广播机制

可以验证上述的代码:

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
print(X_prob, '\n', X_prob.sum(1))

结果:

tensor([[0.0152, 0.1212, 0.6149, 0.0877, 0.1610],
[0.1921, 0.0852, 0.1945, 0.4261, 0.1020]])
tensor([1.0000, 1.0000])

根据概率原理易得每行的和为1
注意:数学上看起来很正确,但是代码实现太草率了。矩阵中的非常大或非常小的元素可能造成数值上溢或下溢,但是这里没有采取措施来防止这一点。

定义模型

也就是直接将y=XW+b进行softmax运算得到,注意下面的X要使用reshape来将每张原始图像展平为向量(轴0放个-1让他自己根据列长度=784来进行运算,这里应为256,因为批量大小为256,每个批量(图像)都被展开成了784的向量)

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

定义损失函数

引入交叉熵函数,这在深度学习中很可能是最常见的损失函数了(目前分类问题数量远超回归问题数量)
回顾一下,交叉熵采用真实标签的预测概率的负对数似然。这边我们不使用for循环这种低效的方式,而是通过一个运算符选择所有函数。在这里我们先介绍下NumPy的整数数组索引。

NumPy的整数数组索引

整数数组索引,它可以选择数组中的任意一个元素,比如,选择第几行第几列的某个元素,示例如下:

import numpy as np
#创建二维数组
x = np.array([[1,  2],  [3,  4],  [5,  6]])
#[0,1,2]代表行索引;[0,1,0]代表列索引
y = x[[0,1,2],[0,1,0]] 
print (y)

结果:

[1 4 5]

对着样例做简单分析:将行、列索引组合会得到 (0,0)、(1,1) 和 (2,0) ,它们分别对应着输出结果在原数组中的索引位置。

下面,我们创建一个数据样本y_hat,其中包含2个样本在3个类别的预测概率,以及它们对应的标签y。然后使用y作为y_hat中概率的索引,我们选择第一个样本中第一个类的概率和第二个样本中第三个类的概率:

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat[[0, 1], y])

输出:

tensor([0.1000, 0.5000])

交叉熵损失函数定义

那么现在只需要一行就可以实现交叉熵函数了:

def cross_entropy(y_hat, y):
    return -torch.log(y_hat[range(len(y_hat)), y])

注意,原来的交叉熵损失函数实际上是:
l ( y , y ^ ) = − ∑ j = 1 q y j l o g y ^ j l(y,\hat{y})=-\sum_{j=1}^qy_jlog\hat{y}_j l(y,y^)=j=1qyjlogy^j
其中,q是独热编码的长度,那么容易知道,那个求和符号其实没啥用,因为利用独热编码的话,除了中标的那一项,其他的y中元素全是0。所以引变为代码中的:
l ( y , y ^ ) = − l o g y ^ j l(y,\hat{y})=-log\hat{y}_j l(y,y^)=logy^j
验证:

print(cross_entropy(y_hat, y))

结果:

tensor([2.3026, 0.6931])

分类精度

给定预测概率分布y_hat,我们要给出硬预测时,通常选择预测概率最高的类。
当预测和标签分类y一致时,就是正确的。分类精度即正确预测数量与总预测数量之比。虽然直接优化精度可能很难(精度计算不可导),但我们总是要关注他。
我们可以进行下面的操作:
若y_hat是矩阵,假定第二维度存储每个类的预测分数,我们就可以使用argmax来获得每行的最大元素的索引,用来获得预测的类别。然后和真实的y比较。(注意,由于等式运算符号"=="对数据类型很敏感,因此我们需要将数据类型转换为一致的。)结果会是一个包含0和1的张量,求和就可以得到正确预测的数量了。

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:  # 判断是矩阵
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.astype(y.dtype) == y
    return float(cmp.astype(y.dtype).sum())

我们将继续使用之前定义的变量y_hat和y分别作为预测的概率分布和标签。 可以看到,第一个样本的预测类别是2(该行的最大元素为0.6,索引为2),这与实际标签0不一致。 第二个样本的预测类别是2(该行的最大元素为0.5,索引为2),这与实际标签2一致。 因此,这两个样本的分类精度率为0.5。

print(accuracy(y_hat, y) / len(y))

结果:

0.5

同样,对于任意数据迭代器data_iter可访问的数据集,我们可以评估在任意模型net的精度。
我们先定义一个实用程序类Accumulator用于对多个变量进行累加:

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):  # *号会拆解为元组
        self.data = [a + float(b) for a, b in zip(self.data, args)]  # zip就是把两元组组合起来
    def reset(self):
        self.data = [0.0] * len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

接着我们定义evaluate_accuracy函数用于计算在指定数据集上模型的精度:

def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

在上面的evaluate_accuracy函数中,我们在Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都将随着时间的推移而累加。

训练

首先,我们定义一个函数来训练一个迭代周期。(注意:updater是更新模型参数的常用函数,它接受批量大小作为参数。它可以是d2l.sgd函数,也可以是框架的内置优化函数。)

def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用Pytorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()  # 损失后向传播
            updater.step()  # 更新网络参数
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数实现前,定义一个在动画中绘制数据的应用程序类Animator(会用就行):

class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        d2l.plt.draw()
        d2l.plt.pause(0.001)
        display.clear_output(wait=True)

接下来,实现一个训练函数,它会在train_iter访问到的训练数据集上训练一个模型net。该训练函数会运行多个迭代周期。在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。我们利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    # assert语句表示断言,表达式为False时会触发AssertionError异常
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

我们使用之前定义的小批量随机梯度下降来优化模型的损失函数,设学习率为0.1:

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

现在,训练10个迭代周期:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

在这里插入图片描述
这边是可以跑出动图的,如果跑不出来动态的效果,解决方案:
File ——> Settings ——> Tools ——> Python Scientific ——> 取消勾选 Show plots in toolwindow
(电脑快跑炸了)

预测

训练已经完成,我们的模型可以进行分类预测了,给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 2, n, titles=titles[0:n])

predict_ch3(net, test_iter)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/792586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

39. Linux系统下在Qt5.9.9中搭建Android开发环境

1. 说明 QT版本:5.9.9 电脑系统:Linux JDK版本:openjdk-8-jdk SDK版本:r24.4.1 NDK版本:android-ndk-r14b 效果展示: 2. 具体步骤 大致安装的步骤如下:①安装Qt5.9.9,②安装jdk,③安装ndk,④安装sdk,⑤在qt中配置前面安装的环境路径 2.1 安装Qt5.9.9 首先下载…

国产化的接口测试、接口自动化测试工具Apipost的介绍及使用

Apipost介绍&#xff1a; Apipost是 API 文档、API 调试、API Mock、API 自动化测试一体化的研发协作赋能平台&#xff0c;它的定位 Postman Swagger Mock JMeter。 Apipost 是接口管理、开发、测试全流程集成工具&#xff0c;能支撑整个研发技术团队同平台工作&#xff0…

win10日程怎么同步到安卓手机?电脑日程同步到手机方法

在如今快节奏的生活中&#xff0c;高效地管理时间变得至关重要。而对于那些经常在电脑上安排日程的人来说&#xff0c;将这些重要的事务同步到手机上成为了一个迫切的需求。因为目前国内使用win10系统电脑、安卓手机的用户较多&#xff0c;所以越来越多的职场人士想要知道&…

手机怎么把word转换成pdf?这几种方法超简单

手机怎么把word转换成pdf&#xff1f;现在很多人在手机上处理文档&#xff0c;但是可能会遇到将Word文档转换为PDF的需求&#xff0c;以便更好地分享和传输文件。在下面这篇文章中&#xff0c;就给大家介绍几种将Word文档转换为PDF的方法。 方法一&#xff1a;使用迅捷PDF转换器…

spring复习:(55)注解配置的情况下@ComponentScan指定的包中的组件是怎么被注册到容器的?

配置类&#xff1a; 主类&#xff1a; 结论&#xff1a;是在context.refresh()处完成扫描和注册的。 fresh()的代码片段如下&#xff1a; 其中调用的invokeBeanFactoryPostProcessor代码如下&#xff1a; 其中调用的静态方法invokeBeanFactoryPostProcessors代码如下&#…

一些联动树形数据组装

export const pieselectdata [{entrustOrganization: 智慧法院电子诉讼平台,entrustOrganizationId: 161,productNames: [{batchCodes: [],productName: CL测试调解产品,},{batchCodes: [2022927_001,2022927_003,2022927_004,2022927_005,2022927_006,2022927_008,2022927_00…

文本预处理——文本数据增强

目录 文本数据增强回译数据增强法 文本数据增强 回译数据增强法

windows 系统安装sonarqube

SonarQube是一种自动代码审查工具&#xff0c;用于检测代码中的错误&#xff0c;漏洞和代码异味。它可以与您现有的工作流程集成&#xff0c;以便在项目分支和拉取请求之间进行连续的代码检查。 官方网站&#xff1a; https://www.sonarqube.org/ 1. 使用前提条件 运行SonarQ…

Excel双向柱状图的绘制

Excel双向柱状图在绘制增减比较的时候经常用到&#xff0c;叫法繁多&#xff0c;双向柱状图、上下柱状图、增减柱状图都有。 这里主要介绍一下Excel的基础绘制方法和复杂一点的双向柱状图的绘制 基础双向柱状图的绘制 首先升降的数据如下&#xff1a; 月份上升下降20220359-…

【二叉树】刷题(以递归写法为主)

226.翻转二叉树 101. 对称二叉树 104.二叉树的最大深度 111.二叉树的最小深度 222.完全二叉树的节点个数 110.平衡二叉树 102. 二叉树的所有路径 226.翻转二叉树 class Solution:def invertTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:if not root:return…

13、PHP面向对象2(方法的访问控制、子类继承、常量)

1、类中的方法可以被定义为公有&#xff0c;私有或受保护。如果没有设置这些关键字&#xff0c;则该方法默认为公有。 public定义的方法&#xff0c;可以在类外使用。 protected定义的方法&#xff0c;只能在本类或子类的定义内使用。 private定义的方法&#xff0c;只能在本…

第八章 非编码RNA简介

第八章 非编码RNA简介 第一节 引言 第二节 长链非编码RNA简介 第三节 环形RNA简介 第四节 小RNA简介 4.1 miRNA 4.2 piRNA 4.3 小RNA数据分析

图像篡改及防篡改

有时候我们是攻击方&#xff0c;发送被网站或微信屏蔽的敏感图像&#xff0c;分享瓜时剔除可能暴露的个人信息&#xff0c;在平台分享其他平台的购物记录 有时候我们是防守方&#xff0c;判断他人给的图有没有造假嫌疑&#xff0c;判断是不是网图盗图 调研了图像造假的判别方案…

soft ip与hard ip

ip分soft和hard两种&#xff0c;soft就是纯代码&#xff0c;买过来要自己综合自己pr。hard ip如mem和analog与工艺有关。 mem的lib和lef是memory compiler产生的&#xff0c;基于bitcell&#xff0c;是foundry给的。 我正在「拾陆楼」和朋友们讨论有趣的话题&#xff0c;你⼀起…

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…

Upload文件导入多条数据到输入框

需求场景&#xff1a;文本框内容支持批量导入(文件类型包括’.txt, .xls, .xlsx’)。使用AntD的Upload组件处理。 下面是Upload的配置&#xff08;伪代码&#xff09;&#xff0c;重点为beforeUpload中的逻辑 // Antd 中用到的Upload组件 import { UploadOutlined } from ant…

ADSCOPE加入中国互联网协会!

近日&#xff0c;经协会批准&#xff0c;ADSCOPE&#xff08;上海倍孜网络技术有限公司&#xff09;正式加入中国互联网协会&#xff0c;成为会员单位。 中国互联网协会隶属于中华人民共和国工业和信息化部&#xff0c;是由中国互联网行业及与互联网相关的企事业单位、社会组织…

网络编程--模拟HTTP服务器

下面是一个简单的例子&#xff0c;来模拟HTTP服务器 这里只是简单的按照 HTTP 协议来构造数据 #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <sys/wait.h> #include <string.h> #include <arpa/inet.h>//处理连…

牛客30道Java专项练习-错题-01

一、Java初始化过程&#xff1a; 初始化父类种的静态成员变量和静态代码块&#xff0c;顺序执行初始化子类种的静态成员变量和静态代码块&#xff0c;顺序执行初始化父类的普通成员变量和代码块&#xff0c;再执行父类的构造函数初始化子类的成员变量和代码块&#xff0c;在执…

2024考研408-计算机网络 第一章-计算机网络体系结构学习笔记.md

文章目录 前言一、计算机网络概述1.1、概念及功能1.1.1、计算机网络的概念1.1.2、计算机网络的功能功能1、数据通信功能2、资源共享功能3、分布式处理功能4、提高可靠性&#xff08;分布式处理引申功能&#xff09;功能5、负载均衡&#xff08;也是分布式处理引申功能&#xff…