【PyTorch】多层感知机

news2024/11/16 17:50:55

文章目录

  • 1. 模型和代码实现
    • 1.1. 模型
      • 1.1.1. 背景
      • 1.1.2. 多层感知机
      • 1.1.3. 激活函数
    • 1.2. 代码实现
      • 1.2.1. 完整代码
      • 1.2.2. 输出结果
  • 2. Q&A

1. 模型和代码实现

1.1. 模型

1.1.1. 背景

许多问题要使用线性模型,但无法简单地通过预处理来实现。此时我们可以通过在网络中加入一个或多个隐藏层来克服线性模型的限制, 使其能处理更普遍的函数关系类型。

1.1.2. 多层感知机

将许多全连接层堆叠在一起。 每一层都输出到上面的层,直到生成最后的输出,我们可以把前层看作表示,把最后一层看作线性预测器。 这种架构通常称为多层感知机,通常缩写为MLP。
多层感知机

1.1.3. 激活函数

我们需要在仿射变换之后对每个隐藏单元应用非线性的激活函数,这样就不可能再将我们的多层感知机退化成线性模型,使得模型具有更强的表达能力。
激活函数是通过计算加权和并加上偏置来确定神经元是否应该被激活, 并将输入信号转换为输出的可微运算的函数。

  • ReLU函数

    • 修正线性单元(Rectified linear unit,ReLU)。

    • 最受欢迎的激活函数。

    • 定义: R e L U ( x ) = m a x ( 0 , x ) \mathrm{ReLU}(x)=\mathrm{max}(0,x) ReLU(x)=max(0,x)
      relu

    • 当输入接近0时,sigmoid函数接近线性变换。
      gradofrelu

    • 当输入值精确等于0时,ReLU函数不可导。 在此时,我们默认使用左侧的导数,即当输入为0时导数为0。 我们可以忽略这种情况,因为输入可能永远都不会是0。

    • 变体:参数化的ReLU(Parameterized ReLU,pReLU),允许即使参数是负的,某些信息依然可以通过,其定义如下: p R e L U ( x ) = m a x ( 0 , x ) + α m i n ( 0 , x ) \mathrm{pReLU}(x)=\mathrm{max}(0,x)+\alpha\mathrm{min}(0,x) pReLU(x)=max(0,x)+αmin(0,x)等等。

  • sigmoid函数

    • 将输入变换为区间(0, 1)上的输出。

    • 在隐藏层中已经较少使用, 它在大部分时候被更简单、更容易训练的ReLU所取代。

    • 定义: s i g m o i d ( x ) = 1 1 + e x p ( − x ) \mathrm{sigmoid}(x)=\frac{1}{1+\mathrm{exp}(-x)} sigmoid(x)=1+exp(x)1
      sigmoid

    • 导数: d d x s i g m o i d ( x ) = s i g m o i d ( x ) ( 1 − s i g m o i d ( x ) ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{sigmoid}(x)=\mathrm{sigmoid}(x)(1-\mathrm{sigmoid}(x)) dxdsigmoid(x)=sigmoid(x)(1sigmoid(x))
      gradofsigmoid

  • tanh函数

    • 将其输入压缩转换到区间(-1, 1)上。

    • 定义: t a n h ( x ) = 1 − e x p ( − 2 x ) 1 + e x p ( − 2 x ) \mathrm{tanh}(x)=\frac{1-\mathrm{exp}(-2x)}{1+\mathrm{exp}(-2x)} tanh(x)=1+exp(2x)1exp(2x)
      tanh

    • 当输入接近0时,tanh函数接近线性变换。

    • 导数: d d x t a n h ( x ) = 1 − t a n h 2 ( x ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{tanh}(x)=1-\mathrm{tanh}^2(x) dxdtanh(x)=1tanh2(x)
      gradoftanh

1.2. 代码实现

1.2.1. 完整代码

import torch
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter

def load_dataset(batch_size, num_workers):
    """加载数据集"""
    root = "./dataset"
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = FashionMNIST(
        root=root, 
        train=True, 
        transform=transform, 
        download=True
    )
    mnist_test = FashionMNIST(
        root=root, 
        train=False, 
        transform=transform, 
        download=True
    )
    dataloader_train = DataLoader(
        mnist_train, 
        batch_size, 
        shuffle=True, 
        num_workers=num_workers
    )
    dataloader_test = DataLoader(
        mnist_test, 
        batch_size, 
        shuffle=False,
        num_workers=num_workers
    )
    return dataloader_train, dataloader_test

def init_network(net):
    """初始化模型参数"""
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, val=0)
    if isinstance(net, nn.Module):
        net.apply(init_weights)

class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def train(net, dataloader_train, criterion, optimizer, device):
    """训练模型"""
    if isinstance(net, nn.Module):
        net.train()
    train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    for X, y in dataloader_train:
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        loss = criterion(y_hat, y)
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())
    train_loss = train_metrics[0] / train_metrics[2]
    train_acc = train_metrics[1] / train_metrics[2]
    return train_loss, train_acc

def test(net, dataloader_test, device):
    """测试模型"""
    if isinstance(net, nn.Module):
        net.eval()
    with torch.no_grad():    
        test_metrics = Accumulator(2)   # 测试准确度总和、样本数
        for X, y in dataloader_test:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            test_metrics.add(accuracy(y_hat, y), y.numel())
        test_acc = test_metrics[0] / test_metrics[1]
        return test_acc


if __name__ == "__main__":
    # 全局参数设置
    batch_size = 256
    num_workers = 4
    num_epochs = 10
    learning_rate = 0.1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 创建记录器
    writer = SummaryWriter()

    # 加载数据集
    dataloader_train, dataloader_test = load_dataset(batch_size, num_workers)

    # 定义神经网络
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    ).to(device)

    # 初始化神经网络
    init_network(net)

    # 定义损失函数
    criterion = nn.CrossEntropyLoss(reduction='none')

    # 定义优化器
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        train_loss, train_acc = train(net, dataloader_train, criterion, optimizer, device)
        test_acc = test(net, dataloader_test, device)
        writer.add_scalars("metrics", {
            'train_loss': train_loss, 
            'train_acc': train_acc, 
            'test_acc': test_acc
            }, 
            epoch)
        
    writer.close()

1.2.2. 输出结果

多层感知机

2. Q&A

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElasticSearch基础知识

ElasticSearch是一个高扩展的分布式全文搜索引擎,基于Lucene作为核心实现所有索引和搜索的功能。 使用场景: (1)搜索领域:如维基百科、谷歌,全文检索等。 (2)网站具体内容&#xf…

计算机网络:传输层——多路复用与解复用

文章目录 前言一、Socket(套接字)二、多路复用/解复用三、多路解复用(1)多路解复用原理(2)无连接(UDP)多路解复用(3)面向连接(TCP)的多…

普通策略梯度算法原理及PyTorch实现【VPG】

有没有想过强化学习 (RL) 是如何工作的? 在本文中,我们将从头开始构建最简单的强化学习形式之一 —普通策略梯度(VPG)算法。 然后,我们将训练它完成著名的 CartPole 挑战 — 学习从左向右移动购物车以平衡杆子。 在此…

哈希与哈希表

哈希表的概念 哈希表又名散列表,官话一点讲就是: 散列表(Hash table,也叫哈希表),是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记…

MySQL的多表查询

多表关系 一对多(多对一)-> 多对多-> 一对一-> 概述 概述 多表查询分类 内连接 代码演示--> -- 内连接演示 -- 1.查询每一个员工的姓名,及关联的部门的名称(隐式内连接实现) select emp.name, dept.name from emp,dept where emp.dept_id dept.id; …

10、外观模式(Facade Pattern,不常用)

外观模式(Facade Pattern)也叫作门面模式,通过一个门面(Facade)向客户端提供一个访问系统的统一接口,客户端无须关心和知晓系统内部各子模块(系统)之间的复杂关系,其主要…

sql面试题之“互相关注的人”(方法三)

题目:某社交平台有关注这个功能,关注的同时也会被关注。现有需求需要找出平台上哪些用户之间互相关注。 文章目录 题目如下:一、数据准备二、建表并导入数据1.建表2.导入数据3.数据分析和实现思路小结: 题目如下: 某社…

[RK-Linux] 移植Linux-5.10到RK3399(三)| 检查eMMC与SD卡配置

这个专题主要记录把 RK Linux-5.10 移植到 ROC-RK3399-PC Pro 的过程。 文章目录 一、eMMC二、SD 卡三、两个接口的区别一、eMMC RK3399 的 eMMC 接口如图: datasheet 介绍: 实际上,连接 eMMC 存储器用的是 SDHCI 接口。SDHCI(Secure Digital Host Controller Interface)…

【数据结构】最短路径——Floyd算法

一.问题描述 给定带权有向图G(V,E),对任意顶点 V (ij),求顶点到顶点的最短路径。 转化为: 多源点最短路径求解问题 解决方案一: 每次以一个顶点为源点调用Dijksra算法。时间复杂…

香港虚拟信用卡如何办理,支持香港apple id

什么是虚拟信用卡? 虚拟信用卡,英文称之为Virtual Credit Card Numbers,就是指没有实体卡片,是基于银行卡上面的BIN码所生成的虚拟账号。通常用于进行网络交易,使用起来很方便,也很安全。 它与实体信用卡…

算法-01-递归

1-理解递归 斐波那契数列(Fibonacci sequence),又称黄金分割数列 ,以兔子繁殖为例子而引入,故又称“兔子数列”,其数值为:1、1、2、3、5、8、13、21、34……特点是 从第三个数开始,第…

HOST文件被挟持,无法上网,如何解决。

问题: 晚上开机,突然发现无法联网,提示网络异常 解决: 首先网络诊断,host文件被劫持,修复后,仍然不行。 然后测试手机热点,发现仍然无法联网 尝试用火绒修复,无果。 所有…

Linux Camera Driver(2):CIS设备注册(DTS)

一:MIPI接口 1、硬件接口 MIPI接口以rv1109和gc2053的硬件为例进行说明: 2、ISP驱动 注意配置事项: endpoint配置,必须指定data-lanes,否则无法识别为mipi类型 链接方式:sensor->csi_dphy->isp->ispp (1)sensor节点配置 根据原理图可知:mipicsi_clk0即引…

Linux系统安装Python3环境

1、默认情况下,Linux会自带安装Python,可以运行python --version命令查看,如图: 我们看到Linux中已经自带了Python2.7.5。再次运行python命令后就可以使用python命令窗口了(CtrlD退出python命令窗口)。 2…

STM32F407-14.3.11-01互补输出和死区插入

互补输出和死区插入 高级控制定时器(TIM1 和 TIM8)可以输出两路互补信号,并管理输出的关断与接通瞬间。 这段时间通常称为死区,用户必须根据与输出相连接的器件及其特性(电平转换器的固有延迟、开关器件产生的延迟...&…

MySQL之时间戳(DateTime和TimeStamp)

MySQL之时间戳(DateTime和TimeStamp) 文章目录: MySQL之时间戳(DateTime和TimeStamp)一、DateTime类型二、TimeStamp类型三、DateTime和TimeStamp的区别 当插入数据时,需要自动记录一个时间时候&#xff0c…

llama.cpp部署(windows)

一、下载源码和模型 下载源码和模型 # 下载源码 git clone https://github.com/ggerganov/llama.cpp.git# 下载llama-7b模型 git clone https://www.modelscope.cn/skyline2006/llama-7b.git查看cmake版本: D:\pyworkspace\llama_cpp\llama.cpp\build>cmake --…

git 本地改动无法删除

1. 问题 记录下git遇到奇怪的问题,本地有些改动不知道什么原因无法删除 git stash, git reset --hard HEAD 等都无法生效,最终通过强制拉取线上解决 如下图: 2. 解决 git fetch --all git reset --hard origin/master执行这两…

LeedCode刷题---双指针问题

顾得泉:个人主页 个人专栏:《Linux操作系统》 《C/C》 《LeedCode刷题》 键盘敲烂,年薪百万! 双指针简介 常见的双指针有两种形式,一种是对撞指针,一种是左右指针。 对撞指针:一般用于顺序结构中&…

马斯克极简5步工作法 —— 筑梦之路

马斯克的五步流程法则: 第一步:确定需求 第二步:极力删除零件或过程 第三步:简化和优化 第四步:加快周期时间 第五步:自动化特别注意:完成前三步之前,千万不要考虑加速和自动化&…