Pytorch 多层感知机

news2024/12/23 9:23:43

一、什么是多层感知机

多层感知机由感知机推广而来,最主要的特点是有多个神经元层,因此也叫深度神经网络(DNN: Deep Neural Networks)。

二、如何实现多层感知机

1、导入所需库并加载fashion_mnist数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torch import nn
from torchvision import transforms
from d2l import torch as d2l
import warnings
warnings.filterwarnings("ignore")
# 用SVG清晰度高
d2l.use_svg_display()
from IPython import display
def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)# 用于集中图像处理方法
    mnist_train = torchvision.datasets.FashionMNIST(
        root="data/", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="data/", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(256)

2、设置变量以及要更新的参数

Fashion-MNIST中的每个图像由 28×28=784个灰度像素值组成。 所有图像共分为10个类别。 忽略像素之间的空间结构, 我们可以将每个图像视为具有784个输入特征 和10个类的简单分类数据集。 首先,我们将实现一个具有单隐藏层的多层感知机, 它包含256个隐藏单元。

通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

torch.nn.Parameter()将一个不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面。即在定义网络时这个tensor就是一个可以训练的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化

3、定义rulu激活函数

def relu(X):
    a = torch.zeros_like(X) #用于生成和X相同shape的全零tensor
    return torch.max(X, a)

4、搭建模型

def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法如上图所示乘法
    return (H@W2 + b2)

5、定义损失函数

这里我们使用API内置的CrossEntropyLoss即包含了softmax也包含了交叉熵损失

loss = nn.CrossEntropyLoss(reduction='none')

6、进行训练

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

 

7、模型测试

d2l.predict_ch3(net, test_iter)

 

8、简单实现

使用Pytorch内置的高级API搭建网络架构

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torch import nn
from torchvision import transforms
from d2l import torch as d2l
import warnings
warnings.filterwarnings("ignore")
# 用SVG清晰度高
d2l.use_svg_display()
from IPython import display
def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="data/", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="data/", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(256)
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);
lr, num_epochs =  0.1, 10
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

这里面使用Sequential搭建了网络架构,其中先将图片进行展平(nn.Flatten)然后传入线形层,在经过relu激活函数,最后使用Linear进行输出。

使用init_weights进行初始化线行层参数,使其符合正态分布(如果全0模型训练不动)

最后进行训练就可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/147212.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java弹幕视频网站源码

简介 Java基于ssm的弹幕视频系统,用户注册后可以上传视频进行投稿,也可以浏览视频发送弹幕,在个人中心管理视频、管理弹幕、管理评论等。管理员可以管理视频弹幕评论,查看统计图。 演示视频: https://www.bilibili.c…

CVE-2020-0014 Toast组件点击事件截获漏洞

文章目录前言漏洞分析组件源码触摸属性漏洞利用POC分析漏洞复现漏洞修复总结前言 Toast 组件是 Android 系统一个消息提示组件,比如你可以通过以下代码弹出提示用户“该睡觉了…”: Toast.makeText(this, "该睡觉了…", Toast.LENGTH_SHORT)…

C语言文件操作-从知识到实践全程

目录 引入 文件的打开和关闭 文件如何使用程序来打开? 绝对路径需要转义字符 fopen函数 fclose函数 文件的打开方式(fopen第二参数const char* mode): 文件的顺序读写 fgetc和fputc的使用 fputc fgetc fgets和fputs的使用 fputs fgets perror的使用 fprint…

哪些程序员适合自由工作?(附平台推荐)

在早些时候进行远程办公,接私活或者跨国进行编程,赚点外快等也不是什么奇怪的事情。但是那时候没有人想到会把这些工作完全变成自己的主要业务——也就是我们说的自由工作。也不知道是哪一个第1个吃了螃蟹的人发现自由工作还不错,于是经过后面…

【JavaScript】DOM 学习总结-基础知识

获取元素方法: // 获取三个非常规的标签 console.log(document.documentElement) console.log(document.head) console.log(document.body)通过id/class获取:getElementById / getElementsByClassName // 获取常规的用id,class,tag var boxdocument.g…

Android 自定义Activity的主题

一. 前言 当在某个app中做一个新界面时, 我们要考虑一下主题风格相符合一致. 本篇文章讲解的是,如何新创建的Activity 与整个app主题符合, 特别是状态栏的颜色需要和这个app的状态栏颜色保持一致. 在读本篇文章之前, 可以移步一下笔者之前写的文章:Android style&#xf…

代码随想录算法训练营第十一天字符串 java :20. 有效的括号 1047. 删除字符串中的所有相邻重复项 150. 逆波兰表达式求值

文章目录Leetcode 20. 有效的括号题目详解数据结构 双端队列(deque)Deque有三种用途:思路报错Ac代码Leetcode1047. 删除字符串中的所有相邻重复项题目详解数据结构 ArrayDeque类思路AC代码150. 逆波兰表达式求值题目详解报错难点AC代码收获Leetcode 20. 有效的括号 …

系分 - 系统设计

个人总结,仅供参考,欢迎加好友一起讨论 文章目录系分 - 系统设计考点摘要系统设计软件设计软件架构设计结构化设计概要设计详细设计处理流程设计流程工作流活动及其所有者工作项工作流管理系统WFMS的基本功能WFMS的组成WRM流程设计工具用户界面设计/人机…

python算法与数据结构1-算法、数据结构、链表

目录1、算法的概念1.1 举例:1.2 算法的五大特性:1.3 时间复杂度1.4 空间复杂度2、数据结构2.1 内存的存储结构2.2 数据结构的分类2.3 顺序表存储方式3、链表3.1链表实现3.2链表的方法3.3链表增加节点3.4链表删除节点3.5链表总结1、算法的概念 算法与数据…

(Java高级教程)第三章Java网络编程-第三节:UDP数据报套接字(DatagramSocket)编程

文章目录一:Java数据报套接字通信模型二:相关API详解(1)DatagramSocket(2)DatagramPacket三:UDP通信示例一:客户端发送什么服务端就返回什么(1)代码&#xff…

k8s之ConfigMap和secret

写在前面 我们知道k8s的数据都是存储到kv数据库etcd中的,那么我们程序中使用到各种配置信息是否可以也存储到etcd,然后在pod中使用呢?是可以的,k8s为了实现将自定义的数据存储到etcd,定义了ConfigMap 和secret两种API…

《后端技术面试 38 讲》学习笔记 Day 01

《后端技术面试 38 讲》学习笔记 Day 01 学习目标 在2022年春节将至(半个月),适合在这个冬天里,温故知新。通过学习一门覆盖面较广的课程,来夯实基础,完善自己的知识体系,是一个很棒的选择。 …

LCHub:未来,低代码产品矩阵是500强企业的绝佳选择

近日,国内知名咨询机LCHub发布2022《中国大型企业数字化升级路径研究》。 报告认为由于大型企业的数字化需求旺盛、购买力充足,因此国内成熟的数字化服务商普遍以大型企业为核心客户。大型企业与数字化服务商的供需磨合决定了我国数字化市场的形态,造就了我国数字化市场与海…

go map 源码逐行阅读

map粗略介绍 源码开头注释: A map is just a hash table. The data is arranged into an array of buckets. Each bucket contains up to 8 key/elem pairs. The low-order bits of the hash are used to select a bucket. Each bucket contains a few high-order…

Linux学习笔记——RabbitMQ安装部署

5.4、RabbitMQ安装部署 5.4.1、简介 RabbitMQ是一款知名的开源消息列队系统,为企业提供消息的发布、订阅、点对点传输等消息服务。 RabbitMQ在企业开发中十分常见,课程为大家演示快速搭建RabbitMQ环境。 5.4.2、安装 RabbitMQ在yum仓库中的版本比较老…

用于从单细胞FRET数据中提取灵敏度分布的Matlab代码

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 对于分子生物学来讲,生物分析手段的发展,是阐明机理的必要条件。在研究分子间相互作用的道路上&#xf…

Leetcode - 106 - 相交链表

160. 相交链表 题目描述 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中不存在环。 注意&am…

【阶段三】Python机器学习04篇:机器学习项目实战:多元线性回归模型、岭回归模型与套索回归模型

本篇的思维导图: 一元线性回归的数学原理 一元线性回归模型又称为简单线性回归模型,其形式可以表示为如下所示的公式。 y=ax+b 其中,y为因变量,x为自变量,a为回归系数,b为截距。 如下图所示,其中y(i)为实际值,y(i)为预测值,一元线性回归的目的就是拟合出一…

Vision-Only Robot Navigation in a Neural Radiance World

Paper name Vision-Only Robot Navigation in a Neural Radiance World Paper Reading Note URL: https://arxiv.org/abs/2110.00168 TL;DR RA-L 2022 文章,提出在 NeRF 表达的环境中进行机器人导航及循迹任务的方法 Introduction 背景 基于 on-board 传感器…

HTML与CSS基础(五)—— CSS布局(盒子模型)、PxCook使用

目标能够认识 盒子模型 的组成部分 能够掌握盒子模型的 边框、内边距、外边距 的作用及简写形式能够计算盒子的 实际大小 能够了解 外边距折叠现象,并知道如何解决 盒子塌陷问题一、PxCook的基本使用1. 通过软件打开设计图① 打开软件 ② 拖拽入设计图 ③ 新建项目2…