softmax回归从零开始实现

news2024/11/25 10:49:16

1. 引入Fashion-MNIST数据集

并设置数据迭代器的批量大小为256

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
# 每次随机读256张图片,返回训练集和测试集的迭代器
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

2. 初始化参数模型

每个样本都将用固定长度的向量表示。 原始数据集中的每个样本都是28 * 28 的图像。 在本节中,我们将展平每个图像,把它们看作长度为784的向量。现在我们暂时只把每个像素位置看作一个特征

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784 * 10 的矩阵, 偏置将构成一个1 * 10的行向量。 与线性回归一样,我们将使用正态分布初始化我们的权重W,偏置初始化为0

num_inputs = 784
num_outputs = 10

W = torch.normal(0,0.01,size = (num_inputs,num_outputs),requires_grad = True)
b = torch.zeros(num_outputs,requires_grad = True)

3. 定义softmax操作

回想一下,实现softmax由三个步骤组成:

  1. 对每个项求幂(使用exp);
  2. 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
  3. 将每一行除以其规范化常数,确保结果的和为1。

在这里插入图片描述

# 在实现softmax时,当X是矩阵时,是对每一行做softmax
def softmax(X):
    X_exp = torch.exp(X) # 对矩阵中每一个元素做指数运算
    
    # keepdim=True 表示仍然使得X是一个矩阵,保持维度
    # 按照维度为1来进行求和sum(axis=1),就是按列压缩,也就是把每一行求和
    partition = X_exp.sum(1,keepdim=True) 
    
    # 这里用了广播机制,矩阵中每个元素/对应行元素之和
    return X_exp / partition 

用如下这个例子,可以知道softmax是怎样工作的:

在这里插入图片描述

可以看出,经过softmax之后,矩阵的维度不变,仍然是2 * 5,并且每一行元素相加为1.

4. 实现softmax回归模型

def net(X):
    # reshape(-1,W.shape[0]):-1表示自动计算,实际上算出来就等于批量大小
    # 列数=W.shape[0],也就是权重矩阵的行数784,也是784个特征
    # 之前定义了batch_size=256,因此最后X会被reshape成256 * 784 的矩阵
    # 再对 矩阵X和矩阵W进行乘法,最后通过广播机制加上偏移
    # 最后,放入softmax中,得到所有元素值非0,且每一行的和为1的输出
    return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

在这里插入图片描述

5. 定义交叉熵损失函数

先补充一下如何根据标号把对应的预测值拿出来。

交叉熵采用真实标签的预测概率的负对数似然。我们创建一个数据样本y_hat,其中包含2个样本在3个类别的预测概率, 以及它们对应的标签y。 有了y,我们知道在第一个样本中,第一类是正确的预测; 而在第二个样本中,第三类是正确的预测。 然后使用y作为y_hat中概率的索引, 我们选择第一个样本中第一个类的概率和第二个样本中第三个类的概率。

在这里插入图片描述

在这里插入图片描述

实现交叉熵损失函数

交叉熵 = -log(预测的类别的概率)

在这里插入图片描述

6. 分类精度

分类精度即正确预测数量与总预测数量之比。

为了计算精度,我们执行以下操作:

  1. 如果y_hat是矩阵(即有多个列,也就是有多个分类类别),那么假定第二个维度存储每个类的预测分数。
  2. 我们使用argmax获得每行中最大元素的索引来获得预测类别。
  3. 然后我们将预测类别与真实y元素进行比较。
  4. 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含**0(错)和1(对)**的张量。
  5. 最后,我们求和会得到正确预测的数量。
def accuracy(y_hat,y):
    '''计算预测的正确率'''
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        # 判断张量y_hat是否大于一维,以及张量的第二个维度是否大于1
        y_hat = y_hat.argmax(axis=1) # 对每一行求argmax,选出索引最大的列
        # 从而得到每一个样本的最大值的下标,存到y_hat中,也就是得到预测的类别
        # 例如,在上面的例子中,第一个样本中的最大概率的下标是2,也就是预测是第二类,
        # 然而真实的类别是第0类,这样就表示预测错误。
        
    # 将y_hat的数据类型转换为与y的数据类型一致
    cmp = y_hat.type(y.dtype) == y # 变成一个bool的tensor
    # 预测对了就等于1,预测错了就等于0
    
    # 再把cmp转成和y一样的形状,求和,再转成一个浮点数,将其转化为标量
    # 仍然看上面的例子,可以看出,第一个样本预测错了,因此是0
    # 第一个预测对了,因此是1,两个相加得到结果为1
    return float(cmp.type(y.dtype).sum()) 
    #cmp.type(y.dtype)实现了bool到y的类型tensor的转化
    # 也就是把True和False转化成1和0

# 得到的sum=1,除以len(y),也就是类别个数,1/2,得到0.5
# 0.5也就是预测正确的概率
accuracy(y_hat,y) / len(y)   # 返回的float类型的数是一个标量,len(y)也是一个标量

在这里插入图片描述
自己的理解:
在这里插入图片描述

在这里插入图片描述
同样,对于任意数据迭代器data_iter可访问的数据集, 我们可以评估在任意模型net的精度.

def evaluate_accuracy(net,data_iter):
    '''计算在指定数据集上模型的精度'''
    # isinstance和type都是比较类型是否一致#
    # 只是isinstance算继承类,type不算继承类
    if isinstance(net,torch.nn.Module):
        # 如果是用torch.nn 实现一个模型的话,把它转成一个评估模式,不用计算梯度
        # 评估模式:指输入后得出结果用来评估模型的正确率,不做反向传播
        net.eval() # 将模型设置为评估模式
    metric = Accumulator(2) # 实用程序类Accumulator,用于对多个变量进行累加
    # Accumulator实例中创建了2个变量, 分别用于存储正确预测的数量和预测的总数量
    for X,y in data_iter:
        # accuracy(net(X),y)预测正确的样本数,y.numel()表示样本总数
        metric.add(accuracy(net(X),y),y.numel())
    return metric[0] / metric[1] # 最后返回的是分类正确的样本数和总样本数的相除

接下来实现Accumulator,用于对多个变量进行累加,在Accumulator实例中创建了2个变量, 分别用于存储正确预测的数量和预测的总数量:

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
  1. 关于[0.0] * n:假设n=2,则 [0.0] * n 表示[0.0 , 0.0],得到的是数组
  2. 关于zip(self.data, args),可以知道 第一项是【0.0 ,0.0】,第二项是(1,2)是元组,假设 obj1 = [0.0 ,0.0] ,obj2=(1,2),则 [a+float(b) for a,b in zip] 就会得到一个数组[0.0+float(1),0.0+float(2)],也就是最终得到数组[1.0,2.0] ,因此,调用add函数最终会得到一个数组。

ps:如果是 for i in zip(),得到是两个元组(0.0,1) 和 (0.0,2)

最后通过下面这个函数,看随机出来的模型和测试集的迭代器得到的精度是0.1645,因为类别数是10,所以随机的话,应该是10%的正确率,二者相差不大,基本上可以认为是随机的:
在这里插入图片描述

7. softmax回归的训练

首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。 它可以是d2l.sgd函数,也可以是框架的内置优化函数。

def train_epoch_ch3(net,train_iter,loss,updater):
    if isinstance(net,torch.nn.Module):
        # 将模型设置为训练模式,需要计算梯度
        net.train()
    metric.Accumulator(3) # 长度为3的迭代器,来累加需要的信息
    for X,y in train_iter:
        y_hat = net(X) # 通过softmax得到预测值
        l = loss(y_hat,y) # 计算损失函数
        # 如果updater是torch.optim.Optimizer(python内置的优化器)
        if isinstance(updater,torch.optim.Optimizer): # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad() # 梯度清零
            l.backward() # 反向传播
            updater.step() # 对参数进行一次更新
            metric.add( # float(l) * len(y) 就是总的训练损失累加
                # 因为内置的损失函数会自动对loss取均值,所以得到的损失要乘以len(y)
                float(l) * len(y),accuracy(y_hat,y),
                y.size().numel()) # y.size().numel()相当于y.numel,求tensor包含的元素个数
        else:
            # 如果使用自己定制的优化器和损失函数
            l.sum().backward() 
            updater(X.shape[0])
            metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

定义一个在动画中绘制数据的实用程序类,代码略。

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):
    animator = Animator(xlabel = 'epoch',xlim = [1,num_epochs],ylim=[0.3,0.9],
                       legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs): # 扫num_epochs遍数据
        # 对整个数据扫一次,返回训练损失和训练精度
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        # 在net()这个函数中,W和b两个参数都被指定了requires_grad=True
        # 同时注意W和b都是全局变量,因此在此处调用net()时,会自动调用W和b
        # 计算在测试数据集上模型的精度
        test_acc = evaluate_accuracy(net, test_iter)
        # 在animator中显示,训练误差,训练精度以及测试的精度
        animator.add(epoch + 1, train_metrics + (test_acc,))
    # 训练损失和训练精度
    train_loss, train_acc = train_metrics

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1:

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W,b],lr,batch_size)

训练模型10个迭代周期:

在这里插入图片描述

8. 预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/68680.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

职场日常:测试人员如何快速熟悉新业务?

身处职场,学习新业务在所难免,尤其是测试人员,具备良好的业务知识是我们做好质量保障的前提,不管是职场「新人」还是「老人」,快速熟悉业务的能力都是不可或缺的,这是我们安身立命的根本。 但,…

【第一章 Linux目录结构,网络连接三种模式,vi和vim】

第一章 Linux目录结构,网络连接三种模式,vi和vim 1.Linux和Unix: ①Unix针对于大型,高性能主机或服务器; ②Linux适用于个人计算机。 2.网络连接的三种模式: ①桥接模式:虚拟系统可以和外部系…

[附源码]JAVA毕业设计师生交流平台(系统+LW)

[附源码]JAVA毕业设计师生交流平台(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&…

一起用Go做一个小游戏(上)

引子最近偶然看到一个Go语言库,口号喊出“一个超级简单(dead simple)的2D游戏引擎”,好奇点开了它的官网。官网上已经有很多可以在线体验的小游戏了(利用WASM技术)。例如曾经风靡一时的2048:当然…

「Redis数据结构」列表对象(List)

「Redis数据结构」列表对象(List) 文章目录「Redis数据结构」列表对象(List)一、概述二、结构三、编码转换四、总结一、概述 Redis列表是简单的字符串列表,按照插入顺序排序。你可以添加一个元素到列表的头部&#xf…

(附源码)php丽江旅游服务网站系统 毕业设计 010149

php丽江旅游服务网站系统 摘 要 21世纪时信息化的时代,几乎任何一个行业都离不开计算机,将计算机运用于旅游服务管理也是十分常见的。过去使用手工的管理方式对旅游服务进行管理,造成了管理繁琐、难以维护等问题,如今使用计算机对…

APP自动化测试系列之Appium介绍及运行原理

在面试APP自动化时,有的面试官可能会问Appium的运行原理,以下介绍Appium运行原理。 一、Appium介绍 1.Appium概念 Appium是一个开源测试自动化框架,可用于原生,混合和移动Web应用程序测试。它使用WebDriver协议驱动IOS&#xf…

易基因|m6A去甲基化酶ALKBH5通过降低PHF20 mRNA甲基化抑制结直肠癌进展 | 肿瘤研究

易基因|m6A去甲基化酶ALKBH5通过降低PHF20 mRNA甲基化抑制结直肠癌进展 | 肿瘤研究 大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 2022年8月17日,北京大学人民医院胃肠外科申占龙教授课题组在《Clin Transl M…

(附源码)ssm物流公司员工管理系统 毕业设计 261625

基于ssm物流公司员工管理系统 摘 要 随着互联网大趋势的到来,社会的方方面面,各行各业都在考虑利用互联网作为媒介将自己的信息更及时有效地推广出去,而其中最好的方式就是建立网络管理系统,并对其进行信息管理。由于现在网络的发…

LeetCode简单题之按身高排序

题目 给你一个字符串数组 names ,和一个由 互不相同 的正整数组成的数组 heights 。两个数组的长度均为 n 。 对于每个下标 i,names[i] 和 heights[i] 表示第 i 个人的名字和身高。 请按身高 降序 顺序返回对应的名字数组 names 。 示例 1&#xff1…

聚观早报 | 奈雪成乐乐茶第一大股东;达达与抖音达成战略合作

今日要闻:奈雪成乐乐茶第一大股东;达达与抖音达成战略合作;B 站启动新一轮降本增效;特斯拉上海工厂减产20%;大众将从中国向欧出口汽车奈雪成乐乐茶第一大股东 12 月 6 日消息,乐乐茶与奈雪的茶签署5.25亿元…

主成分分析-书后习题回顾总结

7-2 题目 理论基础 矩阵的特征值和特征向量的定义以及其求法 https://www.cnblogs.com/Peyton-Li/p/9772281.html 特征值和特征向量的定义:设AAA是nnn阶方阵,如果数λ\lambdaλ和nnn维非零列向量α\alphaα使关系式AαλαA\alpha\lambda\alphaAαλα成…

MyBatis一 Mybatis的介绍、基本使用、高级使用

一 数据库操作框架的历程 1.1 JDBC JDBC(Java Data Base Connection,java数据库连接)是一种用于执行SQL语句的Java API,可以为多种关系数据库提供统一访问,它由一组用Java语言编写的类和接口组成.JDBC提供了一种基准,据此可以构建更高级的工具和接口,使数据库开发人员能够编写…

【JDBC】----封装工具类和ORM

分享第二十二篇励志语录 有些烦恼是我们凭空虚构的,而我们却把它当成真实去承受。想得太多只会毁了你,让你陷入忐忑,让实际上本不糟糕的事情,变得糟糕。阳光这么好,何必自寻烦恼。 目录 分享第二十二篇励志语录 一&a…

毕业设计 stm32老人跌倒检测预防系统 - 单片机 物联网 嵌入式

文章目录0 前言1 整体设计2 硬件电路3 软件设计4 跌倒检测算法5 关键代码6 最后0 前言 🔥 这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要求,这两年不断有学弟学妹告诉…

linux常用命令二

1、find 查找文件或目录 find / -size 204800k //在根目录下查找大于200MB的文件 find / -user username//在根目录下查找所有者为username的文件 find / -name filename.txt //根据名称查找/目录下的filename.txt文件。 2、复制文件包括其子文件到自定目录 cp -r sourceF…

[附源码]JAVA毕业设计水果销售管理网站(系统+LW)

[附源码]JAVA毕业设计水果销售管理网站(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术…

AidLux智慧交通AI安全实战

这里写目录标题1.项目背景2. 项目实战流程2.1 YOLOv5车辆检测模型训练及部署2.2 AI对抗攻击与对抗防御2.2.1 AI对抗攻击算法讲解2.2.2 常用AI对抗攻击算法划分2.2.3 对抗攻击主要代码2.2.4 对抗攻击效果验证2.2.4 常用AI对抗防御算法讲解2.2.5 常用AI对抗防御算法划分2.2.6 对抗…

实验十一 级数与方程符号求解(MATLAB)

实验十一 级数与方程符号求解 1.1实验目的 1.2实验内容 1.3流程图 1.4程序清单 1.5运行结果及分析 1.6实验的收获与体会 1.1实验目的 1.2实验内容 实验十一 级数与方程符号求解 课本373页 1.3流程图 1.4程序清单 实验十一 1 clear clc nsym(n);xsym(x); s1…

通关算法题之 ⌈字符串⌋

字符串 171. Excel 表列序号 给你一个字符串 columnTitle ,表示 Excel 表格中的列名称,返回该列名称对应的列序号。 A -> 1 B -> 2 C -> 3 ... Z -> 26 AA -> 27 AB -> 28 ...输入: columnTitle "A" 输出: 1 输入: col…