从0开始深度学习(12)——多层感知机的逐步实现

news2024/11/26 10:49:17

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过

1 读取数据

import torch
from torchvision import transforms
import torchvision
from torch.utils import data

# 读取数据
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="D:/DL_Data/", train=True, transform=trans, download=False)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="D:/DL_Data/", train=False, transform=trans, download=False)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=12),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=12))
                            
train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

2 初始化模型参数

以单隐藏层的多层感知机为例,选择使用256个隐藏单元

from torch import nn

# 初始化模型参数
num_inputs=784      # 28*28
num_outputs=10
num_hiddens=256     # 我们选择使用256个隐藏单元,注意,一般选择使用2的若干次幂,因为内存的特殊性,可以在计算上更高效

w1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))

w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [w1, b1, w2, b2]

3 激活函数、损失函数、建立模型

# 激活函数
def relu(x):
    a=torch.zeros_like(x) # 保证全零张量和x的形状一致,利于广播计算
    return torch.max(x,a)

# 损失函数
loss = nn.CrossEntropyLoss(reduction='none')

#建立模型
def net(x):
    x=x.reshape((-1,num_inputs))#展开
    H=relu(x@w1+b1)# @表示矩阵乘法
    return (H@w2+b2)

4 训练模型

优化器使用SGD

#训练,优化器使用sgd
num_epochs=5
lr=00.1
updater=torch.optim.SGD(params,lr=lr)

def train_epoch(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()  # 将模型设置为训练模式
    metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            updater.step()
        else:
            l.backward()
            updater([w, b], lr, batch_size)
        metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]

def train(net, train_iter, test_iter, loss, num_epochs, updater):
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')
        
class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def compute_accuracy(y_hat, y):  # 预测值、真实值
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别
    cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmp
    return float(cmp.type(y.dtype).sum())

# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(compute_accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

train(net, train_iter, test_iter, loss, num_epochs, updater)

在这里插入图片描述

5 预测

import matplotlib.pyplot as plt
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 预测并显示结果
def predict(net, test_iter, n=6):
    for X, y in test_iter:
        break  # 只取一个批次的数据
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
    n = min(n, X.shape[0])
    fig, axs = plt.subplots(1, n, figsize=(12, 3))
    for i in range(n):
        axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')
        axs[i].set_title(titles[i])
        axs[i].axis('off')
    plt.show()

# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2219769.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaCove部署文档

1. 基础配置 1.1服务器: 2 核 2G 1.2. 一个域名 1.3. 项目地址: gitee:https://gitee.com/guo-_jun/JavaCove github:https://github.com/nansheng1212/JavaCove 2. CentOS 安装 Docker 官方网站上有各种环境下的 安装指南,这里主要介绍…

webpack自定义插件 ChangeScriptSrcPlugin

插件文件 class ChangeScriptSrcPlugin {apply(compiler) {const pluginName "ChangeScriptSrcPlugin";compiler.hooks.compilation.tap(pluginName, (compilation, callback) > {compilation.hooks.htmlWebpackPluginAlterAssetTags.tapAsync(pluginName,(html…

SpringCloudStream使用StreamBridge实现延时队列

利用RabbitMQ实现消息的延迟队列 一、安装RabbitMQ 1、安装rabbitmq 安装可以看https://blog.csdn.net/qq_38618691/article/details/118223851,进行安装。 2、安装插件 安装完毕后,exchange是不支持延迟类型的,需要手动安装插件,需要和安装的rabbitmq版本一致 https:…

动态规划:17.简单多状态 dp 问题_买卖股票的最佳时机III_C++

题目链接: 一、题目解析 题目:123. 买卖股票的最佳时机 III - 力扣(LeetCode) 解析: 拿示例1举例: 我们可以如图所示买入卖出股票,以求得最大利润,并且交易次数不超过2次 拿示…

基于SpringBoot设计模式之结构型设计模式·组合模式

文章目录 介绍开始架构图定义条目定义文件定义文件夹 测试样例 总结 介绍 能够使容器与内容具有一致性,创造出递归结构的模式就是 Composite 模式。Composite 在英文中是“混合物”“复合物”的意思。   以目录为例,在计算机中,某个目录下有…

在海外留学/工作,如何报考微软mos认证?

重点首先得强调的是,即使在海外也可以顺利地在国内获取微软MOS认证! 01 微软mos认证简介 Microsoft Office Specialist 简称MOS。是微软公司和第三方国际认证机构、全球三大IT测验与教学中心之一的思递波/Certiport公司于1997年联合推出的,…

2009年国赛高教杯数学建模A题制动器试验台的控制方法分析解题全过程文档及程序

2009年国赛高教杯数学建模 A题 制动器试验台的控制方法分析 汽车的行车制动器(以下简称制动器)联接在车轮上,它的作用是在行驶时使车辆减速或者停止。制动器的设计是车辆设计中最重要的环节之一,直接影响着人身和车辆的安全。为了…

分享一个IDEA里面的Debug调试设置

1.问题来源 其实我们在这个IDEA里面的这个进行调试的时候,这个是只有步入,出去的选项的; 之前学习这个sort的底层源码的时候,进不去,我们是设置了一个取消java*什么的选项,然后使用这个step into就可以进…

计算机网络易混知识点

1.以太网采用曼彻斯特编码;以太网帧最短为64B,其中14个B首部(目的MAC-6B,源MAC-6B,类型-2B)4B尾部 2.OSI协议中,每一层为上一层提供服务,为下一层提供接口 3.帧序号的比特数表示的是发送窗口的大小&#…

java逻辑运算符 C语言结构体定义

1. public static void main(String[] args) {System.out.println(true&true);//&两者均为true才trueSystem.out.println(false|false);// | 两边都是false才是falseSystem.out.println(true^false);//^ 相同为false,不同为trueSystem.out.println(!false)…

(38)MATLAB分析带噪信号的频谱

文章目录 前言一、MATLAB仿真代码二、仿真结果画图总结 前言 本文给出带噪信号的时域和频域分析,指出频域分析在处理带噪信号时的优势。 首先使用MATLAB生成一段信号,并在信号上叠加高斯白噪声得到带噪信号,然后对带噪信号对其进行FFT变换&…

Java面试指南:Java基础介绍

这是《Java面试指南》系列的第1篇,本篇主要是介绍Java的一些基础内容: 1、Java语言的起源 2、Java EE、Java SE、Java ME介绍 3、Java语言的特点 4、Java和C的区别和联系? 5、面向对象和面向过程的比较 6、Java面向对象的三大特性&#xff1a…

云计算-----单机LNMP结构WordPress网站

LNMP结构 博客网站 day1 小伙伴们,LNMP结构在第一二阶段浅浅的学习过,这里我们可以离线部署该结构。L指(虚拟机)服务器,nginx(前端代理服务器)mysql数据库,最后基于php建设动态…

AlDente Pro for Mac电脑 充电限制保护工具 安装教程【简单,轻松上手】

Mac分享吧 文章目录 AlDente Pro for Mac 充电限制保护工具 安装完成,软件打开效果一、AlDente Pro for Mac 充电限制保护工具 Mac电脑版——v1.28.41️⃣:下载软件2️⃣:安装软件,将安装包从左侧拖入右侧文件夹中,等…

Halcon实战——基于NCC模板匹配的芯片检测(附源码)

Halcon实战——基于NCC模板匹配的芯片检测(附源码) 关于作者 作者:小白熊 作者简介:精通python、matlab、c#语言,擅长机器学习,深度学习,机器视觉,目标检测,图像分类&am…

Java | Leetcode Java题解之第493题翻转对

题目&#xff1a; 题解&#xff1a; class Solution {public int reversePairs(int[] nums) {Set<Long> allNumbers new TreeSet<Long>();for (int x : nums) {allNumbers.add((long) x);allNumbers.add((long) x * 2);}// 利用哈希表进行离散化Map<Long, Int…

linux 效率化 - 输入法 - fcitx5

安装 Fcitx5 1. 卸载 ibus 框架 由于 ibus 和 fcitx 可能会冲突&#xff0c;先卸载 ibus&#xff08;暂未确认原因&#xff09; sudo apt remove --purge ibus2. 安装 fcitx5 输入法框架 sudo apt update sudo apt install fcitx5 fcitx5-chinese-addons fcitx5-frontend-gtk…

深入理解Nest的REQUEST范围和TRANSIENT范围

深入理解Nest的REQUEST范围和TRANSIENT范围 单例模式REQUEST范围控制器的REQUEST范围REQUEST范围的冒泡特性场景 TRANSIENT范围例外场景 总结 单例模式 单例模式是指在整个程序执行期间&#xff0c;程序内的类都会实例化&#xff0c;且与应用程序生命周期直接相关&#xff0c;…

javax.el.PropertyNotFoundException: Property ‘XXX‘ not found on type XXX(类的路径)

捣鼓了半小时的bug 在网上找了好多方案,都没有解决 其中一个佬的解决方案:异常&#xff1a;javax.el.PropertyNotFoundException: Property xxx not found on type java.lang.String-CSDN博客 但是还是没有解决我的问题 最终解决方法,在jsp文件头部导入了类包(第三行我导入…

【Nginx系列】Nginx配置超时时间

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…