PyTorch使用细节

news2025/1/25 4:46:58

model.eval() :让BatchNorm、Dropout等失效;

with torch.no_grad() : 不再缓存activation,节省显存;

这是矩阵乘法:

y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)

y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

这是点乘:

z1 = tensor * tensor
z2 = tensor.mul(tensor)

z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

Tensor如果是1*1大小的,可以转为普通Python变量

agg = tensor.sum()
agg_item = agg.item()

Tensor和numpy之间,是share内存的,改一个另一个也被改动

n = torch.ones(5).numpy()

n = np.ones(5)
t = torch.from_numpy(n)

root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Dataset类:通过index,拿到1条数据;

        数据可以都在磁盘上,用到哪条,就加载哪条;

        自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__

DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...

ToTensor(): 把PIL格式的Image,转成Tensor;

Lambda: 把int的y,转成10维度的1-hot向量;

一切模型层,皆继承自torch.nn.Module

class NeuralNetwork(nn.Module):

Module必须copy到device上

model = NeuralNetwork().to(device)

input data也必须copy到device上

X = torch.rand(1, 28, 28, device=device)

不能直接使用Module.forward,使用Module(input)语法可以使前后的hook起作用

logits = model(X)

model.parameters(): 可训练的参数;

model.named_parameters(): 可训练的参数;包含名称;

state_dict: 可训练的参数、不可训练的参数,都有;

继承自Function类,可以写自定义的forward和backward,input或output可以放在ctx里:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

 构造计算图:

Tensor的几大成员:grad, grad_fn, is_leaf, requires_grad

Tensor.grad_fn,就是用于backward梯度计算的Function:

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

# Output:
Gradient function for z = <AddBackward0 object at 0x7f5e9fb64e20>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f5e99b11b40>

backward时,注意,是累积加和到Tensor.grad上;这样,链式法则有些地方就是要加和的,accumulate step也可以实现;

只有满足这个条件的才会累积其grad: is_leaf==True && requires_grad==True

只有requires_grad==True,但is_leaf==False,则会将梯度传播给上游,自己的grad成员无值;

只用来inference时,可用"with torch.no_grad()"控制其不生成计算图:(好处:forward速度变快一点儿,不保存activation至ctx节省显存)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

Output: False

某些模型训练,有些parameter要设成frozen不参与权重更新,则手工设其requires_grad=False即可。

用detach()来创造数据引用,脱离了原计算图,原计算图可以被垃圾回收了:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

Output: False

backward DAG,在每次forward阶段,都会被重新搭建;所以每个step,计算图可以任意变化(例如根据Tensor的值来走不同的control flow)

向量对向量求偏导,得到的是雅克比矩阵:

以下例子演示:雅克比矩阵、梯度累积、zero_grad

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

Output:

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

optimizer使用例子

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

def train(...):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()  # 将所有Tensor.grad清0

torch.save: 使用Python的pickle,将一个dict进行序列化,并存至文件;

torch.load: 读取文件,使用Python的pickle,将字节数组进行反序列化,至一个dict;

torch.nn.Module.state_dict: 一个Python的dict,key是字符串,value是Tensor;包含可学习的parameters,不可学习的buffers(例如batch normalization需要的running mean);

optimizer也有state_dic(learning rate,冲量等)

save下来仅仅用于推理:(注意:必须model.eval(),否则dropout、BN,会出毛病)

# save:
torch.save(model.state_dict(), PATH)

# load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

save下来可用于继续训练:

# save:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

使用state_dict方式,load之前,model必须初始化好(内存已经被parameters占住了,只是权重是随机的)

map_location、model.to(device)等:Saving and Loading Models — PyTorch Tutorials 2.3.0+cu121 documentation

小众用法:(model不用初始化)

# save:
torch.save(model, PATH)

# load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1935751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

破解反爬虫策略 /_guard/auto.js(一) 原理

背景 当用代码或者postman访问一个网站的时候&#xff0c;访问他的任何地址都会返回<script src"/_guard/auto.js"></script>&#xff0c;但是从浏览器中访问显示的页面是正常的&#xff0c;这种就是网站做了反爬虫策略。本文就是带大家来破解这种策略&…

USB3200N模拟信号采集卡12位8路500K采样带DIO带计数器

1、概述&#xff1a; USB3200N多功能数据采集卡&#xff0c;LabVIEW无缝连接&#xff0c;提供图形化API函数&#xff0c;提供8通道&#xff08;RSE、NRSE&#xff09;、4通道&#xff08;DIFF&#xff09;模拟量输入&#xff0c;4路可编程数字I/O&#xff0c;1路计数器。 USB3…

C/C++蓝屏整人代码

文章目录 &#x1f4d2;程序效果 &#x1f4d2;具体步骤 1.隐藏任务栏 2.调整cmd窗口大小 3.调整cmd窗口屏幕颜色 4.完整代码 &#x1f4d2;代码详解 &#x1f680;欢迎互三&#x1f449;&#xff1a;程序猿方梓燚 &#x1f48e;&#x1f48e; &#x1f680;关注博主&a…

前端实现视频播放添加水印

一、效果如下 二、代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Document</title> </head> <body><style>.container {position: relative;}.base {width: 300px;hei…

VTD学习笔记(一)-启动vtd、基本界面和按钮

写在前面&#xff1a;真快啊&#xff0c;眨眼就毕业上班了&#xff0c;岗位也是做仿真&#xff0c;看来以后就是一直做仿真了&#xff0c;再见了定位~。公司使用的是vtd&#xff0c;看资料是一个很庞大的自动驾驶仿真软件&#xff0c;囊括了车辆动力学到传感器仿真&#xff0c;…

基于Java技术的智慧外贸平台

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;B/S模式、Java技术、SpringBoot框架 工具&#xff1a;Eclipse、MySQL数据库开发工具 系统展示 首…

【网络工具】Charles 实战(下)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/iAmAo &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会整理一些工作或学习中用到的工具介绍给大家~ &#x1f4d8;Charles 系列文章&#xff1a; 【网络工…

昇思25天学习打卡营第02天|张量Tensor

这节学习的张量&#xff08;Tensor&#xff09;的内容进行总结 &#xff1a; 1、张量的概念&#xff1a;张量是一种多线性函数&#xff0c;可以表示矢量、标量和其他张量之间的线性关系。张量是MindSpore网络运算中的基本数据结构&#xff0c;类似于数组和矩阵。 2、可以通过直…

【同行案例】亚马逊精铺卖家,2年跨境选品思路分享!

店雷达年度商家实战经验分享又来啦&#xff01;希望给各位商友一些选品思路参考。该商家主做亚马逊&#xff0c;2年跨境经验&#xff0c;主营类目艺术品&#xff0c;精铺模式。 一、亚马逊选品思路分享 ☛对于中小卖家&#xff0c;选择月销量300-1000可以较好平衡投入和产出&…

Langchain-Chatchat3.1版本docker部署流程——知识库问答

Langchain——chatchat3.1版本docker部署流程Langchain-Chatchat 1. 项目地址 #项目地址 https://github.com/chatchat-space/Langchain-Chatchat #dockerhub地址 https://hub.docker.com/r/chatimage/chatchat/tags2. docker部署 参考官方文档 #官方文档 https://github.c…

.env.local 配置本地环境变量 用于团队开发

.env.local 用途&#xff1a;.env.local 通常用于存储本地开发环境中的环境变量。这些变量可能包括敏感数据或特定于单个开发者的设置&#xff0c;不应该被提交到版本控制系统中。优先级&#xff1a;在大多数框架中&#xff0c;.env.local 文件中的变量会覆盖其他 .env 文件中…

【Git远程操作】向远程仓库推送 | 拉取远程仓库

目录 1.向远程仓库推送 ​1.1本地仓库的配置 1.2remote-gitcode本地仓库 1.3推送至远程仓库 2.拉取远程仓库 现阶段以下操作仅在master主分支上。 1.向远程仓库推送 工作区☞add☞暂存区☞commit☞本地仓库☞推送push☞远程仓库注意&#xff1a;本地仓库的某个分支 ☞推…

在Ubuntu上安装redis

Ubuntu上安装redis 一、通过下载redis的压缩包安装二、通过apt包管理器安装Redis三、修改redis的配置文件四、控制redis启动 Redis是一种开源的内存数据存储&#xff0c;可以用作数据库、缓存和消息代理等。本文将会介绍两种不同的安装方式&#xff0c;包括通过压缩包安装以及通…

钡铼Profinet、EtherCAT、Modbus、MQTT、Ethernet/IP、OPC UA分布式IO系统BL20X系列耦合器

BL20X系列耦合器是钡铼技术开发的一款用于分布式I/O系统的设备&#xff0c;专为工业环境下的高速数据传输和远程设备控制而设计&#xff0c;支持多种工业以太网协议&#xff0c;包括Profinet、EtherCAT、Modbus、MQTT、Ethernet/IP和OPC UA等。如果您正在考虑部署BL20X系列耦合…

如何学习Spark:糙快猛的大数据之旅

作为一名大数据开发者,我深知学习Spark的重要性。今天,我想和大家分享一下我的Spark学习心得,希望能够帮助到正在学习或准备学习Spark的朋友们。 目录 Spark是什么?学习Spark的"糙快猛"之道1. 不要追求完美,在实践中学习2. 利用大模型作为24小时助教3. 根据自己的节…

数据结构课程设计:源代码(C)客房信息管理系统

main.c #include <unistd.h> #include "SeqList.h" #include "User.h"int main() {SL user;SLInit(&user);char ans 0;printf("是否需要导入昨日续住客人的数据&#xff1a;y/n\n");scanf(" %c", &ans);if (ans y){L…

JVM(day2)经典垃圾收集器

经典垃圾收集器 Serial收集 使用一个处理器或一条收集线程去完成垃圾收集工作&#xff0c;更重要的是强调在它进行垃圾收集时&#xff0c;必须暂停其他所有工作线程&#xff0c;直到它收集结束。 ParNew收集器 ParNew 收集器除了支持多线程并行收集之外&#xff0c;其他与 …

C++写一个线程池

C写一个线程池 文章目录 C写一个线程池设计思路测试数据的实现任务类的实现线程池类的实现线程池构造函数线程池入口函数队列中取任务添加任务函数线程池终止函数 源码 之前用C语言写了一个线程池&#xff0c;详情请见&#xff1a; C语言写一个线程池 这次换成C了&#xff01;…

C#知识|账号管理系统-账号信息管理界面[1]:账号分类选择框、Panel面板设置

哈喽,你好啊,我是雷工! 前一节实现了多条件查询后端代码的编写, 接下来继续学习账号信息管理界面的功能编写,本节主要记录账号分类选择框和Panel的设置, 以下为学习笔记。 01 功能说明 本节实现以下功能: ①:账号分类选择框只能选择,无法自由输入; ②:账号分类框默认…

大语言模型与扩散模型的“爱恨情仇”:Kolors和Auraflow的技术解析

近年来&#xff0c;随着深度学习技术的发展&#xff0c;生成模型在多个领域取得了显著进展。特别是大语言模型&#xff08;LLM&#xff09;和扩散模型&#xff08;Diffusion Model&#xff09;这两类模型&#xff0c;在自然语言处理&#xff08;NLP&#xff09;和图像生成任务中…