【深度学习笔记】3_6 代码实现softmax-regression

news2025/1/4 17:44:42

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.6 softmax回归的从零开始实现

这一节我们来动手实现softmax回归。首先导入本节实现所需的包或模块。

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

3.6.1 获取和读取数据

我们将使用Fashion-MNIST数据集,并设置批量大小为256。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.2 初始化模型参数

跟线性回归中的例子一样,我们将使用向量表示每个样本。已知每个样本输入是高和宽均为28像素的图像。模型的输入向量的长度是 28 × 28 = 784 28 \times 28 = 784 28×28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为 784 × 10 784 \times 10 784×10 1 × 10 1 \times 10 1×10的矩阵。

num_inputs = 784
num_outputs = 10

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)

同之前一样,我们需要模型参数梯度。

W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

3.6.3 实现softmax运算

在介绍如何定义softmax回归之前,我们先描述一下对如何对多维Tensor按维度操作。在下面的例子中,给定一个Tensor矩阵X。我们可以只对其中同一列(dim=0)或同一行(dim=1)的元素求和,并在结果中保留行和列这两个维度(keepdim=True)。

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))

输出:

tensor([[5, 7, 9]])
tensor([[ 6],
        [15]])

下面我们就可以定义前面小节里介绍的softmax运算了。在下面的函数中,矩阵X的行数是样本数,列数是输出个数。为了表达样本预测各个输出的概率,softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任意一行元素代表了一个样本在各个输出类别上的预测概率。

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

可以看到,对于随机输入,我们将每个元素变成了非负数,且每一行和为1。

X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, X_prob.sum(dim=1))

输出:

tensor([[0.2206, 0.1520, 0.1446, 0.2690, 0.2138],
        [0.1540, 0.2290, 0.1387, 0.2019, 0.2765]]) tensor([1., 1.])

3.6.4 定义模型

有了softmax运算,我们可以定义上节描述的softmax回归模型了。这里通过view函数将每张原始图像改成长度为num_inputs的向量。

def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

3.6.5 定义损失函数

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。与3.4节(softmax回归)数学表述中标签类别离散值从1开始逐一递增不同,在代码中,标签类别的离散值是从0开始逐一递增的。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

输出:

tensor([[0.1000],
        [0.5000]])

下面实现了3.4节(softmax回归)中介绍的交叉熵损失函数。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

3.6.6 计算分类准确率

给定一个类别的预测概率分布y_hat,我们把预测概率最大的类别作为输出类别。如果它与真实类别y一致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之比。

为了演示准确率的计算,下面定义准确率accuracy函数。其中y_hat.argmax(dim=1)返回矩阵y_hat每行中最大元素的索引,且返回结果与变量y形状相同。相等条件判断式(y_hat.argmax(dim=1) == y)是一个类型为ByteTensorTensor,我们用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensor

def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

让我们继续使用在演示gather函数时定义的变量y_haty,并将它们分别作为预测概率分布和标签。可以看到,第一个样本预测类别为2(该行最大元素0.6在本行的索引为2),与真实标签0不一致;第二个样本预测类别为2(该行最大元素0.5在本行的索引为2),与真实标签2一致。因此,这两个样本上的分类准确率为0.5。

print(accuracy(y_hat, y))

输出:

0.5

类似地,我们可以评价模型net在数据集data_iter上的准确率。

# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

因为我们随机初始化了模型net,所以这个随机模型的准确率应该接近于类别个数10的倒数即0.1。

print(evaluate_accuracy(test_iter, net))

输出:

0.0681

3.6.7 训练模型

训练softmax回归的实现跟3.2(线性回归的从零开始实现)一节介绍的线性回归中的实现非常相似。我们同样使用小批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数num_epochs和学习率lr都是可以调的超参数。改变它们的值可能会得到分类更准确的模型。

num_epochs, lr = 5, 0.1

# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            
            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            l.backward()
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                optimizer.step()  # “softmax回归的简洁实现”一节将用到
            
            
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

输出:

epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

3.6.8 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。

X, y = next(iter(test_iter))# 注意这里有坑,pytorch版本不一样next写法可能不一样,可根据自己所使用的版本修改

true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])

在这里插入图片描述

小结

  • 可以使用softmax回归做多类别分类。与训练线性回归相比,你会发现训练softmax回归的步骤和它非常相似:获取并读取数据、定义模型和损失函数并使用优化算法训练模型。事实上,绝大多数深度学习模型的训练都有着类似的步骤。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1468191.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

神经网络系列---权重初始化方法

文章目录 权重初始化方法Xavier初始化(Xavier initialization)Kaiming初始化,也称为He初始化LeCun 初始化正态分布与均匀分布Orthogonal InitializationSparse Initializationn_in和n_out代码实现 权重初始化方法 Xavier初始化(X…

java面试题之SpringMVC篇

Spring MVC的工作原理 Spring MVC的工作原理如下: DispatcherServlet 接收用户的请求找到用于处理request的 handler 和 Interceptors,构造成 HandlerExecutionChain 执行链找到 handler 相对应的 HandlerAdapter执行所有注册拦截器的preHandler方法调…

Java知识点一

hello,大家好!我们今天开启Java语言的学习之路,与C语言的学习内容有些许异同,今天我们来简单了解一下Java的基础知识。 一、数据类型 分两种:基本数据类型 引用数据类型 (1)整型 八种基本数…

计算机网络-网络互联与互联网(一)

1.常用网络互联设备: 1层物理层:中继器、集线器2层链路层:网桥、交换机3层网络层:路由器、三层交换机4层以上高层:网关 2.网络互联设备: 中继器Repeater、集线器Hub(又叫多端口中继器&#xf…

一分钟 由浅入深 学会Navigation

目录 1.官网正式概念 1.1 初认知 2.导入依赖 2.1 使用navigation 2.2 safe Args插件-> 传递数据时用 3.使用Navigation 3.1 搭建初始框架 3.2 确定action箭头的属性 3.3 为Activity添加NavHostFragment控件 3.4 NavController 管理应用导航的对象 3.5 数据传递(单…

leetcode单调栈

739. 每日温度 请根据每日 气温 列表,重新生成一个列表。对应位置的输出为:要想观测到更高的气温,至少需要等待的天数。如果气温在这之后都不会升高,请在该位置用 0 来代替。 例如,给定一个列表 temperatures [73, …

基于SpringBoot的停车场管理系统

基于SpringBootVue的停车场管理系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringBootMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 前台首页 停车位 个人中心 管理员界面 摘要 摘要:随着城市化进程的…

基于django的购物商城系统

摘要 本文介绍了基于Django框架开发的购物商城系统。随着电子商务的兴起,购物商城系统成为了许多企业和个人创业者的首选。Django作为一个高效、稳定且易于扩展的Python web框架,为开发者提供了便捷的开发环境和丰富的功能模块,使得开发购物商…

Java零基础 - 条件运算符

哈喽,各位小伙伴们,你们好呀,我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一个人虽可以走的更快,但一群人可以走的更远。 我是一名后…

Vue(学习笔记)

什么是Vue Vue是一套构建用户界面的渐进式框架 构建用户界面: 基于数据渲染出用户可以看到的界面 渐进式: 所谓渐进式就是循序渐进,不一定非得把Vue中的所有API都学完才能开发Vue,可以学一点开发一点 创建Vue实例 比如就上面…

k8s学习笔记-基础概念

(作者:陈玓玏) deployment特别的地方在于replica和selector,docker根据镜像起容器,pod控制容器,job、cronjob、deployment控制pod,job做离线任务,pod大多一次性的,cronj…

汽车常识网:电脑主机如何算功率的计算方法?

今天汽车知识网就给大家讲解一下如何计算一台主机的功率。 它还会解释如何计算计算机主机所需的功率? ? (如何计算电脑主机所需的功率)进行说明。 如果它恰好解决了您现在面临的问题,请不要忘记关注本站。 让我们现在就…

vue3 vite 经纬度逆地址解析

在web端测试经纬度逆地址解析有2中方式,先准备好两个应用key 第一种,使用“浏览器端”应用类型 const address ref() const latitude ref() // 经度 const longitude ref() // 纬度 const ak 你的key // 浏览器端 function getAddressWeb() {// 创建…

【读博杂记】:近期日常240223

近期日常 最近莫名其妙,小导悄悄卷起来,说要早上八点半开始打卡,我感觉这是要针对我们在学校住的,想让我们自己妥协来这边租房子住,但我感觉这是在逼我养成规律作息啊!现在基本上就是6~7点撤退,…

【Spring】 AOP面向切面编程

文章目录 AOP是什么?一、AOP术语名词介绍二、Spring AOP框架介绍和关系梳理三、Spring AOP基于注解方式实现和细节3.1 Spring AOP底层技术组成3.2 初步实现3.3 获取通知细节信息3.4 切点表达式语法3.5 重用(提取)切点表达式3.6 环绕通知3.7 切…

R语言入门笔记2.6

描述统计 分类数据与顺序数据的图表展示 为了下面代码便于看出颜色参数所对应的值,在这里先集中介绍, col1是黑色,2是粉红,3是绿色,4是天蓝,5是浅蓝,6是紫红,7是黄色,…

前沿科技速递——YOLOv9

随着YOLO系列的不断迭代更新,前几天,YOLO系列也迎来了第九个大型号的更新!YOLOv9正式推出了!附上原论文链接。 arxiv.org/pdf/2402.13616.pdf 同样是使用MS COCO数据集进行对比比较,通过折线图可看出AP曲线在全方面都…

一、系统架构师考试介绍

一、系统架构设计师介绍 系统架构设计师在软考体系中,属于高级资格。(不需要先考中级可以直接报考高级,我之前不知道还考了软件设计师T.T不如当初直接考系统架构师) 考试时间: 每年11月份的第二个周六 报名方式: 网上报名 报名网址 http://wwwruankao.…

C++常见问题

C常见问题 引用模板STLvector原理移动语义与右值引用New delete与malloc freeinlineconststaticexplicit 的作用lambda 表达式友元public、protected、private的区别封装继承多态虚函数重载、重写、隐藏的区别智能指针C 11新特性深拷贝与浅拷贝虚拟内存内存对齐及内存泄漏C内存…

解决ubuntu系统cannot find -lc++abi: No such file or directory

随着CentOS的没落,使用ubuntu的越来越多,而且国外貌似也比较流行使用ubuntu,像LLVM/Clang就有专门针对ubuntu编译二进制发布文件: ubuntu本身也可以直接通过apt install命令来安装编译好的clang编译器。不过目前22.04版本下最高…