采用自动微分进行模型的训练

news2024/9/24 13:19:38

 自动微分训练模型

 简单代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的线性回归模型
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1

    def forward(self, x):
        return self.linear(x)

# 准备训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])

# 实例化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 训练模型
epochs = 1000
for epoch in range(epochs):
    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播
    optimizer.zero_grad()  # 清空之前的梯度
    loss.backward()  # 自动计算梯度
    optimizer.step()  # 更新模型参数

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 测试模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'预测值: {predicted.item():.4f}')

代码分解:

1.定义一个简单的线性回归模型:

  • LinearRegression 类继承自nn.Module,这是所有神经网络模型的基类
  • 在 __init__ 方法中,定义了一个线性层 self.linear,它的输入维度是1,输出维度也是1。
  • forward 方法定义了数据在模型中的传播路径,即输入 x 经过 self.linear 层后得到输出。
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1
    
        def forward(self, x):
            return self.linear(x)
    

2.准备训练数据:

  • x_train 和 y_train 分别是输入和目标输出的训练数据。每个张量表示一个样本,x_train 中的每个元素是一个维度为1的张量,因为模型的输入维度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.实例化模型,损失函数和优化器:

  • model 是我们定义的 LinearRegression 类的一个实例,即我们要训练的线性回归模型。
  • criterion 是损失函数,这里选择了均方误差损失(MSE Loss),用于衡量预测值与实际值之间的差异。
  • optimizer 是优化器,这里选择了随机梯度下降(SGD),用于更新模型参数以最小化损失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    

4.训练模型:

  • 这里进行了1000次迭代的训练过程。
  • 在每个迭代中,首先进行前向传播,计算模型对 x_train 的预测输出 outputs,然后计算损失 loss
  • 调用 optimizer.zero_grad() 来清空之前的梯度,然后调用 loss.backward() 自动计算梯度,最后调用 optimizer.step() 来更新模型参数
    epochs = 1000
    for epoch in range(epochs):
        # 前向传播
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
    
        # 反向传播
        optimizer.zero_grad()  # 清空之前的梯度
        loss.backward()  # 自动计算梯度
        optimizer.step()  # 更新模型参数
    
        if (epoch+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.测试模型:

  • x_test 是用来测试模型的输入数据,这里表示输入为4.0。
  • model(x_test) 对 x_test 进行前向传播,得到预测结果 predicted
  • predicted.item() 取出预测结果的标量值并打印出来。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'预测值: {predicted.item():.4f}')
    

运行结果:

运行结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924461.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python:使用matplotlib库绘制图像(四)

作者是跟着http://t.csdnimg.cn/4fVW0学习的,matplotlib系列文章是http://t.csdnimg.cn/4fVW0的自己学习过程中整理的详细说明版本,对小白更友好哦! 四、条形图 1. 一个数据样本的条形图 条形图:常用于比较不同类别的数量或值&…

DockerCompose介绍,安装,使用

DockerCompose 1、Compose介绍 将单机服务-通过Dockerfile 构建为镜像 -docker run 成为一个服务 user 8080 net 7000 pay 8181 admin 5000 监控 .... docker run 单机版、一个个容器启动和停止问题: 前面我们使用Docker的时候,定义 Dockerfil…

深入理解Java泛型:概念、用法与案例分析

个人名片** 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4…

Transformer模型:Encoder的self-attention mask实现

前言 这是对Transformer模型的Word Embedding、Postion Embedding内容的续篇。 视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili 文章链接:Transformer模型:WordEmbedding实现-CSDN博客 Transformer模型…

docker-compose安装PolarDB-PG数据库

文章目录 一. Mac1.1 docker-compose.yaml1.2 部署1.3 卸载4. 连接 二. Win102.1 docker-compose.yaml2.2 部署2.3 卸载 参考官方文档 基于单机文件系统部署 一. Mac 1.1 docker-compose.yaml mkdir -p /Users/wanfei/docker-compose/polardb-pg && cd /Users/wanfei…

Linux - 综合使用shell脚本,输出网站有效数据

综合示例: shell脚本实现查看网站分数 使用编辑器编辑文件jw.sh为如下内容: #!/bin/bash save_file"score" # 临时文件 semester20102 # 查分的学期, 20102代表2010年第二学期 jw_home"http://jwas3.nju.edu.cn:8080/jiaowu" # 测试网站首页地址 jw_logi…

zigbee开发工具:2、zigbee工程建立与配置

本文演示基于IAR for 8051(版本10.10.1)如何建立一个开发芯片cc2530的zigbee的工程,并配置这个工程,使其能够将编译的代码进行烧录,生成.hex文件。IAR for 8051(版本10.10.1)支持工程使用C语言&…

STM32智能交通灯系统教程

目录 引言环境准备智能交通灯系统基础代码实现:实现智能交通灯系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景:交通管理与优化问题解决方案与优化收尾与总结 1. 引言 智能交通灯系统通过STM…

Python游戏开发:四连珠(内附完整代码)

四连珠(Connect Four)是一款经典的棋类游戏,由两名玩家在7列6行的网格上轮流下棋。玩家的目标是将自己的棋子在垂直、水平或对角线上连成一条线,通常是四个棋子。如果一方成功做到这一点,那么他就赢得了游戏。如果所有…

视频监控汇聚平台:通过SDK接入大华DSS视频监控平台的源代码解释和分享

目录 一、视频监控汇聚平台 1、概述 2、视频接入能力 3、视频汇聚能力 二、大华DSS平台 1、DSS平台概述 2、大华DSS平台的主要特点 (1)高可用性 (2)高可靠性 (3)易维护性 (4&#xf…

《昇思25天学习打卡营第2天|02快速入门》

课程目标 这节课准备再学习下训练模型的基本流程,因此还是选择快速入门课程。 整体流程 整体介绍下流程: 数据处理构建网络模型训练模型保存模型加载模型 思路是比较清晰的,看来文档写的是比较连贯合理的。 数据处理 看数据也是手写体数…

【算法】平衡二叉树

难度:简单 题目 给定一个二叉树,判断它是否是 平衡二叉树 示例: 示例1: 输入:root [3,9,20,null,null,15,7] 输出:true 示例2: 输入:root [1,2,2,3,3,null,null,4,4] 输出&…

炒鸡清晰的防御综合实验(内含区域划分,安全策略,用户认证,NAT认证,智能选路,域名访问)

实验拓扑图如下: 前面六个条件在之间的实验中做过了,详细步骤可以去之前的文章看 这里简写一下大致步骤 第一步: 先将防火墙之外的配置给配置好,比如,PC的IP,交换上的Vlan划分。 第二步: 在浏览器上登…

用SurfaceView实现落花动画效果

上篇文章 Android子线程真的不能刷新UI吗?(一)复现异常 中可以看出子线程更新main线程创建的View,会抛出异常。SurfaceView不依赖main线程,可以直接使用自己的线程控制绘制逻辑。具体代码怎么实现了? 这篇文章用Surfa…

【算法专题】快速排序

1. 颜色分类 75. 颜色分类 - 力扣(LeetCode) 依据题意,我们需要把只包含0、1、2的数组划分为三个部分,事实上,在我们前面学习过的【算法专题】双指针算法-CSDN博客中,有一道题叫做移动零,题目要…

小公司的Git工作流程

项目初始化 git init并添加.gitignore文件 Git使用 通过git add . 把代码推到暂存区通过git commit -m “你的说明”,将暂存区的代码推到本地仓库将本地仓库的代码通过git push 推到远程仓库远程仓库(gitee/gitlab/github)同事就可以通过命令git pull将你推上去的…

信息学奥赛初赛天天练-46-CSP-J2020阅读程序2-进制转换、十进制转k进制、等比数列通项公式、等比数列求和公式应用

PDF文档公众号回复关键字:20240713 2020 CSP-J 阅读程序2 1阅读程序(程序输入不超过数组或字符串定义的范围&#xff1b;判断题正确填 √&#xff0c;错误填 。除特殊说明外&#xff0c;判断题 1.5 分&#xff0c;选择题 3 分&#xff0c;共计 40 分) 01 #include <iostre…

java各种锁介绍

在 Java 中&#xff0c;锁是用来控制多个线程对共享资源进行访问的机制。主要有以下几种类型的锁&#xff1a; 1.互斥锁&#xff08;Mutex Lock)&#xff1a;最简单的锁&#xff0c;一次只允许一个线程访问共享资源。如果一个线程获得了锁&#xff0c;其他线程必须等待锁被释放…

DEBUG:jeston卡 远程ssh编程

问题 jeston 打开网页 gpt都不方便 而且只需要敲命令就行 解决 下载MobaXterm(window执行) liunx需要虚拟机 软件 远程快速复制命令

【kubernetes】Helm包管理器基本概念与Chart实战

概念&#xff1a;基础架构与常用命令 三个重要概念&#xff1a; 1.chart 创建Kukernetes应用程序所必需的一组信息。 2.config 包含了可以合并到打包的chart中的配置信息&#xff0c;用于创建一个可发布的对象。 3.release 是一个与特走配置相结合的chart的运行实例。 常用命…