【动手学深度学习】softmax回归的简洁实现详情

news2025/1/16 4:43:50

目录

🌊1. 研究目的

🌊2. 研究准备

🌊3. 研究内容

🌍3.1 softmax回归的简洁实现

🌍3.2 基础练习

🌊4. 研究体会


🌊1. 研究目的

  • 理解softmax回归的原理和基本实现方式;
  • 学习如何从零开始实现softmax回归,并了解其关键步骤;
  • 通过简洁实现softmax回归,掌握使用现有深度学习框架的能力;
  • 探索softmax回归在分类问题中的应用,并评估其性能。

🌊2. 研究准备

  • 根据GPU安装pytorch版本实现GPU运行研究代码;
  • 配置环境用来运行 Python、Jupyter Notebook和相关库等相关库。

🌊3. 研究内容

启动jupyter notebook,使用新增的pytorch环境新建ipynb文件,为了检查环境配置是否合理,输入import torch以及torch.cuda.is_available() ,若返回TRUE则说明研究环境配置正确,若返回False但可以正确导入torch则说明pytorch配置成功,但研究运行是在CPU进行的,结果如下:


🌍3.1 softmax回归的简洁实现

完成softmax回归的简洁实现的研究代码及练习内容如下:

导入必要库及模型:

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
        
net.apply(init_weights)

重新审视Softmax的实现

loss = nn.CrossEntropyLoss(reduction='mean')  # 将reduction设置为'mean'或'sum'

优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)


🌍3.2 基础练习

1.尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。

在这个示例中,我将批量大小调整为128,迭代周期数调整为20,学习率调整为0.01。

import torch
from torch import nn
from d2l import torch as d2l

# 超参数调整
batch_size = 128  # 调整批量大小
num_epochs = 20  # 调整迭代周期数
learning_rate = 0.01  # 调整学习率

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
net.apply(init_weights)

loss = nn.CrossEntropyLoss(reduction='mean')

trainer = torch.optim.SGD(net.parameters(), lr=learning_rate)

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.增加迭代周期的数量。为什么测试精度会在一段时间后降低?我们怎么解决这个问题?

当增加迭代周期的数量时,训练过程会继续进行更多的迭代,模型会有更多的机会学习训练数据中的模式和特征。通常情况下,增加迭代周期数量可以提高模型的训练精度。然而,如果过度训练,测试精度可能会在一段时间后开始降低。

这种情况被称为"过拟合"(overfitting)。过拟合发生时,模型在训练数据上表现得很好,但在新数据(测试数据)上表现较差。过拟合是由于模型过于复杂,过度记住了训练数据中的噪声和细节,而无法泛化到新数据。

为了解决过拟合问题,可以尝试以下几种方法:

  • 提前停止(Early Stopping):在训练过程中,跟踪训练误差和测试误差。一旦测试误差开始上升,就停止训练。这样可以防止模型过度拟合训练数据。
  • 正则化(Regularization):通过向损失函数添加正则化项,限制模型参数的大小,防止过度拟合。常见的正则化方法包括L1正则化和L2正则化。
  • 数据增强(Data Augmentation):通过对训练数据进行随机变换(如旋转、翻转、缩放等),增加训练样本的多样性,有助于提高模型的泛化能力。
  • 减小模型复杂度:减少模型的层数、节点数或参数量,使其更简单。简化模型可以降低过拟合的风险。
  • 使用更多的训练数据:增加训练数据量可以减少过拟合的可能性,因为模型将有更多的样本进行学习。

通过组合使用这些方法,可以有效地解决过拟合问题并提高模型的泛化能力。


🌊4. 研究体会

通过这次研究,我深入学习了softmax回归模型,理解了它的原理和基本实现方式。开始了解softmax回归的背景和用途,它在多类别分类问题中的应用广泛;学习了如何从零开始实现softmax回归,并掌握了其中的关键步骤。

通过简洁实现softmax回归,更加熟悉了深度学习框架的使用。可以通过几行代码完成模型的定义、数据的加载和训练过程。还学会了使用框架提供的工具来评估模型的性能,如计算准确率和绘制混淆矩阵。这使能够更方便地对模型进行调试和优化,以获得更好的分类结果。

最后,通过实验探索了softmax回归在分类问题中的应用,并评估了其性能。使用了一些真实的数据集,如MNIST手写数字数据集,来进行实验。在实验中,将数据集划分为训练集和测试集,用训练集来训练模型,然后用测试集来评估模型的性能。

在从零开始实现的实验中,对模型的性能进行了一些调优,比如调整学习率和迭代次数。观察到随着迭代次数的增加,模型的训练损失逐渐下降,同时在测试集上的准确率也在提升。这证明了的模型在一定程度上学习到了数据的规律,并能够泛化到新的样本。而在简洁实现的实验中,由于深度学习框架的优化算法和自动求导功能,模型的训练速度明显快于从零开始实现。同时,框架提供了更多的网络结构和调优方法,使能够更加灵活地构建和调整模型。在简洁实现中,我还尝试了一些不同的模型结构,比如加入隐藏层或使用更复杂的优化算法,以探索更高效的模型设计。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1799383.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开发人员必备的常用工具合集-lombok

Project Lombok 是一个 java 库,它会自动插入您的编辑器和构建工具,为您的 Java 增添趣味。 再也不用编写另一个 getter 或 equals 方法了,只需一个注释,您的类就拥有了一个功能齐全的构建器,自动化了您的日志记录变量…

从零开始手把手Vue3+TypeScript+ElementPlus管理后台项目实战五(引入vue-router,并给注册功能加上美丽的外衣el-form)

安装vue-router pnpm install vue-router创建router src下新增router目录,ruoter目录中新增index.ts import { createRouter, createWebHashHistory } from "vue-router"; const routes [{path: "/",name: "Home",component: () …

SQL语句练习每日5题(四)

题目1——查找GPA最高值 想要知道复旦大学学生gpa最高值是多少,请你取出相应数据 题解: 1、使用MAX select MAX(gpa) FROM user_profile WHERE university 复旦大学 2、使用降序排序组合limit select gpa FROM user_profile WHERE university 复…

当C++的static遇上了继承

比如我们想要统计下当前类被实例化了多少次,我们通常会这么写 class A { public:A() { Count_; }~A() { Count_--; }int GetCount() { return Count_; }private:static int Count_; };class B { public:B() { Count_; }~B() { Count_--; }int GetCount() { return …

LeetCode1143最长公共子序列

题目描述 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符&#xff08…

Type-C音频转接器方案

在数字化时代,音频设备作为我们生活中不可或缺的一部分,其连接方式的便捷性和高效性显得尤为重要。Type-C音频转接器,作为一种新型的音频连接解决方案,正逐渐走进我们的生活,以其独特的优势改变着我们的音频体验。 一、…

数据库设计步骤、E-R图转关系模式、E-R图的画法

一、数据库设计步骤 ①需求分析阶段 准确了解与分析用户需求。 ②概念结构设计阶段 通过对用户需求进行综合、归纳与抽象,形成一个独立于具体数据库管理系统的概念模型。 ③逻辑结构设计阶段 将概念结构转换为某个数据库管理系统所支持的数据模型&am…

LCM — Least Common Multiple 最小公倍数

因为任何一个数都可以表示为若干个质数幂的乘积。 比如75 3*5*5,即 2^0 * 3^1 * 5^2 * 7^0 ... 那么对于两个数来说,gcd就是他们取每个质数的较小幂的乘积,lcm则相反。显然,这些幂加起来就是他们乘积。 gcd(a,b) * lcm(a,b) a…

C++学习/复习13--list概述

一、list概念 1.带头双向链表 2.构造函数 3.迭代器(其迭代器需尤其注意) 4、size 5.front/back 6.插入删除 删除时的迭代器失效 由于list的节点特殊,既有数据又有指针,其实现需要节点/迭代器/list各成一类再组合

【TB作品】MSP430F5529 单片机,智能温控系统,DS18B20

作品功能 本项目设计并实现了一个基于MSP430单片机的智能温控系统。系统可以实时显示当前温度,并且可以根据设置的临界值对环境进行加热或降温。主要功能如下: 实时显示当前温度。显示并调整温度临界值,临界值可在20~35摄氏度之间调节。当前…

chorme浏览器查看shadow-root配置

F12打开控制台,点击设置图标 点击偏好设置-> 勾选显示用户代理 Shadow DOM

C++STL---stack queue模拟实现

前言 对于这两个容器适配器的模拟实现非常简单,因为stack和queue只是对其他容器的接口进行了包装,在STL中,若我们不指明用哪种容器作为底层实现,栈和队列都默认是又deque作为底层实现的。 也就是说,stack和queue不管是…

Django学习一:创建Django框架,介绍Django的项目结构和开发逻辑。创建应用,编写主包和应用中的helloworld

文章目录 前言一、Django环境配置1、python 环境2、Django环境3、mysql环境4、IDE:pycharm 二、第一次创建Django项目1、创建项目door_web_django_system2、运行启动 三、Django项目介绍1、介绍Django项目结构2、第一个helloword4、django的项目逻辑(和j…

为什么PPT录制没有声音 电脑ppt录屏没有声音怎么办

一、为什么PPT录制没有声音 1.软件问题 我们下载软件的时候可能遇到软件损坏的问题,导致录制没有声音,但其他功能还是可以使用的。我建议使用PPT的隐藏功能,下载插件,在PPT界面的加载项选项卡中就能使用。我推荐一款可以解决录屏…

探索风电机组:关键软件工具全解析

探索风电机组:关键软件工具全解析 随着可再生能源市场的迅猛发展,风电作为一种重要的可再生能源,其相关技术和工具也越来越受到重视。风电机组的设计、仿真、优化及运维等方面,都需要依靠一系列专业软件工具来实现。这些软件涵盖…

链表的回文结构OJ

链表的回文结构_牛客题霸_牛客网对于一个链表,请设计一个时间复杂度为O(n),额外空间复杂度为O(1)的算法,判断其是否为。题目来自【牛客题霸】https://www.nowcoder.com/practice/d281619e4b3e4a60a2cc66ea32855bfa?tpId49&&tqId29370&rp1&a…

k8s怎么监听资源的变更

监听k8s所有的 Deployment 资源 package mainimport ("context""fmt"v1 "k8s.io/api/apps/v1""k8s.io/apimachinery/pkg/util/json""k8s.io/client-go/informers""k8s.io/client-go/kubernetes""k8s.io/cli…

RFID测温技术在电力行业的革命性应用

随着科技的快速发展, RFID技术在各个领域的应用越来越广泛,而其中的一个重要领域就是电力行业。这一无线测温技术以其非接触、实时、高精度的特点,为电力设备的温度监测带来了革命性的改变。电力行业作为国家基础设施建设的重要支柱,设备的安…

使用GitHub托管静态网页

前言​: 如果没有服务器,也没有域名,又想部署静态网页的同学,那就可以尝试使用GitHub托管自己的网页​。 正文: 首先要有自己的GitHub的账号,如果没有可以自己搜索官网进行注册登录,国内对Gi…

git【工具软件】分布式版本控制工具软件

一、Git 的介绍 git软件的作用:管理软件开发项目中的源代码文件。 常用功能: 仓库管理、文件管理、分支管理、标签管理、远程操作 功能指令: add,commit,log,branch,tag,remote…