李沐动手学习深度学习——4.2练习

news2024/12/25 2:23:09

1. 在所有其他参数保持不变的情况下,更改超参数num_hiddens的值,并查看此超参数的变化对结果有何影响。确定此超参数的最佳值。

通过改变隐藏层的数量,导致就是函数拟合复杂度下降,隐藏层过多可能导致过拟合,而过少导致欠拟合。
我们将层数改为128可得:
在这里插入图片描述

2. 尝试添加更多的隐藏层,并查看它对结果有何影响。

过拟合,导致测试机精确度下降。

3. 改变学习速率会如何影响结果?保持模型架构和其他超参数(包括轮数)不变,学习率设置为多少会带来最好的结果?

过高的学习率导致,梯度跨度过大,使得降低不到对应的驻点。
过低的学习率导致训练缓慢,需要增加epoch。
在训练轮数不变的情况下,我们可以通过for 设置不同的学习率找出最合适的学习率。一般来说设置为0.01或者0.1足以

4. 通过对所有超参数(学习率、轮数、隐藏层数、每层的隐藏单元数)进行联合优化,可以得到的最佳结果是什么?

跑了一次学习率lr=0.01的情况:
在这里插入图片描述

需要大量的训练,但是目前我训练结果是学习率lr=0.1、轮数是num_epochs=10,隐藏层数为1,隐藏层数单元num_hiddens=128。

5. 描述为什么涉及多个超参数更具挑战性。

因为组合的情况更多,当层数越多时,训练时间也更多,这玩意就是炼丹了,看你自己的GPU还有时间、运气。

6. 如果想要构建多个超参数的搜索方法,请想出一个聪明的策略。

套用for 循环暴力破解,时间上肯定慢的要死,我们可以先固定其他变量,挑选一个变量寻找最优解,以此类推对所有的超参数这样使用,但是这种做法肯定不是最优的,只是能够较好的找出比较好的超参数。

由于学校穷逼所以没有闲置GPU服务器,所有的模型只能在colab上进行运行,其中遇到了d2l的版本对应问题,所以对于d2l.train_ch3跑不起来,只能使用自写进行替代如下:

import torch.nn
from d2l import torch as d2l
from IPython import display

class Accumulator:
    """
    在n个变量上累加
    """
    def __init__(self, n):
        self.data = [0.0] * n       # 创建一个长度为 n 的列表,初始化所有元素为0.0。

    def add(self, *args):           # 累加
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):                # 重置累加器的状态,将所有元素重置为0.0
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):     # 获取所有数据
        return self.data[idx]


def accuracy(y_hat, y):
    """
    计算正确的数量
    :param y_hat:
    :param y:
    :return:
    """
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)            # 在每行中找到最大值的索引,以确定每个样本的预测类别
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())


def evaluate_accuracy(net, data_iter):
    """
    计算指定数据集的精度
    :param net:
    :param data_iter:
    :return:
    """
    if isinstance(net, torch.nn.Module):
        net.eval()                  # 通常会关闭一些在训练时启用的行为
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]



class Animator:
    """
    在动画中绘制数据
    """
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量的绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend
        )
        self.X, self.Y, self.fmts = None, None, fmts


    def add(self, x, y):
        """
        向图表中添加多个数据点
        :param x:
        :param y:
        :return:
        """
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)


def train_epoch_ch3(net, train_iter, loss, updater):
    """
    训练模型一轮
    :param net:是要训练的神经网络模型
    :param train_iter:是训练数据的数据迭代器,用于遍历训练数据集
    :param loss:是用于计算损失的损失函数
    :param updater:是用于更新模型参数的优化器
    :return:
    """
    if isinstance(net, torch.nn.Module):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。
        net.train()

    # 训练损失总和, 训练准确总和, 样本数
    metric = Accumulator(3)

    for X, y in train_iter:  # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):  # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。
            # 使用pytorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()  # 方法用于计算损失的平均值
            updater.step()
        else:
            # 使用定制(自定义)的优化器和损失函数
            l.sum().backward()
            updater(X.shape())
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]


def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
    """
    训练模型()
    :param net:
    :param train_iter:
    :param test_iter:
    :param loss:
    :param num_epochs:
    :param updater:
    :return:
    """
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        trans_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, trans_metrics + (test_acc,))
        train_loss, train_acc = trans_metrics
        print(trans_metrics)


def predict_ch3(net, test_iter, n=6):
    """
    进行预测
    :param net:
    :param test_iter:
    :param n:
    :return:
    """
    global X, y
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + "\n" + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n]
    )
    d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1484804.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

KubeSphere平台安装系列之二【Linux单节点部署KubeSphere】(2/3)

**《KubeSphere平台安装系列》** 【Kubernetes上安装KubeSphere(亲测–实操完整版)】(1/3) 【Linux单节点部署KubeSphere】(2/3) 【Linux多节点部署KubeSphere】(3/3) **《KubeS…

Matlab 机器人工具箱 运动学

文章目录 R.fkine()R.ikine()R.ikine6s()R.jacob0、R.jacobn、R.jacob_dotjtrajctraj参考链接官网:Robotics Toolbox - Peter Corke R.fkine() 正运动学,根据关节坐标求末端执行器位姿 mdl_puma560; % 加载puma560模型 qz % 零角度 qr

备战蓝桥杯---状态压缩DP基础2之TSP问题

先来一个题衔接一下: 与上一题的思路差不多,不过这里有几点需要注意: 1.因为某一列的状态还与上上一行有关,因此我们令f[i][j][k]表示第i行状态为j,第i-1行状态为k的最大炮兵数。 因此,我们可以得到状态转移方程&…

IDEA的安装教程

1、下载软件安装包 官网下载:https://www.jetbrains.com/idea/ 2、开始安装IDEA软件 解压安装包,找到对应的idea可执行文件,右键选择以管理员身份运行,执行安装操作 3、运行之后,点击NEXT,进入下一步 4、…

电子科技大学《数据库原理及应用》(持续更新)

前言 电子科技大学的数据库课程缩减了部分的课时,因此,可能并不适合所有要学习数据库的宝子们,但是,本人尽量将所有数据库的内容写出来。本文章适用于本科生的期中和期末的复习,电子科技大学的考生请在复习前先看必读…

【算法】顺时针打印矩阵(图文详解,代码详细注释

目录 题目 代码如下: 题目 输入一个矩阵,按照从外向里以顺时针的顺序依次打印出每一个数字。例如:如果输入如下矩阵: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 则打印出数字:1 2 3 4 8 12 16 15 14 13 9 5 6 7 11 10 这一道题乍一看,没有包含任何复杂的数据结构和…

Android Gradle开发与应用 (三) : Groovy语法概念与闭包

1. Groovy介绍 Groovy是一种基于Java平台的动态编程语言,与Java是完全兼容,除此之外有很多的语法糖来方便我们开发。Groovy代码能够直接运行在Java虚拟机(JVM)上,也可以被编译成Java字节码文件。 以下是Groovy的一些…

驾辰龙跨Llama持Wasm,玩转Yi模型迎新春

今年新年很特别,AI工具添光彩。今天就来感受下最新的AI神器天选组合“WasmEdgeYi-34B”,只要短短三步,为这个甲辰龙年带来一份九紫离火运的科技感。 环境准备 这次用的算力是OpenBayes提供的英伟达RTX_4090*1、24GB显存、20核CPU、80GB内存…

【论文笔记】Mamba:挑战Transformer地位的新架构

Mamba Mamba: Linear-Time Sequence Modeling with Selective State Spaces Mamba Mamba摘要背景存在的问题本文的做法实验结果 文章内容Transformer的缺点Structured state space sequence models (SSMs)介绍本文的工作模型介绍State Space ModelsSelective State Space Mod…

docker三剑客compose+machine+swarm小结

背景 在容器领域,不少公司会使用docker三剑客composemachineswarm进行容器编排和部署,本文就简单记录下这几个工具的用法 三剑客composemachineswarm compose compose主要是用于容器编排,我们部署容器时,容器之间会有依赖&…

git的安装、使用

文章目录 安装gitgit学习网站git初始配置具体配置信息 新建版本库(仓库)git的工作区域和文件状态工作区域文件状态git文件提交的基础指令 git基础指令1. 版本提交2. 分支创建3. 分支切换4. 分支合并(1) git merge(2) git rebase 5. 在git的提交树上移动(…

vue+springboot项目部署服务器

项目仓库:vuespringboot-demo: vuespringboot增删改查的demo (gitee.com) ①vue中修改配置 在public文件夹下新建config.json文件: {"serverUrl": "http://localhost:9090"//这里localhost在打包后记得修改为服务器公网ip } 然后…

三天学会阿里分布式事务框架Seata-seata事务日志mysql持久化配置

锋哥原创的分布式事务框架Seata视频教程: 实战阿里分布式事务框架Seata视频教程(无废话,通俗易懂版)_哔哩哔哩_bilibili实战阿里分布式事务框架Seata视频教程(无废话,通俗易懂版)共计10条视频&…

Java中的Collection

Collection Collection 集合概述和使用 Collection集合概述 是单例集合的顶层接口,它表示一组对象,这些对象也称为Collection的元素 JDK 不提供此接口的任何直接实现.它提供更具体的子接口(如Set和List)实现 创建Collection集合的对象 多态的方式 具体的实现类ArrayList C…

Pycharm的下载安装与汉化

一.下载安装包 1.接下来按照步骤来就行 2.然后就能在桌面上找到打开了 3.先建立一个文件夹 二.Pycharm的汉化

javaweb day9 day10

昨天序号标错了 vue的组件库Elent 快速入门 写法 常见组件 复制粘贴 打包部署

修改一个教材上的网站源码使它能在www服务器子目录上正常运行

修改一个教材上的网站源码,使它能在www服务器子目录上正常运行。 该网站源码是教材《PHPMySQL网站开发项目式教程》上带的网站源码。该源码是用 php html 写的。该源码包含对mysql数据库进行操作的php代码。以前该网站源码只能在www服务器的根目录上正常运行&…

一文认识蓝牙(验证基于Aduino IDE的ESP32)

1、简介 蓝牙技术是一种无线通信的方式,利用特定频率的波段(2.4GHz-2.485GHz左右),进行电磁波传输,总共有83.5MHz的带宽资源。 1.1、背景 蓝牙(Bluetooth)一词取自于十世纪丹麦国王哈拉尔Haral…

[技巧]Arcgis之图斑四至点批量计算

前言 上一篇介绍了arcgis之图斑四至范围计算,这里介绍的图斑四至点的计算及获取,两者之间还是有差异的。 [技巧]Arcgis之图斑四至范围计算 这里说的四至点指的是图斑最东、最西、最南、最北的四个地理位置点坐标,如下图: 四至点…

SCP命令行向服务器端上传文件或下载文件

环境要求 使用scp(Secure Copy Protocol)命令在本地和远程系统之间安全地复制文件和目录,需要满足以下环境要求: SSH服务:scp依赖于SSH(Secure Shell)协议来安全地传输文件。因此,…