跟着StatQuest学知识07-张量与PyTorch

news2025/3/26 16:41:20

一、张量tensor

张量重新命名一些数据概念,存储数据以及权重和偏置。

张量还允许与数据相关的数学计算能够相对快速的完成。

通常,张量及其进行的数学计算会通过成为图形处理单元(GPUs)的特殊芯片来加速。但还有张量处理单元(TPUs)专门处理张量,使得神经网络运行相当更快。

另外,张量通过自动微分处理反向传播。

二、PyTorch

以下部分参考 【深度学习基础】用PyTorch从零开始搭建DNN深度神经网络

 图中的这个神经网络的参数都是训练优化好的,下面我们简便起见,假设最后一个参数b_final没有优化过,初始化为0,我们尝试用Pytorch实现一下对这个参数的优化,将final_bias初始化为0,看看最终这个-16可否被优化出来的。首先引入一些相关的库:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
import matplotlib.pyplot as plt
import seaborn as sns

其中torch就是PyTorch框架,matplotlib和seaborn都是用来绘图的库。然后我们定义对照着图中的各个参数,搭建神经网络如下: 

class BasicNN_train(nn.Module):  # 继承父类nn.Module
    def __init__(self):
        super().__init__()  # 对父类的成员进行初始化

        self.w00 = nn.Parameter(torch.tensor(1.7), requires_grad=False)
        self.b00 = nn.Parameter(torch.tensor(-0.85), requires_grad=False)
        self.w01 = nn.Parameter(torch.tensor(-40.8), requires_grad=False)

        self.w10 = nn.Parameter(torch.tensor(12.6), requires_grad=False)
        self.b10 = nn.Parameter(torch.tensor(0.0), requires_grad=False)
        self.w11 = nn.Parameter(torch.tensor(2.7), requires_grad=False)

        self.final_bias = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        # requires_grad=True 表示需要优化

    def forward(self, input):  # 前向传播
        input_to_top_relu = input * self.w00 + self.b00
        top_relu_output = F.relu(input_to_top_relu)
        scaled_top_relu_output = top_relu_output * self.w01

        input_to_bottom_relu = input * self.w10 + self.b10
        bottom_relu_output = F.relu(input_to_bottom_relu)
        scaled_bottom_relu_output = bottom_relu_output * self.w11

        input_to_final_relu = scaled_top_relu_output + scaled_bottom_relu_output + self.final_bias
        output = F.relu(input_to_final_relu)

        return output

然后我们实例化这个网路,设定epoch=100,即最多进行100次前向和反向传播,定义损失函数就是预测值和实际值的平方误差,当损失函数之和低于0.0001时,我们就停止训练(最多训练100轮次),代码如下:

if __name__ == '__main__':
    model = BasicNN_train()  # 实例化神经网络模型
    inputs = torch.tensor([0., 0.5, 1.])  # 输入张量
    labels = torch.tensor([0., 1., 0.])  # 输出张量
    # 定义一个优化器 optimizer,使用随机梯度下降(SGD)算法来更新模型的参数
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)  # 学习率为0.1
    print("优化前的final_bias是:" + str(model.final_bias.data) + '\n')
    # 开始训练,最多100轮次
    for epoch in range(100):
        total_loss = 0  # 累积当前 epoch 中所有样本的损失值
        for iteration in range(len(inputs)): # len(inputs) 表示数据集中样本的数量
            input_i = inputs[iteration]
            label_i = labels[iteration]

            output_i = model(input_i) # 前向传播
            loss = (output_i - label_i) ** 2

            loss.backward() # 反向传播
            # 通过反向传播,PyTorch 会自动计算每个参数的梯度,并存储在参数的 .grad 属性中
            total_loss += float(loss)# 将每个样本的loss加和
  •  backward() 的功能:

        backward() 使用链式法则计算损失函数 loss 对模型参数的梯度。

        loss.backward() 是从 loss 开始,沿着计算图反向传播梯度,最终得到每个参数的梯度值。这些梯度值(数据)会被存储在模型参数的 .grad 属性中,用于后续的参数更新。

  • 正向传播是怎么实现的?

        model(input_i) 会自动调用 model 中定义的 forward 方法。

        在 Python 中,当一个类的实例被“调用”时(例如 model(input_i)),Python 会尝试调用该实例的 __call__ 方法。

        PyTorch 的 nn.Module 类实现了 __call__ 方法。当你调用 model(input_i) 时,实际上是调用了 model.__call__(input_i)。

        if total_loss < 0.0001:
            print(f"当前是第{epoch}轮次,已经满足total_loss < 0.0001,结束程序。")
            break
        optimizer.step()  # 使用优化器(如 SGD)更新模型的权重和偏置,以最小化损失函数。
        optimizer.zero_grad()  # 清除模型参数的梯度。
        print(f"当前是第{epoch}轮次,此时的final_bias值为{model.final_bias.data},total_loss为{total_loss}")
        # 画图如下
        input_doses = torch.linspace(start=0, end=1, steps=11)
        output_values = model(input_doses)
        sns.set(style="whitegrid")
        sns.lineplot(x=input_doses,
                     y=output_values.detach(),
                     color='green',
                     linewidth=2.5)
        plt.ylabel('Effectiveness')
        plt.xlabel('Dose')
        plt.show()
    print(f"优化后的final_bias值为:{model.final_bias.data}")

最终的输出结果如下:

  一共34轮训练后,就实现了总损失小于0.001的要求,也看到最终的优化结果final_bia大概是-16,与之前我们的结论一致。 损失函数变化曲线如下:

    最终迭代到第34轮次后,实现了最终的效果: 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2320924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端字段名和后端不一致?解锁 JSON 映射的“隐藏规则” !!!

&#x1f680; 前端字段名和后端不一致&#xff1f;解锁 JSON 映射的“隐藏规则” &#x1f31f; 嘿&#xff0c;技术冒险家们&#xff01;&#x1f44b; 今天我们要聊一个开发中常见的“坑”&#xff1a;前端传来的 JSON 参数字段名和后端对象字段名不一致&#xff0c;会发生…

基于springboot的新闻推荐系统(045)

摘要 随着信息互联网购物的飞速发展&#xff0c;国内放开了自媒体的政策&#xff0c;一般企业都开始开发属于自己内容分发平台的网站。本文介绍了新闻推荐系统的开发全过程。通过分析企业对于新闻推荐系统的需求&#xff0c;创建了一个计算机管理新闻推荐系统的方案。文章介绍了…

2024年数维杯数学建模C题天然气水合物资源量评价解题全过程论文及程序

2024年数维杯数学建模 C题 天然气水合物资源量评价 原题再现&#xff1a; 天然气水合物&#xff08;Natural Gas Hydrate/Gas Hydrate&#xff09;即可燃冰&#xff0c;是天然气与水在高压低温条件下形成的类冰状结晶物质&#xff0c;因其外观像冰&#xff0c;遇火即燃&#…

Linux与HTTP中的Cookie和Session

HTTP中的Cookie和Session 本篇介绍 前面几篇已经基本介绍了HTTP协议的大部分内容&#xff0c;但是前面提到了一点「HTTP是无连接、无状态的协议」&#xff0c;那么到底有什么无连接以及什么是无状态。基于这两个问题&#xff0c;随后解释什么是Cookie和Session&#xff0c;以…

linux 备份工具,常用的Linux备份工具及其备份数据的语法

在Linux系统中&#xff0c;备份数据是确保数据安全性和完整性的关键步骤。以下是一些常用的Linux备份工具及其备份数据的语法&#xff1a; 1. tar命令 tar命令是Linux系统中常用的打包和压缩工具&#xff0c;可以将多个文件或目录打包成一个文件&#xff0c;并可以选择添加压…

C++核心语法快速整理

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文主要为学过多门语言玩家快速入门C 没有基础的就放弃吧。 全部都是精华&#xff0c;看完能直接上手改别人的项目。 输出内容 std::代表了这里的cout使用的标准库&#xff0c;避免不同库中的相同命名导致混乱 …

使用HAI来打通DeepSeek的任督二脉

一、什么是HAI HAI是一款专注于AI与科学计算领域的云服务产品&#xff0c;旨在为开发者、企业及科研人员提供高效、易用的算力支持与全栈解决方案。主要使用场景为&#xff1a; AI作画&#xff0c;AI对话/写作、AI开发/测试。 二、开通HAI 选择CPU算力 16核32GB&#xff0c;这…

【day2】数据结构刷题 栈

一 有效的括号 给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &#xff0c;判断字符串是否有效。 有效字符串需满足&#xff1a; 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的…

第16章:基于CNN和Transformer对心脏左心室的实验分析及改进策略

目录 1. 项目需求 2. 网络选择 2.1 UNet模块 2.2 TransUnet 2.2.1 SE模块 2.2.2 CBAM 2.3 关键代码 3 对比试验 3.1 unet 3.2 transformerSE 3.3 transformerCBAM 4. 结果分析 5. 推理 6. 下载 1. 项目需求 本文需要做的工作是基于CNN和Transformer的心脏左心室…

云上 Redis 迁移至本地机房

文章目录 摘要在 IDC 搭建读写分离 redis 集群一、环境准备二、部署主从架构1. 安装Redis2. 配置主节点3. 配置从节点4. 所有 Redis 节点设置开机自启动三、部署代理层(读写分离)1. 安装Twemproxy2. 配置Twemproxy3. 配置开机自启动四、高可用配置(哨兵模式)1. 配置哨兵节点…

SQL Server——表数据的插入、修改和删除

目录 一、引言 二、表数据的插入、修改和删除 &#xff08;一&#xff09;方法一&#xff1a;在SSMS控制台上进行操作 1.向表中添加数据 2.对表中的数据进行修改 3.对表中的数据进行删除 &#xff08;二&#xff09;方法二&#xff1a;使用 SQL 代码进行操作 1.向表中添…

deepSeek-SSE流式推送数据

1、背景 DeepSeek作为当前最火的AI大模型&#xff0c; 使用的时候用户在输入框输入问题&#xff0c;大模型进行思考回答你&#xff0c;然后会有一个逐步显示的过程效果&#xff0c;而不是一次性返回整个答案给前端页面进行展示&#xff0c;为了搞清楚其中的原理&#xff0c;我们…

【北京迅为】iTOP-RK3568开发板OpenHarmony系统南向驱动开发UART接口运作机制

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

C#实现自己的Json解析器(LALR(1)+miniDFA)

C#实现自己的Json解析器(LALR(1)miniDFA) Json是一个用处广泛、文法简单的数据格式。本文介绍如何用bitParser&#xff08;拥有自己的解析器&#xff08;C#实现LALR(1)语法解析器和miniDFA词法分析器的生成器&#xff09;迅速实现一个简单高效的Json解析器。 读者可在&#xf…

机器学习——KNN数据均一化

在KNN&#xff08;K-近邻&#xff09;算法中&#xff0c;数据均一化&#xff08;归一化&#xff09;是预处理的关键步骤&#xff0c;用于消除不同特征量纲差异对距离计算的影响。以下是两种常用的归一化操作及其核心要点&#xff1a; 质押 一 、主要思想 1. 最值归一化&#…

异步编程与流水线架构:从理论到高并发

目录 一、异步编程核心机制解析 1.1 同步与异步的本质区别 1.1.1 控制流模型 1.1.2 资源利用对比 1.2 阻塞与非阻塞的技术实现 1.2.1 阻塞I/O模型 1.2.2 非阻塞I/O模型 1.3 异步编程关键技术 1.3.1 事件循环机制 1.3.2 Future/Promise模式 1.3.3 协程&#xff08;Cor…

哈尔滨工业大学DeepSeek公开课人工智能:大模型原理 技术与应用-从GPT到DeepSeek|附视频下载方法

导 读INTRODUCTION 今天继续哈尔滨工业大学车万翔教授带来了一场主题为“DeepSeek 技术前沿与应用”的报告。 本报告深入探讨了大语言模型在自然语言处理&#xff08;NLP&#xff09;领域的核心地位及其发展历程&#xff0c;从基础概念出发&#xff0c;延伸至语言模型在机器翻…

Excel处理控件Spire.XLS系列教程:C# 在 Excel 中添加或删除单元格边框

单元格边框是指在单元格或单元格区域周围添加的线条。它们可用于不同的目的&#xff0c;如分隔工作表中的部分、吸引读者注意重要的单元格或使工作表看起来更美观。本文将介绍如何使用 Spire.XLS for .NET 在 C# 中添加或删除 Excel 单元格边框。 安装 Spire.XLS for .NET E-…

Web开发-JS应用NodeJS原型链污染文件系统Express模块数据库通讯

知识点&#xff1a; 1、安全开发-NodeJS-开发环境&功能实现 2、安全开发-NodeJS-安全漏洞&案例分析 3、安全开发-NodeJS-特有漏洞 node.js就是专门运行javascript的一个应用程序&#xff0c;区别于以往用浏览器解析原生js代码&#xff0c;node.js本身就可以解析执行js代…

国产达梦(DM)数据库的安装(Linux系统)

目录 一、安装前的准备工作 1.1 导包 1.2 创建用户和组 1.3 修改文件打开最大数 1.4 目录规划 1.5 修改目录权限 二、安装DM8 2.1 挂载镜像 2.2 命令行安装 2.3 配置环境变量 2.4 启动图形化界面 三、配置实例 四、注册服务 五、启动 停止 查看状态 六、数据库客…