4 多层感知机

news2024/11/20 7:02:25

多层感知机是一组前向结构的人工神经网络,映射一组输入向量到一组输出向量。除了输入节点,每一个节点都是一个带有非线性激活函数的神经元。多层感知机在输入层和输出层之间添加了一个或者多个隐藏层,并通过激活函数转换隐藏层输出。以下介绍几种激活函数。

4.1 多层感知机

4.1.1 RuLU函数

求导表现好,要么参数消失,要么参数通过,减轻了梯度消失问题。

%matplotlib inline
import torch
from d2l import torch as d2l

x=torch.arange(-8,8,0.1,requires_grad=True)
y=torch.relu(x)
# 此处使用detach().numpy()是因为带有梯度的不需要梯度
d2l.plot(x.detach().numpy(),y.detach().numpy(),"x","rule(x)",figsize=(3,3))

 


# torch.ones_like返回填充了标量值为1的张量
# retain_graph保留梯度,此处我不添加也不影响结果,暂时不知道为啥
y.backward(torch.ones_like(x))
d2l.plot(x.detach(),x.grad,"x","x.grad",figsize=(4,3))

 4.1.2 sigmoid函数

 sigmoid函数

y=torch.sigmoid(x)
d2l.plot(x.detach(),y.detach(),'x','sigmoid(x)',figsize=(4,3))

 sigmoid反向传播函数

# 清除之前的梯度
x.grad.data.zero_()
y.backward(torch.ones_like(x))
d2l.plot(x.detach(),x.grad,'x','grad of sigmoid',figsize=(4,3))

 4.1.3 tanh函数

 tanh函数

y=torch.tanh(x)
d2l.plot(x.detach(),y.detach(),'x','tanh(x)',figsize=(4,3))

 

x.grad.data.zero_()
y.backward(torch.ones_like(x))
d2l.plot(x.detach(),x.grad,'x','grad of tannh',figsize=(3,4))

 4.2 多层感知机简要实现(不使用torch工具包)

net = nn.Sequential(nn.Flatten(),nn.Linear(784,256),nn.ReLU(),nn.Linear(356,10))
def init_weight(m):
    if type(m)==nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
net.apply(init_weight)

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4.4 模型选择、欠拟合和过拟合

我们训练模型的原因是为了提高模型的泛化能力,在未遇到的个体上,也可以很好的评估风险。

样本有限,当在训练数据上拟合比在潜在分布更接近的现象叫做过拟合。用于对抗过拟合的技术叫做正则化。

当训练误差和验证误差都很严重,但他们之间仅有一点差距的现象叫做欠拟合。

4.4.1 数据集

首先用n阶多项式生成训练集和测试集的标签

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

# 创建随机的训练和测试数据,并排成一列
features = np.random.normal(size=(n_train + n_test, 1))
# 打乱数据
np.random.shuffle(features)
# 求出【x^0,x^1,...,x^max_degree-1】,并改成一行
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
# 每个x^i除以i!
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

# NumPy ndarray转换为tensor,这里不能注释
true_w, features, poly_features, labels = [torch.tensor(x, dtype=
    torch.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]

4.4.2 创建评估损失函数

def evaluate_loss(net, data_iter, loss):
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X,y in data_iter:
        out=net(X)
        y=y.reshape(out.shape)
        l = loss(out,y)
        metric.add(l.sum(), l.numel())
    return metric[0]/metric[1]

4.4.3 创建训练函数

每训练20次计算损失比率

def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss(reduction='none')
    # 货期train_features最后一列
    input_shape = train_features.shape[-1]
    # bias=False表示不设置偏置值
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    # 抽取batch_size个数据
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    # 优化算法采用SGD
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    # xlim和ylim代表x轴和y轴的范围
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())

查看训练损失和测试损失

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])

4.5 权重衰减

为了解决过拟合的问题,通过向损失函数中添加权重参数的平方和作为惩罚。损失函数可以这么写:L'=L+λ*||W||^2,λ用来控制惩罚的大小。由于惩罚项和参数的平方成正比,鼓励权重接近0,以此来减小模型复杂度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/927626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity 应用消息中心-MessageCenter

Ps:主要解决耦合问题,把脚本之间的联系通过不同消息类型事件形式进行贯通 1.MessageCenter主脚本 2.DelegateEvent消息类型脚本 3.MC_Default_Data具体接收类脚本 using System; using System.Collections; using System.Collections.Generic; using …

C语言弯道超车必做好题锦集(编程题)

目录 前言: 1.计算日期到天数转换 2.尼科彻斯定理 3.密码检查 4.图片整理 5.寻找数组的中心下标 6.字符个数统计 7.多数元素 前言: 编程想要学的好,刷题少不了,我们不仅要多刷题,还要刷好题!为此我…

压力检测器的基本信息是什么

压力检测器利用了传感器技术、电路处理技术、无线传输技术,能够精准测量气体或者液体等介质的压力,并将测得的数据上传至监控平台。 压力检测器能够适用于供水厂、污水处理厂、消防水系统、输油管道、输气管道等相关场景,拥有自动补偿功能、…

你知道开发程序的流程化、模块化、规范化是怎样的?不同厂商一样吗?

Postive: 我在天津的公司 都是netframwork的 .......... 后来 去北京 就找core 的技术 确实感觉不是一个层次的 .......... Postive: 以前 在天津 就是堆业务 部署iis 点点就完事了 用个redis 就牛逼的不行了 干上core的项目才发现 授权是单独…

[C++] STL_vector使用与常用接口的模拟实现

文章目录 1、vector的介绍2、vector的使用2.1 vector的定义2.2 vector迭代器的使用2.3 vector的空间增长问题 3、vector的增删查改3.1 push_back(重点)3.2 pop_back(重点)3.3 operator[](重点)3.4 insert3.…

腾讯云下一代CDN -- EdgeOne加速MinIO对象存储

省流 使用MinIO作为EdgeOne的源站。 背景介绍 项目中需要一个兼容S3协议的对象存储服务,腾讯云的COS虽然也兼容S3协议,但是也只是支持简单的上传下载,对于上传的时候同时打标签这种需求,就不兼容S3了。所以决定自建一个对象存储…

基于Java+SpringBoot+vue前后端分离在线问卷调查系统设计实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

VR云游同美景邂逅,抓住暑假的小尾巴~

暑期余额不足的宝子们,是不是还没出门玩耍玩个够呢?不如趁这个时候,用VR云游抓住暑假的小尾巴,收获一波开学前的“收心之旅”吧! VR云游相较于传统旅游来说,是通过个性化云服务,为智能景区建立综…

【WSN无线传感器网络恶意节点】使用 MATLAB 进行无线传感器网络部署研究

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

线性代数的学习和整理8:行列式相关

目录 1 从2元一次方程组求解说起 1.1 直接用方程组消元法求解 1.2 有没有其他方法呢?有:比如2阶行列式方法 1.3 3阶行列式 2 行列式的定义 2.1 矩阵里的方阵 2.2 行列式定义:返回值为标量的一个函数 2.3 行列式的计算公式 2.4 克拉…

189. 轮转数组

189. 轮转数组 class Solution { public:void rotate(vector<int>& nums, int k) {int n nums.size();k k % n;reverse(nums.begin(),nums.end());reverse(nums.begin(),nums.begin()k);reverse(nums.begin()k,nums.end());} };

2023年高教社杯数学建模思路 - 案例:FPTree-频繁模式树算法

文章目录 算法介绍FP树表示法构建FP树实现代码 建模资料 ## 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 算法介绍 FP-Tree算法全称是FrequentPattern Tree算法&#xff0c;就是频繁模式树算法&#xff0c…

最新活动报名表单系统源码 支持表单自定义+在线支付报名

分享一款最新活动报名表单系统源码&#xff0c;支持任意行业各种活动在线支付报名&#xff0c;配合万能自定义表单&#xff0c;适用于各种活动报名、课程招生、会议报名统计等等。 功能特点一览&#xff1a; 表单自定义&#xff1a;该报名系统允许组织者根据活动的需求自定义报…

Mybatis与Spring整合以及Aop整合pagehelper插件

一. Mybatis与Spring的集成 将MyBatis与Spring进行整合&#xff0c;主要解决的问题就是将SqlSessionFactory对象交由Spring容器来管理&#xff0c;所以&#xff0c;该整合&#xff0c;只需要将SqlSessionFactory的对象生成器SqlSessionFactoryBean注册在Spring容器中&#xff…

从头开始:将新项目上传至Git仓库的简易指南

无论您是一个经验丰富的开发者还是一个刚刚起步的新手&#xff0c;使用Git来管理您的项目是一个明智的选择。Git是一个强大的版本控制系统&#xff0c;它可以帮助您跟踪项目的变化、合并代码以及与团队成员协作。在本文中&#xff0c;我们将为您提供一步步的指南&#xff0c;教…

opengl shader nv格式转换

可以参考&#xff1a; OpenGL: 如何利用 Shader 实现 RGBA 到 NV21 图像格式转换&#xff1f;&#xff08;全网首次开源&#xff09; - 知乎 nv12 #extension GL_OES_EGL_image_external : require precision mediump float; varying vec2 vTextureCoord; uniform sampler2D…

Matplotlib | 高阶绘图案例【1】

文章目录 &#x1f3f3;️‍&#x1f308; 1. 绘制图布&#xff0c;设置坐标范围&#x1f3f3;️‍&#x1f308; 2. 绘制圆角矩形&#x1f3f3;️‍&#x1f308; 3. 添加水滴&#x1f3f3;️‍&#x1f308; 4. 添加时间线&#x1f3f3;️‍&#x1f308; 5. 添加文本、配色&…

ssm+vue海鲜自助餐厅系统源码和论文

ssmvue海鲜自助餐厅系统源码和论文068 开发工具&#xff1a;idea 数据库mysql5.7 数据库链接工具&#xff1a;navcat,小海豚等 技术&#xff1a;ssm 摘 要 网络技术和计算机技术发展至今&#xff0c;已经拥有了深厚的理论基础&#xff0c;并在现实中进行了充分运用&…

如何写好新闻稿

写好新闻稿是一门技巧和艺术的结合。一个有效的新闻稿应该能够快速吸引读者的注意力&#xff0c;并为他们提供有价值的信息。以下是如何写好新闻稿的步骤和建议&#xff1a; 1.吸引眼球的标题 简短明了&#xff1a;标题应该简洁&#xff0c;一眼就能告诉读者新闻的核心内容。使…

Python打包exe和生成安装程序

1.打包exe python打包成exe文件的一般步骤如下&#xff1a; 安装pyinstaller模块&#xff0c;可以使用pip install pyinstaller命令来安装或更新pyinstaller模块。在cmd中切换到要打包的python文件所在的目录&#xff0c;输入pyinstaller -F 文件名.py命令来生成单个exe文件。…