一文通俗理解torch.nn.Parameter()
一、起源
首先,我写这篇文章的起源是因为,我突然看到了一段有关torch.nn.Parameter()的代码。
因此就去了解了一下这个函数,把自己的一些理解记录下来,希望可以帮到你。
二、官方文档
网址如下:https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html#torch.nn.parameter.Parameter
具体参数解释如下:
torch.nn.parameter.Parameter(data=None, requires_grad=True)
- data:代表一个tensor类似的数据
- requires_grad:是否需要进行梯度计算,默认为True
三、个人理解
- 这个函数的主要作用就是把一个不可训练的Tensor数据转换成可以训练的Tensor数据。
那么这个这个函数怎么实现的呢。
这个函数可以将你输入的数据(你想训练的数据)加入到你模型的参数里面(因为requires_grad=True,如果为False就是不加入),跟着你模型的参数一起训练,一起学习,逐渐达到最优解。
代码实现:
self.w = nn.Parameter(torch.tensor(0.5, dtype=torch.float), requires_grad=True)
"""
初始数据为0.5,且为float类型,进行训练。
"""
四、示例程序
# -*- coding: UTF-8 -*-
# Project :python
# File :test_1.py
# IDE :PyCharm
# Author :小李同学
# Date :2023/10/21 13:44
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
# 创建一个可学习的权重参数,初始值为0.5
weight = torch.nn.Parameter(torch.tensor(0.5, requires_grad=True))
# 定义一个优化器,用于更新权重
optimizer = optim.SGD([weight], lr=0.01)
# 目标值
target = torch.tensor(20.0)
# 存储损失和权重的列表,用于绘制学习曲线
losses = []
weights = []
# 训练循环
for epoch in range(10):
# 模型的预测值
prediction = weight * 5.0 # 假设模型的预测是输入值乘以权重
# 计算损失,这里使用均方误差损失
loss = (prediction - target) ** 2
losses.append(loss.item())
weights.append(weight.item())
# 梯度清零
optimizer.zero_grad()
# 反向传播和权重更新
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}: Loss={loss.item():.2f}, Weight={weight.item():.2f}')
# 绘制学习曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.subplot(1, 2, 2)
plt.plot(weights)
plt.xlabel('Epoch')
plt.ylabel('Weight')
plt.title('Weight Curve')
plt.show()
输出结果如下:
Epoch 1: Loss=306.25, Weight=2.25
Epoch 2: Loss=76.56, Weight=3.12
Epoch 3: Loss=19.14, Weight=3.56
Epoch 4: Loss=4.79, Weight=3.78
Epoch 5: Loss=1.20, Weight=3.89
Epoch 6: Loss=0.30, Weight=3.95
Epoch 7: Loss=0.07, Weight=3.97
Epoch 8: Loss=0.02, Weight=3.99
Epoch 9: Loss=0.00, Weight=3.99
Epoch 10: Loss=0.00, Weight=4.00
学习曲线如下:
如果想获得本文的的pdf,请在公众号“冬天的李同学”上回复“2023.10.22”即可获得。
参考文章:
1.https://mp.weixin.qq.com/s/ryfSof2OrGQdJauqmTpK0A
2.https://blog.csdn.net/weixin_44878336/article/details/124733598?
ps://mp.weixin.qq.com/s/ryfSof2OrGQdJauqmTpK0A
2.https://blog.csdn.net/weixin_44878336/article/details/124733598?
3.https://blog.csdn.net/weixin_43145941/article/details/114757673?