深度学习基础知识 register_buffer 与 register_parameter用法分析
- 1、问题引入
- 2、register_parameter()
- 2.1 作用
- 2.2 用法
- 3、register_buffer()
- 3.1 作用
- 3.2 用法
1、问题引入
思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习
验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数
import torch
import torch.nn as nn
# 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习
class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__()
self.conv1=nn.Conv2d(in_channels= 3,
out_channels= 6,
kernel_size=3,
stride = 1,
padding=1,
bias=False)
self.conv2=nn.Conv2d(in_channels= 6,
out_channels= 9,
kernel_size=3,
stride = 1,
padding=1,
bias=False)
self.waight=torch.ones(10,10)
self.bias=torch.zeros(10)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x = x * self.weight + self.bias
return x
net=MyModule()
for name,param in net.named_parameters(): # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数
print(name,param.shape)
print("\n","-"*40,"\n")
for key,val in net.state_dict().items(): # 说明weight与bias是不会被state_dict转化为字典中的元素的
print(key,val.shape)
打印分析结果:
可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数
2、register_parameter()
register_parameter()是 torch.nn.Module 类中的一个方法
2.1 作用
1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中
2.2 用法
register_parameter(name,param)
- name:参数名称
- param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,
否则报错如下
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)
self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x * self.weight + self.bias
return x
net = MyModule()
for name, param in net.named_parameters():
print(name, param.shape)
print('\n', '*'*40, '\n')
for key, val in net.state_dict().items():
print(key, val.shape)
结果显示:
3、register_buffer()
register_buffer()是 torch.nn.Module() 类中的一个方法
3.1 作用
-
将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习
-
self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中
它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)
- 参数:可以被优化器更新 (requires_grad=False / True)
- buffer 中的数据 : 不会被优化器更新
3.2 用法
register_buffer(name,tensor)
- name:参数名称
- tensor:张量
代码:
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)
self.register_buffer('weight', torch.ones(10, 10)) # 注意:定义的方式
self.register_buffer('bias', torch.zeros(10))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x * self.weight + self.bias
return x
net = MyModule()
for name, param in net.named_parameters():
print(name, param.shape)
print('\n', '*'*40, '\n')
for key, val in net.state_dict().items():
print(key, val.shape)
效果如下所示: