【深度学习实验】线性模型(四):使用Pytorch实现线性模型:使用随机梯度下降优化器训练模型

news2025/4/6 6:19:42

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 线性模型linear_model

2. 损失函数loss_function

3. 定义数据

4. 初始化权重和偏置

5. 模型训练

6. 迭代

7. 实验结果

8. 完整代码


一、实验介绍

        使用随机梯度下降优化器训练线性模型,并输出优化后的参数

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

         随机梯度下降(Stochastic Gradient Descent,SGD)是一种常用的优化算法,用于训练机器学习模型。它是梯度下降算法的一种变体,主要用于大规模数据集或高维参数空间的情况下。

        梯度下降算法的目标是通过最小化损失函数来寻找模型参数的最优解。在传统的梯度下降算法中,每个训练周期(epoch)都需要计算整个训练集的梯度,然后更新模型参数。这种方法在大规模数据集上计算量较大,因为每个训练周期都需要遍历整个数据集。

        与传统的梯度下降不同,随机梯度下降每次迭代仅使用一个样本(或一小批样本)来计算梯度,并更新模型参数。具体步骤如下:

  1. 初始化模型参数。
  2. 将训练数据集随机打乱顺序。
  3. 对于每个训练样本(或小批量样本):
    • 计算模型对于当前样本的预测值。
    • 计算损失函数对于当前样本的梯度。
    • 根据梯度和学习率更新模型参数。
  4. 重复步骤3,直到达到预定的训练周期数或满足停止条件。

        随机梯度下降的主要优点是计算效率高,尤其适用于大规模数据集。它也可以在每个训练周期中进行参数更新,因此可以更快地收敛。然而,由于每次迭代仅使用一个样本(或小批量样本),因此随机梯度下降的更新方向可能会更加不稳定,导致训练过程中的损失函数波动较大。为了缓解这个问题,可以使用学习率衰减、动量等技巧来改进算法。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

0. 导入库

import torch

1. 线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):
    return torch.matmul(x, w) + b

2. 损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss

3. 定义数据

  • 生成一个随机的输入张量 x,形状为 (5, 1),表示有 5 个样本,每个样本的特征维度为 1。

  • 生成一个目标张量 y,形状为 (5, 1),表示对应的真实标签。

  • 打印数据的信息,包括每个样本的输入值x和目标值y
x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):
    print("Item " + str(i), "x:", x[i][0], "y:", y[i])

4. 初始化权重和偏置

w = torch.rand(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

5. 模型训练

model = linear_model(x, w, b)
optimizer = optim.SGD([w, b], lr=0.01)  # 使用SGD优化器

6. 迭代

num_epochs = 100  # 迭代次数
for epoch in range(num_epochs):
    optimizer.zero_grad()   # 梯度清零
    prediction = linear_model(x, w, b)
    loss = loss_function(y, prediction)
    loss.mean().backward()  # 计算梯度
    optimizer.step()        # 更新参数
    # print(w, b)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.mean().item()}")
  • 在每个迭代中:

    • 将优化器的梯度缓存清零,然后使用当前的权重和偏置对输入 x 进行预测,得到预测结果 prediction

    • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

    • 调用 loss.mean().backward() 计算损失的平均值,并根据计算得到的梯度进行反向传播。

    • 调用 optimizer.step() 更新权重和偏置,使用优化器进行梯度下降更新。

    • 每隔 10 个迭代输出当前迭代的序号、总迭代次数和损失的平均值。

7. 实验结果

print("The optimized parameters are:")
print("w:", model[0].item())
print("b:", model[1].item())

8. 完整代码

import torch
import torch.optim as optim


def linear_model(x, w, b):
    return torch.matmul(x, w) + b


def loss_function(y_true, y_pred):
    loss = (y_pred - y_true) ** 2
    return loss


x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):
    print("Item " + str(i), "x:", x[i][0], "y:", y[i])

w = torch.rand(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
model = linear_model(x, w, b)
optimizer = optim.SGD([w, b], lr=0.01)  # 使用SGD优化器

num_epochs = 100  # 迭代次数
for epoch in range(num_epochs):
    optimizer.zero_grad()   # 梯度清零
    prediction = linear_model(x, w, b)
    loss = loss_function(y, prediction)
    loss.mean().backward()  # 计算梯度
    optimizer.step()        # 更新参数
    # print(w, b)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.mean().item()}")

print("The optimized parameters are:")
print("w:", model[0].item())
print("b:", model[1].item())

注意:

        本实验使用随机瞎生成的数据,所以训练起来没有任何意义,下文将基于经典的鸢尾花数据集进行实验,并对模型进行评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1021824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ES6知识】Iterator迭代器与 class类

文章目录 一、Iterator迭代器1.1 基础知识概述1.2 工作原理1.3 Symbol.iterator1.4 Generator函数来实现Symbol.iterator接口 二、ES6 Class 类2.1 概述2.2 ES6中的继承2.3 面向对象应用 - React 一、Iterator迭代器 1.1 基础知识概述 迭代器(Iterator&#xff09…

JAVA实现PDF转图片

前言 使用wps自带的转换工具&#xff0c;需要花钱&#xff0c;不花钱的话还带水印。于是&#xff0c;使用java程序将pdf转换为图片。 代码 依赖 <dependencies><dependency><groupId>org.apache.pdfbox</groupId><artifactId>fontbox</ar…

几个国内可用的强大的GPT工具

前言&#xff1a; 人工智能发布至今&#xff0c;过去了九个多月&#xff0c;已经成为了我们不管是工作还是生活中一个重要的辅助工具&#xff0c;大大提升了效率&#xff0c;作为一个人工智能的自然语言处理工具&#xff0c;它给各大行业的提供了一个巨大的生产工具&#xff0c…

Vue Grid Layout -️ 适用Vue.js的栅格布局系统,在vue3+上使用

文章目录 1、官网简介2、在vue3中使用1)、需要导入vue3支持的版本插件2)、在mian.js里引入&#xff1a;3)、在组件中使用 3、layout布局的计算逻辑4、 gridLayout 的属性 该栅格系统目前对 vue2 的支持是最好的&#xff0c;vue3 是需要用插件支持的&#xff0c;会在小节详细讲解…

目标检测:Edge Based Oriented Object Detection

论文作者&#xff1a;Jianghu Shen,Xiaojun Wu 作者单位&#xff1a;Harbin Institute of Technology Shenzhen 论文链接&#xff1a;http://arxiv.org/abs/2309.08265v1 内容简介&#xff1a; 1&#xff09;方向&#xff1a;遥感领域中的目标检测技术 2&#xff09;应用&…

肖sir__mysql之存储__012

mysql之存储 一、存储 1、什么是存储过程 定义&#xff1a;存储过程就是实现某个特定功能的sql语句的集合&#xff0c;编译后的存储过程会保存在数据库中&#xff0c;通过存储过程的名称可以反复的调用执行。 2、存储过程的优点? (1)存储创建后&#xff0c;可以反复的调用&am…

液晶LCD显示驱动VK2C23可支持最大416点(52SEGx8COM)的LCD屏提供LQFP48、LQFP64的封装

VK2C23是一个点阵式存储映射的LCD驱动器&#xff0c;可支持最大224点&#xff08;56SEGx4COM&#xff09;或者最大416点&#xff08;52SEGx8COM&#xff09;的LCD屏。单片机可通过I2C接口配置显示参数和读写显示数据&#xff0c;也可通过指令进入省电模式。其高抗干扰&#xff…

Zoho Projects登顶福布斯2023年项目管理工具排行榜

Zoho Projects一款集成度较高的项目管理工具&#xff0c;近日入选福布斯2023年十佳项目管理工具榜单。作为一款为中小企业量身定制的解决方案&#xff0c;Zoho Projects以其卓越的功能和性价比&#xff0c;赢得了市场的广泛认可。 在当今竞争激烈的市场环境中&#xff0c;选择一…

F.cross_entropy的使用困惑--终结

1.由于按照cross_entropy的计算公式&#xff0c;F.cross_entropy(outs,label)这里面的outs和label的应该都是batch_size * one_hot_vec的形状&#xff0c;但是&#xff0c;这个函数呢&#xff0c;第二个label只要是batch_size * scalar即可&#xff0c;那个scalar就是index的位…

如果你用Markdown写公众号文章,试试我做的在线转换神器

大家好&#xff0c;我希望向大家介绍一款我最近开发的实用工具——"MarkdownConvert"&#xff0c;它能够以极简和高效的方式将Markdown文本转换为微信公众号所支持的HTML格式。它是我基于一个开源项目改进优化了页面&#xff0c;增加了自己喜欢的文章主题。 先聊聊这…

微信小程序传参的五种方式

文章目录 前言一、URL参数传递1.api跳转2.组件跳转 二、Storage本地存储三、全局变量globalData四、页面跳转时传参五、页面栈传参总结结语 前言 大家好&#xff0c;今天和大家分享一下微信小程序页面之间传参的五种方式&#xff0c;这个的话也是有人问了我一嘴&#xff0c;然…

Python进阶复习-文件与异常

目录 文件打开文件路径打开模式字符编码 文件读取逐行读取读取所有行 文件写入既读又写两种数据存储结构csv文件json文件 程序异常Exception万能捕捉 文件打开 文件路径 完整路径 with open("E:\hello.txt", "r", encoding"UTF-8") as file:c…

Socks5与HTTP的区别与应用场景

在网络访问中&#xff0c;代理服务器扮演着重要角色&#xff0c;用于保护用户隐私、提高访问速度等。Socks5代理和HTTP代理是两种常见的代理协议&#xff0c;它们在功能和应用场景上有所不同。本文将深入解析Socks5代理和HTTP代理的区别&#xff0c;帮助您更好地了解并选择适合…

Linux 信号捕捉函数 signal sigaction

signal函数 #include <signal.h> typedef void (*sighandler_t)(int); sighandler_t signal(int signum, sighandler_t handler); 功能&#xff1a;设置某个信号的捕捉行为 参数&#xff1a; -signum&#xff1a;要捕捉的信号 handler&#xff1a;对捕捉到的信号怎么处理…

安装了多个版本VS导致无法安装vsix

如&#xff0c;先后安装了VS2015、VS2019&#xff0c;现在想给VS2015安装一个qt-vsaddin插件&#xff0c;运行vsix报错 “View Install Log”里显示 2023/9/19 10:03:46 - 正在搜索适用的产品... 2023/9/19 10:03:46 - 找到的已安装产品 - Microsoft Visual Studio Ultimate …

中标麒麟--国产操作系统-九五小庞

那么&#xff0c;我国国产操作系统现状到底如何呢&#xff1f; 自 1999 年徐冠华部长一语点破我们的产业软肋之后&#xff0c;国产操作系统起步于国家“七五”计划期间&#xff0c;目前国产操作系统均是基于Linux内核进行的二次开发&#xff0c;中国国产操作系统进入Linux元年…

多线程详解(上)

文章目录 一、线程的概念1&#xff09;线程是什么2&#xff09;为甚要有线程&#xff08;1&#xff09;“并发编程”成为“刚需”&#xff08;2&#xff09;在并发编程中, 线程比进程更轻量. 3&#xff09;线程和进程的区别 二、Thread的使用1&#xff09;线程的创建继承Thread…

【随想】每日两题Day.7

题目&#xff1a;面试题 02.07.链表相交 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 …

Excel中的宏、VBA

一、宏是什么&#xff1f; EXCEL MACRO 是一种记录和播放工具&#xff0c;它仅记录您的 Excel 步骤&#xff0c;并且宏将根据需要播放任意多次。 VBA 宏可自动执行重复任务&#xff0c;从而节省了时间。 这是一段可在 Excel 环境中运行的编程代码&#xff0c;但您无需成为编码…

EtherCAT 总线型 4 轴电机控制卡解决方案

 技术特点  支持标准 100M/s 带宽全双工 EtherCAT 总线网络接口及 CoE 通信协议一 进一出&#xff08;RJ45 接口&#xff09;&#xff0c;支持多组动态 PDO 分组和对象字典的自动映射&#xff0c;支持站 号 ID 的自动设置与保存&#xff0c;支持 SDO 的电机参数设置与…