机器学习深度学习——多层感知机的从零开始实现

news2024/9/24 17:20:43

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——多层感知机
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

为了与之前的softmax回归获得的结果进行比较,将继续使用Fashion-MNIST图像分类数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

多层感知机的从零开始实现

  • 初始化模型参数
  • 激活函数
  • 模型
  • 损失函数
  • 训练
  • 预测

初始化模型参数

数据集的每个图像由28×28=784个灰度像素值组成。所有图像分为10个类别。
忽略像素间的空间结构,我们可以将每个图像视为具有784个输入特征和10个类的简单分类数据集。
首先,我们将实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元。注意,我们可以将这两个变量都视为超参数。通常,我们选择2的若干次幂作为层的宽度。因为内存在硬件的分配和寻址方式,这么做往往可以在计算上更高效。
我们用几个张量来表示我们的参数。注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。并要为这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256
W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [W1, b1, W2, b2]

激活函数

这里就不用内置的了,自己实现一下:

def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

模型

既然忽略了空间结构,那就直接用reshape将每个二维图像转换为一个长度为num_inputs的向量:

def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X@W1 + b1)  # "@"表示矩阵乘法
    return (H@W2 + b2)

损失函数

之前已经从零实现过了softmax函数,这里直接用内置函数计算softmax和交叉熵损失(为什么要计算这两个,之前在softmax的简洁实现中曾经证明过)

loss = nn.CrossEntropyLoss(reduction='none')

训练

训练过程和softmax一样,直接调用d2l的train_ch3函数就行了,将迭代周期数设为10,学习率设为0.1。

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

预测

对模型进行评估,我们在测试数据上应用这个模型。

d2l.predict_ch3(net, test_iter)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/798357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机二级Python基本操作题-序号41

1. 输入字符串s,按要求把s输出到屏幕,格式要求:宽度为30个字符,星号字符*填充,居中对齐。如果输入字符串超过30位,则全部输出。 例如:键盘输入字符串s为"Congratulations”,屏幕…

基于Kaggle训练集预测的多层人工神经网络的能源消耗的时间序列预测(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🌈3 Matlab代码实现 🎉4 参考文献 💥1 概述 本文为能源消耗的时间序列预测,在Matlab中实现。该预测采用多层人工神经网络,基于Kaggle训练集预测未来能源消耗。 &am…

自动化测试如何做?分层自动化测试如何实施?一篇概全...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 分层的自动化测试…

qt简易闹钟

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);ui->stopBtn->setDisabled(true);this->setFixedSize(this->size()); //设置固定大小this->s…

【Spring Cloud Alibaba】限流--Sentinel

文章目录 概述一、Sentinel 是啥?二、Sentinel 的生态环境三、Sentinel 核心概念3.1、资源3.2、规则 四、Sentinel 限流4.1、单机限流4.1.1、引入依赖4.1.2、定义限流规则4.1.3、定义限流资源4.1.4、运行结果 4.2、控制台限流4.2.1、客户端接入控制台4.2.2、引入依赖…

(双指针) 剑指 Offer 57. 和为s的两个数字 ——【Leetcode每日一题】

❓ 剑指 Offer 57. 和为s的两个数字 难度:简单 输入一个递增排序的数组和一个数字 s,在数组中查找两个数,使得它们的和正好是 s。如果有多对数字的和等于 s,则输出任意一对即可。 示例 1: 输入:nums [2…

1400*C. Strong Password

Example input 5 88005553535123456 2 50 56 123412341234 3 111 444 1234 4 4321 4321 459 2 49 59 00010 2 10 11output YES NO YES NO YES解析: 题目要求有一种密码不在数据库中即可,所以枚举每一位的所有可能的数字,记录这一位数字在数…

帆软 FineReport/FineBI channel反序列化漏洞分析

事件背景 热点漏洞 漏洞说明 1. 漏洞原理:FineReport/FineBI channel接口能接受序列化数据并对其进行反序列化。配合帆软内置CB链会导致任意代码执行。 2. 组件描述:FineReport是一款企业级报表设计和数据分析工具,它提供了丰富多样的组件…

教你轻松又简单的绘制地铁线路图

地铁是主要以地下运输为主的交通系统。其轨道通常在地下隧道内。其中:地铁线路中的M,L,S,R分别代表:MMetro地,LLight轻轨,SSuburb市郊铁路,Rrapid,快速铁路。地铁线路图则可以简单理解为地铁的运行路线轨道…

基于Java+SpringBoot+vue前后端分离足球青训俱乐部管理后台系统设计实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

【自动话化运维】Ansible常见模块的运用

目录 一、Ansible简介二、Ansible安装部署2.1环境准备 三、ansible 命令行模块3.1.command 模块3.2.shell 模块3.3.cron 模块3.4.user 模块3.5.group 模块3.6.copy 模块3.7.file 模块8&#xff…

Java阶段五Day15

Java阶段五Day15 文章目录 Java阶段五Day15分层其他依赖dao-apidao-implinfrustructuredomainadaptermain 测试整合项目main前台师傅功能luban-front配置师傅相关表格ER图ER练习案例鲁班表格ER关系(非常重要) 前台师傅接口——师傅入驻adapterdomaininfr…

【C语言】指针进阶(二)

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

使用xtcp映射穿透指定服务

使用xtcp映射穿透指定服务 管理员Ubuntu配置公网服务端frps配置service自启(可选) 配置内网服务端frpc配置service自启(可选) 使用者配置service自启(可选) 通过frp实现内网client访问另外一个内网服务器 管理员 1)配置公网服务端frps2)配置内网服务端…

2023年 React 最佳学习路线

CSS CSS JavaScript JavaScript TypeScript 目前没有找到比其他文档好很多的文档地址 可以先看官网 React 新版 React 官方文档无敌 React React-router-dom V5 V6 Webpack webpack Antd antd

基于贴花的热力图呈现方法参考

首先是热力图信息类: using System; using UnityEngine;[Serializable] public class HeatMapInfo {public Rect rect Rect.zero;public float pixelDensity 1;public float valBase 0;public float valRange 1;public PercentColor[] ribbon;//色带public Col…

关联分析-Apriori

关联分析-Apriori 1. 定义 关联分析就是从大规模数据中,发现对象之间隐含关系与规律的过程,也称为关联规则学习。 2. 相关概念 2.1 事务、项与项集 订单号购买商品0001可乐、薯片0002口香糖、可乐0003可乐、口香糖、薯片 以上面的订单为例&#xf…

一套流程6个步骤,教你如何正确采购询价

采购询价(RFQ)是一种竞争性投标文件,用于邀请供应商或承包商就标准化或重复生产的产品或服务提交报价。 询价通常用于大批量/低价值项目,买方必须提供技术规格和商业要求,该文件有时也称为招标书或投标邀请书。询价流…

客户问题解决平台-帮助与支持中心

在现代企业中,帮助与支持中心扮演着至关重要的角色。随着市场的竞争日趋激烈,客户对于产品和服务的期望也越来越高。因此,建立一个高效的客户问题解决平台是企业成功的关键之一。本文将探讨帮助与支持中心的作用,介绍其功能和优势…

Linux(二)--Linux基础命令

我们在前面学习的Linux命令,其实它们的本体就是一个个的二进制可执行程序。 和Windows系统中的.exe文件,是一个意思。 一.Linux的目录结构 Linux的目录结构是一个树形结构。 Windows系统可以拥有多个顶级目录,称之为盘符,如C盘&…