【从零开始学习深度学习】7.自己动手实现softmax回归的训练与预测

news2024/12/29 8:55:21

基于上一篇文章读取fashion-minist数据集的基础,本文自己动手实现一个softmax模型对其进行训练与预测。

目录

  • 1. 自己动手实现softmax回归
    • 1.1 读取数据
    • 1.2 初始化模型参数
    • 1.3 实现softmax运算
    • 1.4 定义模型
    • 1.5 定义损失函数
    • 1.6 计算分类准确率
    • 1.7 训练模型
    • 1.8 预测
    • 完整代码
    • 小结

1. 自己动手实现softmax回归

首先导入本节实现所需的包或模块。

import torch
import torchvision
import numpy as np
import sys
import d2lzh_pytorch as d2l

1.1 读取数据

我们将使用Fashion-MNIST数据集,并设置批量大小为256。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

1.2 初始化模型参数

跟线性回归中的例子一样,我们将使用向量表示每个样本。已知每个样本输入是高和宽均为28像素的图像。模型的输入向量的长度是 28 × 28 = 784 28 \times 28 = 784 28×28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为 784 × 10 784 \times 10 784×10 1 × 10 1 \times 10 1×10的矩阵。

num_inputs = 784
num_outputs = 10

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)

同之前一样,我们需要模型参数梯度。

W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

1.3 实现softmax运算

在介绍如何定义softmax回归之前,我们先描述一下对如何对多维Tensor按维度操作。在下面的例子中,给定一个Tensor矩阵X。我们可以只对其中同一列(dim=0)或同一行(dim=1)的元素求和,并在结果中保留行和列这两个维度(keepdim=True)。

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))

输出:

tensor([[5, 7, 9]])
tensor([[ 6],
        [15]])

下面我们就可以定义之前介绍的softmax运算了。在下面的函数中,矩阵X的行数是样本数,列数是输出个数。为了表达样本预测各个输出的概率,softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任意一行元素代表了一个样本在各个输出类别上的预测概率。

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

可以看到,对于随机输入,我们将每个元素变成了非负数,且每一行和为1。

X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, X_prob.sum(dim=1))

输出:

tensor([[0.2206, 0.1520, 0.1446, 0.2690, 0.2138],
        [0.1540, 0.2290, 0.1387, 0.2019, 0.2765]]) tensor([1., 1.])

1.4 定义模型

有了softmax运算,就可以定义softmax回归模型。这里通过view函数将每张原始图像改成长度为num_inputs的向量。

def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

1.5 定义损失函数

之前介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。与3.4节(softmax回归)数学表述中标签类别离散值从1开始逐一递增不同,在代码中,标签类别的离散值是从0开始逐一递增的。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

输出:

tensor([[0.1000],
        [0.5000]])

下面为(softmax回归)中介绍的交叉熵损失函数。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

1.6 计算分类准确率

给定一个类别的预测概率分布y_hat,我们把预测概率最大的类别作为输出类别。如果它与真实类别y一致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之比。

为了演示准确率的计算,下面定义准确率accuracy函数。其中y_hat.argmax(dim=1)返回矩阵y_hat每行中最大元素的索引,且返回结果与变量y形状相同。相等条件判断式(y_hat.argmax(dim=1) == y)是一个类型为ByteTensorTensor,我们用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensor

def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

让我们继续使用在演示gather函数时定义的变量y_haty,并将它们分别作为预测概率分布和标签。可以看到,第一个样本预测类别为2(该行最大元素0.6在本行的索引为2),与真实标签0不一致;第二个样本预测类别为2(该行最大元素0.5在本行的索引为2),与真实标签2一致。因此,这两个样本上的分类准确率为0.5。

print(accuracy(y_hat, y))

输出:

0.5

类似地,我们可以评价模型net在数据集data_iter上的准确率。

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

因为我们随机初始化了模型net,所以这个随机模型的准确率应该接近于类别个数10的倒数即0.1。

print(evaluate_accuracy(test_iter, net))

输出:

0.0681

1.7 训练模型

使用小批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数num_epochs和学习率lr都是可以调的超参数。改变它们的值可能会得到分类更准确的模型。

num_epochs, lr = 5, 0.1

def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            
            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            l.backward()
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                optimizer.step()  
            
            
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

输出:

epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

1.8 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。

X, y = iter(test_iter).next()

true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])

完整代码

import torch
import torchvision
import numpy as np
import sys
import d2lzh_pytorch as d2l

# 加载数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 初始化模型
num_inputs = 784
num_outputs = 10

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

# 定义softmax回归
def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

# 网络计算
def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

# 计算准确率
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

# 交叉熵损失
def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

# 模型训练
num_epochs, lr = 5, 0.1

def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            
            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            l.backward()
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                optimizer.step()  # “softmax回归的简洁实现”一节将用到
            
            
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

# 预测
X, y = iter(test_iter).next()

true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])

小结

  • 可以使用softmax回归做多类别分类。与训练线性回归相比,你会发现训练softmax回归的步骤和它非常相似:获取并读取数据、定义模型和损失函数并使用优化算法训练模型。

如果内容对你有帮助,感谢点赞+关注哦!

关注下方GZH,可获取更多干货内容~欢迎共同学习交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/63751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面试碰壁15次!作为一个已经27岁的测试工程师,未来在何方....

3年测试经验原来什么都不是,只是给你的简历上画了一笔,一直觉得经验多,无论在哪都能找到满意的工作,但是现实却是给我打了一个大巴掌!事后也不会给糖的那种... 先说一下自己的个人情况,普通二本计算机专业…

LabVIEW编程LabVIEW开发SMP10辐射表例程与相关资料

LabVIEW编程LabVIEW开发SMP10辐射表例程与相关资料 ​​SMP10辐射表是荷兰Kipp&Zonen公司的一种用于测量短波辐射的产品,配有只能型接口,能够提供标准输出,能耗低。 作为一款副基准总辐射表,SMP10结合了CMP 11的传感器技术、SMP 11的智…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java基于自组网的空地一体化信息系统mf392

面对老师五花八门的设计要求,首先自己要明确好自己的题目方向,并且与老师多多沟通,用什么编程语言,使用到什么数据库,确定好了,在开始着手毕业设计。 1:选择课题的第一选择就是尽量选择指导老师…

[附源码]计算机毕业设计JAVA疫情期间回乡人员管理系统

[附源码]计算机毕业设计JAVA疫情期间回乡人员管理系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM…

基于Java的课程管理系统

摘 要 在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括课程管理系统的网络应用,在外国课程管理已经是很普遍的方式,不过国内的课程管理可能还处于起步阶段。课程管理系统具有下载课件功能。课程管理系…

数据结构(12)Dijkstra算法JAVA版:图的最短路径问题

目录 12.1.概述 12.1.1.无权图的最短路径 12.1.2.带权图的最短路径 1.单源最短路径 2.多源最短路径 12.2.代码实现 12.1.概述 12.1.1.无权图的最短路径 无权图的最短路径,即最少步数,使用BFS贪心算法来求解最短路径,比较好实现&#xf…

04-05 - 主引导程序的扩展(实验未完)

---- 整理自狄泰软件唐佐林老师课程 1. 突破限制的思路 限制:主引导程序的代码不能超过512字节 主引导程序完成: 完成最基本的初始化工作从存储介质中加载程序到内存将控制权交由新加载的程序执行…… 问题: 主引导程序如何加载存储介质中的…

Windows上Qt源码调试(使用VS2017调试qt5.12.0)

环境:vs2017 qt 5.12.0 msvc32和msvc64 1.下载源代码 把所用 Qt 库版本对应源码(qt-everywhere-src-5.12.0)下载来解压(https://download.qt.io/archive/qt/5.12/5.12.0/single/),或者安装时选择把源码&…

一文带你掌握JSP基础知识

✅作者简介:热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏:JAVA开发者…

赋能建筑建材企业物流网络内外联通,B2B交易管理系统打造行业智慧供应链

数据显示,在疫情和行业转型升级的双重压力下,行业中竞争力不强、商业模式老套的建筑建材企业在疫情中产值下降甚至被淘汰出局。随着数字经济的兴起,传统建筑建材产业的发展也带来了巨大的变革。 据有关数据分析指出,数字化已经成…

数据之道读书笔记-08打造“清洁数据”的质量综合管理能力

数据之道读书笔记-08打造“清洁数据”的质量综合管理能力 越来越多的企业应用和服务都基于数据而建,数据质量是数据价值得以发挥的前提。例如企业运营效率主要依赖于数据获取的准确性和及时性,企业客户关系管理系统中的错误或不完整数据将导致客户沟通不…

安卓讲课笔记6.1 共享参数

文章目录零、本讲学习目标一、导入新课二、新课讲解(一)数据存储(二)共享参数1、共享参数概述2、利用共享参数读写文件步骤(三)案例演示:多窗口共享数据1、创建安卓应用2、准备图片素材3、主界面…

【LeetCode每日一题】——141.环形链表

文章目录一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【解题思路】七【题目提示】八【题目进阶】九【时间频度】十【代码实现】十一【提交结果】一【题目类别】 链表 二【题目难度】 简单 三【题目编号】 141.环形链表 四【题目描述】 给…

【gbase8a】docker搭建gbase8a,详细【图文】

docker搭建gbase8a安装docker安装GBase 8a查询安装的版本拉取镜像启动进入容器创建用户dbever测试安装docker 其中具有docker的搭建 搭建docker,docker搭建达梦数据库,详细【图文】 https://blog.csdn.net/weixin_44385419/article/details/127738868 d…

Spark 数据倾斜调优10策

一、数据倾斜概述 1.1 什么是数据倾斜 对Hadoop、Spark、Flink这样的大数据系统来讲,数据量大并不可怕,可怕的是数据倾斜。 何谓数据倾斜?数据倾斜指的是,并行处理的数据集中,某一部分(如Spark或Kafka的…

@SpringBootApplication中的注解

Target(ElementType.TYPE):指示适用注释类型的上下文(即注解的作用目标)这里是接口、类、枚举、注解 Retention(RetentionPolicy.RUNTIME):指示具有注释类型的注释要保留多长时间,这里注解是将被JVM保 留,所以在运行…

无法安装64位版本的office,因为在您的PC上找到以下32位程序

无法安装64位版本的office,因为在您的PC上找到以下32位程序: 请卸载所有32位office程序,然后重试安装64位office。如果想要安装32位office,请运行32位安装程序。 那为什么会出现这种情况呢? 首先,我们要知道我们的电脑是32位的还…

9个发展您的B2B业务的LinkedIn营销策略

没有比在 LinkedIn 上与其他公司建立联系更好的地方了。您可以与数以百万计的品牌和专业人士建立联系并发展您的业务。 您可以尝试多种不同的 B2B LinkedIn营销策略,以便与您的受众建立联系并将他们转变为您的客户。 事实上,根据公司自己的研究&#x…

Vue3.2中的setup语法糖(易懂)

简介 在vue3中删除了vue2中的data函数,因此,vue3.0要在template中使用某些变量就必须在最后return出来,多次声明变量,不太方便。而在vue3.2版本之后,新增了setup语法糖。 直接在script标签中添加setup属性就可以直接使…

Arduino开发实例-DIY电能表

DIY电能表 在本文中,将展示如何制作一个基于 Arduino 的功率和电能表。应用使用 INA219 电流传感器测量电流、功率和能耗,并将其显示在 OLED 显示屏上。 可以在 OLED 显示屏上查看您的电压、电流、功率和能量数据。 1、INA219介绍 INA219 电流传感器是一款支持 I2C 的基于…