【动手学习深度学习--逐行代码解析合集】04softmax回归的从零开始实现

news2025/1/4 16:30:36

【动手学习深度学习】逐行代码解析合集

04softmax回归的从零开始实现


视频链接:动手学习深度学习–softmax回归的从零开始实现
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

1、 softmax网络架构

在这里插入图片描述
2、 softmax运算

在这里插入图片描述
3、 交叉熵损失函数

3.1、 对数似然函数

在这里插入图片描述
在这里插入图片描述

3.2、 softmax及其导数
在这里插入图片描述
3.3、 交叉熵损失

在这里插入图片描述
4、代码

以下代码是在PyCharm中运行的

实用程序类Accumulator若不懂,参考文章链接

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

"====================1、定义初始化模型参数===================="
# 输入28*28=784 ,拉成一条向量;数据集10个类别,所以网络输出维度为10
num_inputs = 784
num_outputs = 10

# W初始化为高斯随机分布的值
# 权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

"====================2、定义softmax操作===================="
'''
实现softmax由三个步骤组成:
1、对每个项求幂(使用exp);
2、对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
3、将每一行除以其规范化常数,确保结果的和为1。
'''
def softmax(X):
    # 做指数运算
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)  # 对每一行求和
    return X_exp / partition  # 这里应用了广播机制
# 正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

"====================3、定义模型===================="
def net(X):
    # W.shape[0]=784,X==>256*784
    # reshape中的-1表示系统帮助计算(结果为256)
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b) # 交叉熵损失

"====================4、定义损失函数===================="
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y] # numpy高级索引 : 拿出对应真实标号的预测值
" 输出:tensor([0.1000, 0.5000]) "

# 一行代码就可以实现交叉熵损失函数
def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

cross_entropy(y_hat, y)
" 输出:tensor([2.3026, 0.6931]) "

"====================5、分类精度===================="
def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    # 现在 y_hat是一个256*10的一个矩阵
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        # 对每一行中元素值最大的下标存到y_hat中
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y  # 作比较转为布尔类型
    return float(cmp.type(y.dtype).sum())  # 预测正确的个数总和
# 预测正确的概率
accuracy(y_hat, y) / len(y)
"输出:0.5"

"评估在任意模型net上的准确率"
def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            # 将所有预测正确的样本数,样本总数量加入迭代器中
            metric.add(accuracy(net(X), y), y.numel())
    # 返回预测正确的样本数和样本总数
    return metric[0] / metric[1]

"在Accumulator实例中创建了2个变量, 分别用于存储正确预测的数量和预测的总数量"
class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n
        # n=2时,self.data = [0.0,0.0]

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
        # 若args接收的传参为(4, 5),那么for a, b in zip(self.data, args)表示(0,4)(0,5)
        # a = 0.0,b = 4,然后执行a + float(b),得到结果4.0,此时self.data = [4.0, 0.0],
        # a = 0.0, b = 5,然后执行a + float(b) 得到结果5.0,最后self.data = [4.0, 5.0]。

    def reset(self):  # 重新设置空间大小并初始化。
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):  # 实现类似数组的取操作。
        return self.data[idx]
# 由于我们使用随机权重初始化net模型,因此该模型的精度应接近于随机猜测.例如在有10个类别情况下的精度为0.1。
evaluate_accuracy(net, test_iter)
"输出: 0.0598"

训练部分相关代码

"====================6、训练===================="
def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

"定义一个在动画中绘制数据的实用程序类"
class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)  # 得到训练损失和训练精度
        test_acc = evaluate_accuracy(net, test_iter)  # 在测试数据集上评估模型精度
        # 可视化训练误差,训练精度,测试精度
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    # 判断训练损失(train_loss)是否小于0.5,如果不满足条件就会抛出异常并打印出train_loss的值
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

"小批量随机梯度下降来优化模型的损失函数"
lr = 0.1
def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

"训练模型10个迭代周期"
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

"====================7、预测===================="
def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)  # 真实标号
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))  # 预测标号
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
    d2l.plt.show()

predict_ch3(net, test_iter)

训练损失、训练准确率、测试准确率可视化

在这里插入图片描述

预测结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/710072.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实验三:运算类编程实验

实验目的 阐明本实验的目的。 一实验目的掌握学精变C6应) 乘汉运第《结购 3位)江程产设计方法穿握双精度[3位)加法运算乡饰程产设计方法穿握双精度园) 朱讼运算 [结果为64位)汇编往广设计方讯 实验要求 说明实现本实验需要掌握的知识及本实验需要的实验环境 、实强要求 了解简单…

20+个小而精的Python实战案例(附源码和数据)

公众号&#xff1a;尤而小屋作者&#xff1a;Peter编辑&#xff1a;Peter 大家好&#xff0c;我是Peter~ 最近小编认真整理了20个基于python的实战案例&#xff0c;主要包含&#xff1a;数据分析、可视化、机器学习/深度学习、时序预测等&#xff0c;案例的主要特点&#xff1…

spring boot security验证码登录示例

前言 在spring boot security自定义认证一文&#xff0c;基本给出了一个完整的自定义的用户登录认证的示例&#xff0c;但是未涉及到验证的使用&#xff0c;本文介绍登录的时候如何使用验证码。 本文介绍一个验证码生成工具&#xff0c;比较老的一个库了&#xff0c;仅作demo…

rust warp框架教程1-helloworld

warp框架简介 warp is a super-easy, composable, web server framework for warp speeds. warp建立在hyper之上&#xff0c;因此&#xff0c;warp天生支持异步&#xff0c;HTTP/2&#xff0c;以及“正确的HTTP实现”。 warp的强大之处在于其提供的filter系统&#xff0c;它…

软件设计模式与体系结构-设计模式-生成模式单例模式

目录 二、生成器模式1. 生成者模式概念实例一&#xff1a;房屋选购系统题目时序图类图 优缺点适用场景 2. 生成器模式与抽象工厂模式3. 课程作业*** 三、单例模式1. 单例模式要点&#xff1a;基本思路实例一&#xff1a;互联网连接问题 2. 多线程情况3. 优缺点4. 适用场景5. 课…

leetcode 88.合并两个有序数组

⭐️ 题目描述 &#x1f31f; leetcode链接&#xff1a;合并两个有序数组 ⭕️ 代码&#xff1a; /*思路&#xff1a;双指针问题1.从前往后拷贝依次比较两个数组元素的较小值&#xff0c;较小值先拷贝- 问题&#xff1a;从前拷贝会造成覆盖(有问题)2.从后往前拷贝依次比较两个…

SpringBoot(五)SpringBoot事务

在实际开发项目时&#xff0c;程序并不是总会按照正常的流程去执行&#xff0c;有时候线上可能出现一些无法预知的问题&#xff0c;任何一步操作都有可能发生异常&#xff0c;异常则会导致后续的操作无法完成。此时由于业务逻辑并未正确的完成&#xff0c;所以在之前操作过数据…

单臂路由实现不同VLAN之间数据转发

实验环境&#xff1a; 思科模拟器&#xff0c;Cisco Packet Tracer 实验拓扑&#xff1a; 实验配置&#xff1a; &#xff08;1&#xff09;PC配置 IP地址子网掩码网关PC1192.168.10.1255.255.255.0192.168.10.254PC2192.168.10.2255.255.255.0192.168.10.254PC3192.168.20…

串口通讯监控方法

当我们调试硬件的时候&#xff0c;发现串口数据异常&#xff0c;用示波器和逻辑分析仪的话会比较麻烦&#xff0c;此时可以并一个监控串口&#xff0c;如下图所示 232串口&#xff0c;我们是不能直接并一个串口上去的&#xff1b;但是我们的监控串口&#xff0c;可以只接一根R…

【玩转循环】探索Python中的无限可能性

前言 循环可能是每个编程语言中使用比较多的语法了&#xff0c;如果能合理利用好循环&#xff0c;就会出现意想不到的结果&#xff0c;大大地减少代码量&#xff0c;让机器做那些简单枯燥的循环过程&#xff0c;今天我将为大家分享 python 中的循环语法使用。&#x1f697;&am…

数据结构--栈的链式存储

数据结构–栈的链式存储 推荐使用不带头结点的单链表 \color{green}推荐使用不带头结点的单链表 推荐使用不带头结点的单链表 typedef struct LNode {ElemType data;struct LNode* next; } LNode, *LinkList;bool InitList(LinkList &L) {L->next NULL; }后插操作&…

python网络编程(二)模拟ssh远程执行命令

1、项目需求&#xff1a; 要实现一个像ssh远程连接工具一样&#xff0c;在终端输入命令&#xff0c;返回对应的结果。 比如window的dos命令&#xff1a; dir &#xff1a;查看目录下的文件 ipconfig : 查看网卡信息 tasklist : 查看进程列表 linux的命令&#xff1a; ls : 查看…

Jenkins与CI/CD

简介 CI&#xff08;持续集成&#xff09; Continuous Integration是一种软件开发实践&#xff0c;即团队开发成员经常集成他们的工作&#xff0c;通常每个成员每天至少集成一次&#xff0c;也就意味着每天可能会发生多次集成。每次集成都通过自动化的构建&#xff08;包括编…

Debian 环境使用 docker compose 部署 sentry

Debian 环境使用 docker compose 部署 sentry Sentry 简介什么是 Sentry &#xff1f;Sentry 开发语言及支持的 SDKSentry 功能架构 前置准备条件规格配置说明Dcoker Desktop 安装WSL2/Debian11 环境准备 Sentry 安装步骤docker 部署 sentry 步骤演示过程说明 总结 Sentry 简介…

python机器学习在气象模式订正、短临预报、气候预测等场景的应用

基于机器学习的天河机场物流预测研究 全球经济快速增长的形势下,八大区域性枢纽之一的武汉天河机场的物流需求也在攀升。文章针对天河机场的货邮吞吐量,运用机器学习中的线性回归模型通过Python对其进行需求预测,并用二次指数平滑法与之对比,在平均绝对百分误差比较下得出机器…

需求分析引言:架构漫谈(四)性能专题

前文介绍了非功能性需求里的可靠性和可用性&#xff0c; 本文对非功能性需求里的性能&#xff0c;进行一些详细的说明&#xff0c;和如何度量系统的性能问题。 1、概念 性能通常是指一个软件系统的处理能力和速度&#xff0c;一般通过 延迟 和 吞吐量 这两个指标进行度量。 不…

分布式软件架构——域名解析系统

透明多级分流系统的设计原则 用户在使用信息系统的过程中&#xff0c;请求首先是从浏览器出发&#xff0c;在DNS的指引下找到系统的入口&#xff0c;然后经过了网关、负载均衡器、缓存、服务集群等一系列设施&#xff0c;最后接触到了系统末端存储于数据库服务器中的信息&…

云计算——容器

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​ 目录 前言 一.容器简介 二.主流容器技术 1.docker &#xff08;1&#xff09;容器的组…

HTML5+ Runtime提示

使用的环境 vue-cli框架&#xff0c;Andriod调试、云打包都会出现该弹框 1.我遇到的问题 上述弹框提示&#xff0c;HBuilderX3.8.2 &#xff0c; 手机SDK版本是3.8.4&#xff0c;不匹配 解决目的&#xff1a;需要让两个版本匹配 2. 点击“查看详情”&#xff0c;查看原因 …

JS文件UTF8格式乱码问题

UTF8格式的JS文件在IE中显示乱码问题的解决 这种情况通常是由于JS文件头缺少BOM标志引起的,解决方式: 方法1:用系统自带记事本,另存为 UTF-8,覆盖原文件,会自动加上BOM标志(就是文件开头的EF BB BF 三个字节) 方法2: 用notepad 打开,编码菜单,由UTF8编码改为 UTF8-BOM编码