本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。
目录
- 官方定义
- demo1
- demo2
官方定义
nn.Linear
是 PyTorch 中用于创建线性层的类。线性层也被称为全连接层,它将输入与权重矩阵相乘并加上偏置,然后通过激活函数进行非线性变换。
官方的文档如下,torch.nn.Linear:
demo1
下面是一个官方文档给出的例子:
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())
输出的结果:
首先,输出[128, 20]的张量,经过一个[20, 30]的线性层,变成[128, 30]的张量。
可以理解为矩阵的乘法,也就是矩阵的"外积",矩阵的叉乘,第一个矩阵的行数与第二个矩阵的列数相同。
demo2
input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3]
m = nn.Linear(3, 2)
output = m(input_data)
print(output) # [2, 2]
输出:
可以看看nn.Linear(3, 2)的参数:
for param in m.parameters():
print(param)
输出:
结合参数,其实本身它们的计算就是矩阵的乘法:
输入X为[n, i]的矩阵,经过W为[i,0]的矩阵,加上b的偏置得到Y为[n,o]的矩阵。
计算的思路也比较简单:
output[0][0]
= [1, 2, 3] * [0.2888, -0.4596, -0,4896] + 0.3740 = -1.7253
output[0][1]
= [1, 2, 3] * [0.4730, -0.4033, -0.4739] + 0.3182 = -1.4370
output[1][0]
= [4, 5, 6] * [0.2888, -0.4596, -0,4896] + 0.3740 = -3.7066
output[1][1]
= [4, 5, 6] * [0.4730, -0.4033, -0.4739] + 0.3182 = -2.6495
通过input和param的对比,我们可以很轻松地理解实际上就是矩阵的乘法操作。而模型在训练过程中就是不断调整param的参数使得输出的张量符合训练集的需求。