PyTorch 的 10 条内部用法

news2024/9/22 9:56:09
alt

欢迎阅读这份有关 PyTorch 原理的简明指南[1]。无论您是初学者还是有一定经验,了解这些原则都可以让您的旅程更加顺利。让我们开始吧!

1. 张量:构建模块

PyTorch 中的张量是多维数组。它们与 NumPy 的 ndarray 类似,但可以在 GPU 上运行。

import torch

# Create a 2x3 tensor
tensor = torch.tensor([[123], [456]])
print(tensor)

2. 动态计算图

PyTorch 使用动态计算图,这意味着该图是在执行操作时即时构建的。这为在运行时修改图形提供了灵活性。

# Define two tensors
a = torch.tensor([2.], requires_grad=True)
b = torch.tensor([3.], requires_grad=True)

# Compute result
c = a * b
c.backward()

# Gradients
print(a.grad)  # Gradient w.r.t a

3.GPU加速

PyTorch 允许在 CPU 和 GPU 之间轻松切换。利用 .to(device) 获得最佳性能。

device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = tensor.to(device)

4. Autograd:自动微分

PyTorch 的 autograd 为张量上的所有操作提供自动微分。设置 require_grad=True 来跟踪计算。

x = torch.tensor([2.], requires_grad=True)
y = x**2
y.backward()
print(x.grad)  # Gradient of y w.r.t x

5. 带有 nn.Module 的模块化神经网络

PyTorch 提供 nn.Module 类来定义神经网络架构。通过子类化创建自定义层。

import torch.nn as nn

class SimpleNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(11)
        
    def forward(self, x):
        return self.fc(x)

6. 预定义层和损失函数

PyTorch 在 nn 模块中提供了各种预定义层、损失函数和优化算法。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

7. 数据集和DataLoader

为了高效的数据处理和批处理,PyTorch 提供了 Dataset 和 DataLoader 类。

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    # ... (methods to define)
    
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

8.模型训练循环

通常,PyTorch 中的训练遵循以下模式:前向传递、计算损失、后向传递和参数更新。

for epoch in range(epochs):
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

9. 模型序列化

使用 torch.save() 和 torch.load() 保存和加载模型。

# Save
torch.save(model.state_dict(), 'model_weights.pth')

# Load
model.load_state_dict(torch.load('model_weights.pth'))

10. Eager Execution and JIT

虽然 PyTorch 默认情况下以 eager 模式运行,但它为生产就绪模型提供即时 (JIT) 编译。

scripted_model = torch.jit.script(model)
scripted_model.save("model_jit.pt")

Reference

[1]

Source: https://medium.com/@kasperjuunge/10-principles-of-pytorch-bbe4bf0c42cd

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1311723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【方法】如何给PDF文件添加“限制密码”?

PDF是很多人工作中经常用到的电子文档格式,它可以保留原始文档的所有格式和布局,也不容易修改,想要编辑修改PDF文件需要用到编辑器。 但如果给PDF文件添加“限制密码”,就可以保护文件不被随意修改,即使使用编辑器也需…

基于PCIe的NVMe学习

一:基本概念 1.UltraScale:是Xilinx ZYNQ 系列产品 2.spec:大家现在别纠结于具体的命令,了解一下就好。老板交代干活的时候,再找spec一个一个看吧————猜测估计是命令表之类的。 3.TLP报文部分: 二&…

STM32--中断使用(超详细!)

写在前面:前面的学习中,我们接触了STM32的第一个外设GPIO,这也是最常用的一个外设;而除了GPIO外,中断也是一个十分重要且常用的外设;只有掌握了中断,再处理程序时才能掌握好解决实际问题的逻辑思…

云上丝绸之路| 云轴科技ZStack成功实践精选(西北)

古有“丝绸之路” 今有丝绸之路经济带 丝路焕发新生,数智助力经济 云轴科技ZStack用“云”护航千行百业 沿丝绸之路,领略西北数字化。 古丝绸之路起点-陕西 集历史与现代交融,不仅拥有悠久的历史文化积淀,而且现代化、数字化发…

论文润色降重哪个平台好 papergpt

大家好,今天来聊聊论文润色降重哪个平台好,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧: 标题:论文润色降重哪个平台好――专业、高效、可靠的学术支持 一、引言 在学术研究中&…

SQLE 3.0 部署实践

来自 1024 活动的投稿系列 第一篇《SQLE 3.0 部署实践》 . 作者:张昇,河北东软软件有限公司高级软件工程师,腾讯云社区作者。 爱可生开源社区出品,原创内容未经授权不得随意使用,转载请联系小编并注明来源。 本文共 32…

机器学习--归一化处理

归一化 归一化的目的 归一化的一个目的是,使得梯度下降在不同维度 θ \theta θ 参数(不同数量级)上,可以步调一致协同的进行梯度下降。这就好比社会主义,一小部分人先富裕起来了,先富带后富&#xff0c…

XUbuntu22.04之npm解决pm WARN deprecated(一百九十九)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

基于vue实现的疫情数据可视化分析及预测系统-计算机毕业设计推荐django

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

Kafka-Kafka基本原理与集群快速搭建

一、Kafka介绍 ​ ChatGPT对于Apache Kafka的介绍: Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发并于2011年开源。它主要用于解决大规模数据的实时流式处理和数据管道问题。 Kafka是一个分布式的发布-订阅消息系统,可以快速地处理…

从计算机底层深入Golang高并发

从计算机底层深入Golang高并发 1.源码流程架构图 2.源码解读 runtime/proc.go下的newpro() func newproc(fn *funcval) {//计算额外参数的地址argpgp : getg()pc : getcallerpc()//s1使用systemstack调用newproc1 systemstack(func() {newg : newproc1(fn, gp, pc)_p_ : getg…

经典文献阅读之--SST-Calib(激光雷达与相机的同步时空参数标定法)

0. 简介 借助多种输入模态的信息,基于传感器融合的算法通常优于单模态。具有互补语义和深度信息的相机和激光雷达是复杂驾驶环境中的典型传感器配置。然而,对于大多数相机和激光雷达融合的算法,传感器的标定将极大地影响性能。具体来说&…

RabbitMq的详细使用

消息队列RabbitMQ详细使用 文章目录 消息队列RabbitMQ详细使用MQ 的相关概念什么是MQ为什么要用MQMQ 的分类MQ 的选择 RabbitMQRabbitMQ 的概念四大核心概念各个名词介绍安装RabbitMQWeb管理界面及授权操作Docker 安装Hello world简单示例 Work Queues轮训分发消息消息应答自动…

JWT令牌的作用和生成

JWT令牌(JSON Web Token)是一种用于身份验证和授权的安全令牌。它由三部分组成:头部、载荷和签名。 JWT令牌的作用如下: 身份验证:JWT令牌可以验证用户身份。当用户登录后,服务器会生成一个JWT令牌并返回…

Turtle绘制菱形-第11届蓝桥杯选拔赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第16讲。 Turtle绘制菱形&a…

bugku--- 比赛真题1-3

第一题 查看源代码直接就有 第二题 万能密码直接填 第三题

Vue2.x源码:new Vue()做了啥

例子1new Vue做了啥?new Vue做了啥,源码解析 initMixin函数 初始化 – 初始化Vue实例的配置initLifecycle函数 – 初始化生命周期钩子函数initEvents – 初始化事件系统初始化渲染 initRender初始化inject选项 例子1 <div id"app"><div class"home&…

JVM之堆学习

一、Java虚拟机内存结构图 二、堆的介绍 1. 前面学习的程序计数器&#xff0c;虚拟机栈和本地方法栈都是线程私有的&#xff0c;堆是线程共享的&#xff1b; 2. 通过 new 关键字&#xff0c;创建的对象都会使用堆内存&#xff0c;其特点是&#xff1a; 它是线程共享的&#x…

pytorch文本分类(二):引入pytorch处理文本数据

pytorch文本数据处理 目录 pytorch文本数据处理1. Pytorch背景2. 数据分割3. 数据加载Dataset代码分析字典的用途代码修改的目的 Dataloader 4. 练习 原学习任务链接 相关数据链接&#xff1a;https://pan.baidu.com/s/1iwE3LdRv3uAkGGI2fF9BjA?pwdro0v 提取码&#xff1a;ro…

flume系列之:监控flume agent channel的填充百分比

flume系列之:监控flume agent channel的填充百分比 一、监控效果二、获取flume agent三、飞书告警四、获取每个flume agent channel的填充百分比一、监控效果 二、获取flume agent def getKafkaFlumeAgent():# 腾讯云10.130.112.60zk = KazooClient(hosts