一文搞懂pytorch hook机制

news2024/12/23 19:34:58

pytorch的hook机制允许我们在不修改模型class的情况下,去debug backward、查看forward的activations和修改梯度。hook是一个在forward和backward计算时可以被执行的函数。在pytorch中,可以对Tensornn.Module添加hook。hook有两种类型,forward hookbackward hook

1. 对Tensors添加hook

对于Tensors来说,只有backward hook,没有forward hook。对于backward hook来说,其函数输入输出形式是 hook(grad) -> Tensor or None。其中,grad是pytorch执行backward之后,一个tensor的grad属性值。

例如:

import torch 
a = torch.ones(5)
a.requires_grad = True

b = 2*a
c = b.mean()
c.backward()

print(f'a.grad = {a.grad}, b.grad = {b.grad}')

输出:

a.grad = tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]), b.grad = None

由于b不是叶子节点,因此在计算完梯度后,b的grad会被释放。因此,b.grad=None。这里,我们要显式的指定不释放掉非叶子节点的grad。代码改为下面这样:

import torch 
a = torch.ones(5)
a.requires_grad = True

b = 2*a

b.retain_grad()   # 让非叶子节点b的梯度保持
c = b.mean()
c.backward()

print(f'a.grad = {a.grad}, b.grad = {b.grad}')

输出:

a.grad = tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]), b.grad = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])

我们可以通过加print的方式来查看一个tensor的梯度值,也可以通过加hook的方式来实现这点。

import torch

a = torch.ones(5)

a.requires_grad = True

b = 2*a

a.register_hook(lambda x:print(f'a.grad = {x}'))
b.register_hook(lambda x: print(f'b.grad = {x}'))  

c = b.mean()

c.backward() 

输出:

b.grad = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
a.grad = tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])

使用hook的一个好处是:代码中的b.retain_grad() # 让非叶子节点b的梯度保持 这句可以删除掉,同样可以记录到非叶子节点的值。对于不方便修改源码的程序,可以通过对tensors添加hook查看梯度。同时,.retain_grad()操作会增加显存的使用。

另外一点对Tensors使用hook的好处是,可以对backward时的梯度进行修改。来看一个更加实际具体的例子:

import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  def forward(self, x):
    x = self.relu(self.conv(x))
    
    # 修改反向传播时,conv输出的梯度不小于0
    x.register_hook(lambda grad : torch.clamp(grad, min = 0))
      
    # 打印确认是否有小于0的梯度
    x.register_hook(lambda grad: print("Gradients less than zero:", bool((grad < 0).any())))  
    return self.fc1(self.flatten(x))
  

net = myNet()

for name, param in net.named_parameters():
  # 使用named_parameters对fc和bias添加修改,使其梯度全部为0
  if "fc" in name and "bias" in name:
    param.register_hook(lambda grad: torch.zeros(grad.shape))


out = net(torch.randn(1,3,8,8)) 

(1 - out).mean().backward()

print("The biases are", net.fc1.bias.grad)

输出为:

Gradients less than zero: False
The biases are tensor([0., 0., 0., 0., 0.])

2. 对nn.Module添加hook

对nn.Module添加hook的函数输入输出形式为:

backward hook:hook(module, grad_input, grad_output) -> Tensor or None

forward hook:hook(module, input, output) -> None

对nn.Module添加backward hook,非常容易造成困扰。看下面的例子:

import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    return self.fc1(self.flatten(x))
  

net = myNet()

def hook_fn(m, i, o):
  print(m)
  print("------------Input Grad------------")

  for grad in i:
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")

  print("------------Output Grad------------")
  for grad in o:  
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")
  
net.conv.register_backward_hook(hook_fn)
net.fc1.register_backward_hook(hook_fn)
inp = torch.randn(1,3,8,8)
out = net(inp)

(1 - out.mean()).backward()

输出为:

Linear(in_features=160, out_features=5, bias=True)
------------Input Grad------------
torch.Size([5])
torch.Size([5])
------------Output Grad------------
torch.Size([5])


Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
------------Input Grad------------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
------------Output Grad------------
torch.Size([1, 10, 4, 4])

可以看到对nn.Module添加的backward hook,对于Input Grad和Output Grad,对于弄清其具体指代的梯度,是比较难以搞清楚的。

对nn.Module添加forward hook,对于我们查看每层的激活值(输出,activations)是非常方便的。

import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
    self.seq = nn.Sequential(nn.Linear(5,3), nn.Linear(3,2))
    
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    x = self.fc1(self.flatten(x))
    x = self.seq(x)
  

net = myNet()
visualisation = {}

def hook_fn(m, i, o):
  visualisation[m] = o 

def get_all_layers(net):
  for name, layer in net._modules.items():
    #If it is a sequential, don't register a hook on it
    # but recursively register hook on all it's module children
    if isinstance(layer, nn.Sequential):
      get_all_layers(layer)
    else:
      # it's a non sequential. Register a hook
      layer.register_forward_hook(hook_fn)

get_all_layers(net)

  
out = net(torch.randn(1,3,8,8))

# Just to check whether we got all layers
print(visualisation.keys())      #output includes sequential layers
print(visualisation)

输出为:

dict_keys([Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2)), ReLU(), Linear(in_features=160, out_features=5, bias=True), Linear(in_features=5, out_features=3, bias=True), Linear(in_features=3, out_features=2, bias=True)])

{Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2)): tensor([[[[ 0.8381,  0.3751,  0.0268, -0.1155],
           [-0.2221,  1.1316,  1.1800, -0.1370],
           [ 1.1750, -0.6800, -0.1855,  0.3174],
           [-0.3929,  0.1941,  0.8611, -0.4447]],
 
          [[ 0.2377,  0.5215,  1.2715, -0.1600],
           [-0.7852, -0.2954, -0.0898,  0.0045],
           [-0.6077, -0.0088, -0.0572, -0.4161],
           [-0.6604,  0.7242, -0.7878,  0.0525]],
 
          [[-0.7283, -0.2644, -1.0609,  0.4960],
           [ 0.7989, -1.2582, -0.4996,  0.4377],
           [ 0.0798,  1.3804, -0.2886, -0.1540],
           [ 1.4034, -0.6836, -0.0658,  0.5268]],
 
          [[-0.6073, -0.3875, -0.3015,  0.7174],
           [-1.2842,  0.7734, -0.6014,  0.4114],
           [-0.3582, -1.4564, -0.6590, -1.0223],
           [-0.7667,  0.6816,  0.0602, -0.2622]],
 
          [[-0.6175, -0.3179, -1.2208, -0.8645],
           [ 1.1918, -0.3578, -0.7223, -1.1834],
           [ 0.1654, -0.1522,  0.0066,  0.0934],
           [ 0.7423, -0.7827,  0.2465,  0.4299]],
 
...
           [0.5625, 0.4753, 0.0000, 0.0000],
           [0.6904, 0.1533, 0.6416, 0.0000]]]], grad_fn=<ReluBackward0>),
 Linear(in_features=160, out_features=5, bias=True): tensor([-0.0816, -0.1588, -0.0201, -0.4695,  0.2911], grad_fn=<AddBackward0>),
 Linear(in_features=5, out_features=3, bias=True): tensor([-0.3199,  0.0220, -0.3564], grad_fn=<AddBackward0>),
 Linear(in_features=3, out_features=2, bias=True): tensor([ 0.5371, -0.5260], grad_fn=<AddBackward0>)}

下面通过一个例子来展示forward hook以及对hook出的activation进行可视化。

import torch
from torchvision.models import resnet34
from PIL import Image
from torchvision import transforms as T
import matplotlib.pyplot as plt


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = resnet34(pretrained=True)
model = model.to(device)

# 定义hook
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
        
# 对Conv2d注册hook
save_output = SaveOutput()
hook_handles = []
for layer in model.modules():
    if isinstance(layer, torch.nn.modules.conv.Conv2d):
        handle = layer.register_forward_hook(save_output)
        hook_handles.append(handle)


image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)

out = model(X)

print(len(save_output.outputs))  # 输出应该是36


def module_output_to_numpy(tensor):
    return tensor.detach().to('cpu').numpy()    

images = module_output_to_numpy(save_output.outputs[0])

with plt.style.context("seaborn-white"):
    plt.figure(figsize=(20, 20), frameon=False)
    for idx in range(64):   # 这里根据输出通道数,不止可以索引到64,可以通过打印images的channels来查看最大的输出通道数
        plt.subplot(8, 8, idx+1)
        plt.imshow(images[0, idx])
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);

matplotlib画出第一层的activation为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们修改代码如下,来查看比较靠后层的activation:

images = module_output_to_numpy(save_output.outputs[30]) # 将此处的索引改为30,查看第30层的activation

with plt.style.context("seaborn-white"):
    plt.figure(figsize=(20, 20), frameon=False)
    for idx in range(64):   # 这里根据输出通道数,不止可以索引到64,可以通过打印images的channels来查看最大的输出通道数
        plt.subplot(8, 8, idx+1)
        plt.imshow(images[0, idx])
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);

我们同样查看中间层,例如第15层的activation。

可以看到随着网络层的加深,activation越来越抽象。

除了上述的对forward加hook查看activation、对backward加hook、对Tensors加hook进行梯度相关的操作外,还可以参考kaggle的文章进行一些更深层次的理解,比如对backward过程的详细解释以及配合backward hook使用GRAD-CAM来查看网络等方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1067177.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开环模块化多电平换流器仿真(MMC)N=6(Simulink仿真)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

【Vue面试题六】为什么Vue中的 v-if 和 v-for 不建议一起用?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;v-if和v-for的优先级是什…

企业建设数字化工厂的四个要点

在当今的制造业领域&#xff0c;数字化技术的应用越来越广泛&#xff0c;数字化工厂管理系统的概念也随之兴起。数字化工厂是一种全新的生产模式&#xff0c;它将信息技术、制造技术和网络技术深度融合&#xff0c;实现了从产品设计到生产制造再到企业管理全过程数字化。本文将…

stack和queque

1.stack 1.1定义 T 是容器内的数据类型&#xff1b; Container是数据类型的容器适配器 vector和list和stack的区别 1.2 stack的功能 注意这里没有迭代器&#xff1b;原因stack是先进后出的规律&#xff1b;这就规定该容器不可以随机访问&#xff1b; 2. queue

热迁移中VirtIO-PCI设备的配置空间处理

文章目录 问题现象定位过程日志分析源端目的端 原理分析基本原理上下文分析复现分析patch分析 总结解决方案 问题现象 集群升级虚拟化组件版本&#xff0c;升级前存量运行并挂载了virtio磁盘的虚拟机集群内热迁移到升级后的节点失败&#xff0c;QEMU报错如下&#xff1a; 202…

KdMapper扩展实现之Dell(pcdsrvc_x64.pkms)

1.背景 KdMapper是一个利用intel的驱动漏洞可以无痕的加载未经签名的驱动&#xff0c;本文是利用其它漏洞&#xff08;参考《【转载】利用签名驱动漏洞加载未签名驱动》&#xff09;做相应的修改以实现类似功能。需要大家对KdMapper的代码有一定了解。 2.驱动信息 驱动名称pcds…

数据中台实战(11)-数据中台的数据安全解决方案

0 微盟删库跑路 除了快、准和省&#xff0c;数据中台须安全&#xff0c;避免“微盟删库跑路”。 2020年2月23日19点&#xff0c;国内最大精准营销服务商微盟出现大面积系统故障&#xff0c;旗下300万商户线上业务全停&#xff0c;商铺后台所有数据被清。始作俑者是一位运维&a…

Java常见设计模式

单例模式&#xff1a;程序自始至终只创建一个对象。 应用场景&#xff1a;1.整个程序运行中只允许一个类的实例时 2.需要频繁实例化然后销毁的对象 3.创建对象时耗时过多但又经常用到的对象 4.方便资源相互通信的环境 懒汉式线程不安全问题解决方案&#xff1a; 双重检查加锁机…

HTTPS 加密工作过程

引言 HTTP 协议内容都是按照文本的方式明文传输的&#xff0c;这就导致在传输过程中出现一些被篡改的情况。例如臭名昭著的运营商劫持。显然&#xff0c; 明文传输是比较危险的事情&#xff0c;为此引入 HTTPS &#xff0c;HTTPS 就是在 HTTP 的基础上进行了加密, 进一步的来保…

计算机毕业设计 基于Java的同城宠物帮(宠物领养平台\系统)的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

SpringSecurity源码学习一:过滤器执行原理

目录 1. web过滤器Filter1.1 filter核心类1.2 GenericFilterBean1.3 DelegatingFilterProxy1.3.1 原理1.3.2 DelegatingFilterProxy源码 2. FilterChainProxy源码学习2.1 源码2.1.1 doFilterInternal方法源码2.1.1.1 getFilters()方法源码2.1.1.2 VirtualFilterChain方法源码 3…

C++学习day1

一>整理思维导图 二>提示并输入一个字符串 &#xff0c;统计活该字符串中的大写&#xff0c;小写&#xff0c;数字&#xff0c;空格&#xff0c;以及其他字符的个数要求使用c风格符完成 #include <iostream>int main() {std::string input;std::cout << &qu…

网络和系统操作命令

目录 ping&#xff1a;用于检测网络是否通畅&#xff0c;以及网络时延情况。ipconfig&#xff1a;查看计算机的IP参数配置信息&#xff0c;如IP地址、默认网关、子网掩码等信息。netstat&#xff1a;显示协议统计信息和当前TCP/IP网络连接。tasklist&#xff1a;显示当前运行的…

Go复合类型之数组类型

Go复合类型之数组 文章目录 Go复合类型之数组一、数组(Array)介绍1.1 基本介绍1.2 数组的特点 二、数组的声明与初始化2.1 数组声明2.2 常见的数据类型声明方法2.3 数组的初始化方式一&#xff1a;使用初始值列表初始化数组方法二&#xff1a;根据初始值个数自动推断数组长度方…

温故知新:dfs模板-843. n-皇后问题

n−n−皇后问题是指将 nn 个皇后放在 nnnn 的国际象棋棋盘上&#xff0c;使得皇后不能相互攻击到&#xff0c;即任意两个皇后都不能处于同一行、同一列或同一斜线上。 现在给定整数 nn&#xff0c;请你输出所有的满足条件的棋子摆法。 输入格式 共一行&#xff0c;包含整数 n…

通过位运算,实现单字段标识多个状态位

可能经常有如下这种需求: 需要一张表,来记录学员课程的通过与否. 课程数量不确定,往往很多,且会有变动,随时可能新增一门课. 这种情况下,在设计表结构时,一门课对应一个字段,就有些不合适, 因为不知道课程的具体数量,也无法应对后期课程的增加. 考虑只用一个状态标志位,利用位运…

精确到区县级街道乡镇行政边界geojson格式矢量数据的获取拼接实现Echarts数据可视化大屏地理坐标信息地图的解决方案

在Echarts制作地理信息坐标地图时&#xff0c;最麻烦的就是街道乡镇级别的行政geojson的获取&#xff0c; 文件大小 788M 文件格式 .json格式&#xff0c;由于是大文件数据&#xff0c;无法直接使用记事本或者IDE编辑器打开&#xff0c;推荐Dadroit Viewer&#xff08;国外…

【代码实践】HAT代码Window平台下运行实践记录

HAT是CVPR2023上的自然图像超分辨率重建论文《activating More Pixels in Image Super-Resolution Transformer》所提出的模型。本文旨在记录在Window系统下运行该官方代码&#xff08;https://github.com/XPixelGroup/HAT&#xff09;的过程&#xff0c;中间会遇到一些问题&am…

linux系统中常见注册函数的使用方法

大家好&#xff0c;今天给大家分享一下&#xff0c;linux系统中常见的注册函数register_chrdev_region()、register_chrdev()、 alloc_chrdev_region()的使用方法​。 一、函数包含的头文件&#xff1a; 分配设备编号&#xff0c;注册设备与注销设备的函数均在fs.h中申明&…

OMV6 安装Extras 插件失败的解决方法

# Time: 2023/10/07 #Author: Xiaohong # 运行环境: OS: OMV6 # 功能: 安装Extras 插件失败的解决方法 问题描述&#xff1a;OMV6 安装插件omv-extras&#xff0c;只能按如下提示的命令行&#xff0c;但安装过程中&#xff0c;会提示raw.githubusercontent.com 无法访问插…