3-2 多层感知机的从零开始实现

news2024/11/16 0:34:50
import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256 # 批量大小为256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# load进来训练集和测试集

初始化模型参数

回想一下,Fashion-MNIST中的每个图像由 28 × 28 = 784 28 \times 28 = 784 28×28=784个灰度像素值组成。 所有图像共分为 10 10 10个类别。 忽略像素之间的空间结构, 我们可以将每个图像视为具有 784 784 784个输入特征 和 10 10 10个类的简单分类数据集。
首先,我们将实现一个具有单隐藏层的多层感知机, 它包含 256 256 256个隐藏单元。 注意,我们可以将这两个变量都视为超参数。 通常,我们选择 2 2 2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效
我们用几个张量来表示我们的参数。 注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。 跟以前一样,我们要为损失关于这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
# 初始成一个随机的,行数是num_inputs输入的个数,列数是num_hiddens,requires_grad=True需要算梯度
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
# 偏差是一个隐藏层的个数,是一个向量,初始化为全0
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]
# 第一层参数 w1和b1
# 第二层参数 w2和b2

激活函数

为了确保我们对模型的细节了如指掌, 我们将实现ReLU激活函数, 而不是直接调用内置的relu函数。

def relu(X):
    a = torch.zeros_like(X)
    # 产生一个矩阵a,a的shape与X一模一样,但是元素全为0
    return torch.max(X, a)

模型

因为我们忽略了空间结构, 所以我们使用reshape将每个二维图像转换为一个长度为num_inputs的向量。 只需几行代码就可以实现我们的模型。

def net(X):
    X = X.reshape((-1, num_inputs))
    # 把X拉成一个二维的矩阵
    # -1 让它自己算,实际是batch size
    H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法
    return (H@W2 + b2)

损失函数

由于我们已经从零实现过softmax函数, 因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。

loss = nn.CrossEntropyLoss(reduction='none')

训练

多层感知机的训练过程与softmax回归的训练过程完全相同。 可以直接调用d2l包的train_ch3函数, 将迭代周期数设置为10,并将学习率设置为0.1.

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

请添加图片描述
相比softmax,我的损失更低了!我的精度其实没有发生太大变化。
因为我的模型更大了,所以我的数据拟合性更好,所以我的损失在下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【考研408操作系统】最容易理解的知识体系-文件管理-面向计算机管理

苏泽 “弃工从研”的路上很孤独,于是我记下了些许笔记相伴,希望能够帮助到大家 本篇内容续写上一篇的【考研408&操作系统】最容易理解的知识体系-文件管理-面向人类管理 这一篇将站在计算机如何管理好文件的角度去梳理这一章节的知识点 目录 本…

.欧拉函数.

先介绍欧拉函数: 贴一张 证明: 这里利用容斥原理来进行证明:若要求1~N当中与N互质的个数,则应在1~N当中去除N的质因数的倍数,因为既然是因数,那么一定不与N互质,既然是N的因数,那么…

初识Laravel(Laravel的项目搭建)

初识Laravel(Laravel的项目搭建) 一、项目简单搭建(laravel)1.首先我们确保使用国内的 Composer 加速镜像([加速原理](https://learnku.com/php/wikis/30594)):2.新建一个名为 Laravel 的项目&a…

gfast前端UI:基于Vue3与vue-next-admin适配手机、平板、pc 的后台开源模板

摘要 随着现代软件开发的高效化需求,一个能够快速适应不同设备、简化开发过程的前端模板变得至关重要。gfast前端UI,基于Vue3.x和vue-next-admin,致力于提供这样一个解决方案。本文将深入探讨gfast前端UI的技术栈、设计原则以及它如何适配手机…

(补充):java各种进制、原码、反码、补码和文本、图像、音频在计算机中的存储方式

文章目录 前言一、进制1 逢几进一2 常见进制在java中的表示3 进制中的转换(1)任意进制转十进制(2)十进制转其他进制二、计算机中的存储1 计算机的存储规则(文本数据)(1)ASCII码表(2)编码规则的发展演化2 计算机的存储规则(图片数据)(1)分辨率、像素(2)黑白图与灰度…

Linux 复现Docker NAT网络

Linux 复现Docker NAT网络 docker 网络的构成分为宿主机docker0网桥和为容器创建的veth 对构成。这个默认网络命名空间就是我们登陆后日常使用的命名空间 使用ifconfig命令查看到的就是默认网络命名空间,docker0就是网桥,容器会把docker0当成路由&…

linux nethogs网络监控程序(端口监控、流量监控、上传流量、下载流量、进程监控进程网络)

文章目录 Nethogs 网络监控程序详解1. 引言2. Nethogs 的安装与运行2.1 安装 Nethogs- **Debian/Ubuntu**- **Fedora**- **Arch Linux** 2.2 运行 Nethogs 3. Nethogs 的使用详解3.1 基本界面- **PID**:进程的 ID。- **用户**:运行该进程的用户。- **程序…

【Linux网络】数据链路层【上】{初识数据链路层/以太网/路由表/MAC地址表/ARP表/NAT表}

文章目录 1.初识数据链路层2.认识以太网2.0前导知识以太网帧和MAC帧CMSA/CD以太网的最小帧长限制是64字节IP层和MAC层 2.1以太网帧格式 3.预备知识计算机网络通信以太网和wifi路由表/MAC地址表/ARP表/NAT表/ACL表 用于同一种数据链路节点的两个设备之间进行信息传递。 1.初识数…

美团一面,你碰到过CPU 100%的情况吗?你是怎么处理的?

本文主要分为三部分 分析一下CPU 100%的常见原因 CPU 100%如何排查 回答这个问题的一个参考答案 CPU被打满的常见原因 1. 死循环 在实际工作中,可能每个开发都写过死循环的代码。 死循环有两种: 在 while、for、forEach 循环中的死循环。 无限递…

期末成绩单怎么单独发给家长,这个小工具超简单!

随着期末考试的落幕,老师们再次迎来了成绩处理的高峰期。传统的成绩单分发方式不仅耗时,还容易出错。但如今,有了易查分小程序,这一过程变得简便而高效。 易查分小程序,一个专为教师和家长设计的便捷工具,让…

[ruby on rails]部署时候产生ActiveRecord::PreparedStatementCacheExpired错误的原因及解决方法

一、问题: 有时在 Postgres 上部署 Rails 应用程序时,可能会看到 ActiveRecord::PreparedStatementCacheExpired 错误。仅当在部署中运行迁移时才会发生这种情况。发生这种情况是因为 Rails 利用 Postgres 的缓存准备语句(PreparedStatementCache)功能来…

【Apache Doris】周FAQ集锦:第 10 期

【Apache Doris】周FAQ集锦:第 10 期 SQL问题数据操作问题运维常见问题其它问题关于社区 欢迎查阅本周的 Apache Doris 社区 FAQ 栏目! 在这个栏目中,每周将筛选社区反馈的热门问题和话题,重点回答并进行深入探讨。旨在为广大用户…

算法力扣刷题记录 四十五【110.平衡二叉树】

前言 二叉树篇继续 记录 四十五【110.平衡二叉树】 一、题目阅读 给定一个二叉树,判断它是否是 平衡二叉树。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:true示例 2: 输入:root [1,2,2,3,3…

【鸿蒙学习笔记】尺寸设置・width・height・size・margin・padding・

官方文档:尺寸设置 目录标题 width:设置组件自身的宽度height:设置组件自身的高度size:设置高宽尺寸margin:设置组件的外边距padding:设置组件的内边距 width:设置组件自身的宽度 参数为Length…

【零基础】学JS之APIS第三天

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 非常期待和您一起在这个小…

10分钟快速了解神经网络(Neural Networks)

神经网络是深度学习算法的基本构建模块。神经网络是一种机器学习算法,旨在模拟人脑的行为。它由相互连接的节点组成,也称为人工神经元,这些节点组织成层次结构。 Source: victorzhou.com 神经网络与机器学习有何不同? 神经网络是…

电脑资料丢失不用慌,5招教你恢复数据

在数字化时代,电脑资料的安全与完整对我们而言至关重要。然而,生活中总有一些小插曲,如意外删除、系统故障或病毒攻击等,导致电脑上的重要资料消失得无影无踪。面对这种情况,我们往往感到焦虑和无助。今天,…

LabVIEW心电信号自动测试系统

开发了一种基于LabVIEW的心电信号自动测试系统,通过LabVIEW开发的上位机软件,实现对心电信号的实时采集、分析和自动化测试。系统包括心电信号采集模块、信号处理模块和自动化测试模块,能够高效、准确地完成心电信号的测量与分析。 硬件系统…

在 SwiftUI 中的作用域动画

文章目录 前言简单示例动画视图修饰符使用多个可动画属性使用 ViewBuilder总结 前言 从一开始,动画就是 SwiftUI 最强大的功能之一。你可以在 SwiftUI 中快速构建流畅的动画。唯一的缺点是每当我们需要运行多步动画或将动画范围限定到视图层次结构的特定部分时&…

网络规划设计师教程(第二版) pdf

网络规划设计师教程在网上找了很多都是第一版,没有第二版。 所以去淘宝买了第二版的pdf,与其自己独享不如共享出来,让大家也能看到。 而且这个pdf我已经用WPS扫描件识别过了,可以直接CtrlF搜索关键词,方便查阅。 链接…