文章目录
- 前言
- 一、原始程序---计算原型,开始训练,计算损失
- 二、每一行代码的详细解释
- 2.1 粗略分析
- 2.2 每一行代码详细分析
前言
承接系列4
,此部分属于原型类中的计算原型,开始训练,计算损失函数。
一、原始程序—计算原型,开始训练,计算损失
def compute_center(self,data_set): #data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
center = 0
for i in range(self.Ns):
data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
data = Variable(torch.from_numpy(data))
data = self.model(data)[0] #将查询点嵌入另一个空间
if i == 0:
center = data
else:
center += data
center /= self.Ns
return center
def train(self,labels_data,class_number): #网络的训练
#Select class indices for episode
class_index = list(range(class_number))
random.shuffle(class_index)
choss_class_index = class_index[:self.Nc]#选20个类
sample = {'xc':[],'xq':[]}
for label in choss_class_index:
D_set = labels_data[label]
#从D_set随机取支持集和查询集
support_set,query_set = self.randomSample(D_set)
#计算中心点
self.center[label] = self.compute_center(support_set)
#将中心和查询集存储在list中
sample['xc'].append(self.center[label]) #list
sample['xq'].append(query_set)
#优化器
optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
optimizer.zero_grad()
protonets_loss = self.loss(sample)
protonets_loss.backward()
optimizer.step()
def loss(self,sample): #自定义loss
loss_1 = autograd.Variable(torch.FloatTensor([0]))
for i in range(self.Nc):
query_dataSet = sample['xq'][i]
for n in range(self.Nq):
data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
data = Variable(torch.from_numpy(data))
data = self.model(data)[0] #将查询点嵌入另一个空间
#查询点与每个中心点逐个计算欧氏距离
predict = 0
for j in range(self.Nc):
center_j = sample['xc'][j]
if j == 0:
predict = eucli_tensor(data,center_j)
else:
predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
#为loss叠加
loss_1 += -1*F.log_softmax(predict,dim=0)[i]
loss_1 /= self.Nq*self.Nc
return loss_1
二、每一行代码的详细解释
2.1 粗略分析
第一个函数 compute_center(self,data_set)
用于计算支持集中心点的坐标。输入参数 data_set 是一个 numpy 对象,代表支持集。该函数中用了一个 for 循环遍历了每一个支持集中的样本,将其嵌入到另一个空间,并计算其总和来求得所有样本的中心点。最后返回计算出的中心点的坐标。
第二个函数 train(self,labels_data,class_number)
是网络的训练函数。其中 labels_data
是标签数据,class_number
是类别数。首先从 class_number
中随机选取出 Nc 个类,对于每个选出来的类,从其标签数据 D_set
中随机选取出支持集和查询集,并将支持集传入 compute_center()
函数计算中心点。接着将计算出的中心点和查询集存储在样本字典 sample 中。最后使用 Adam 优化器对模型进行优化,并计算损失(调用了 loss 函数),将反向传播得到的梯度更新到模型中。
第三个函数def loss(self,sample)
是一个自定义的损失函数,它的作用是计算样本的损失值。在这个损失函数中,使用了欧氏距离和softmax函数。
2.2 每一行代码详细分析
def compute_center(self,data_set)
: - 这是一个方法,用于计算给定数据集(支持集)的中心点。
2-4. center = 0
- 初始化中心点的变量为0。
5-8. for i in range(self.Ns)
: - 遍历数据集中的每个数据点。
9-14.
这部分代码将数据集中的每个数据点重塑为适应模型输入的形状,并将其转换为PyTorch的Variable。然后,使用模型将查询点嵌入另一个空间。
if i == 0:
- 如果这是第一个数据点,则将查询点设置为中心点。
16-19
. 否则,将查询点添加到中心点。
center /= self.Ns
- 计算中心点,这是所有数据点的平均值。
return center
- 返回计算得到的中心点。
接下来是 train
方法:
23-24.
从给定的标签数据中选择类别索引并随机洗牌。选择特定数量的类别(self.Nc)。
25-30
. 对于所选类别中的每一个,从其数据中随机选择支持集和查询集。
31-33.
使用 compute_center 方法计算每个类的中心点,并将其存储在列表中。同时将查询集也存储在列表中。
34-37.
初始化优化器,这里使用Adam优化算法,学习率设置为0.001。然后清空梯度缓存。
38-42.
计算损失函数值,该损失函数是根据自定义的损失函数计算的。然后进行反向传播以计算梯度。
optimizer.step()
- 使用优化器更新模型的参数。
最后是自定义的损失函数 loss
:
45-46
. 初始化一个张量 loss_1 为0,它用于累计损失值。
47-52
. 对于每个类别(self.Nc),遍历查询集中的每个数据点。对于每个查询点,将其嵌入到另一个空间中,并计算它与每个中心点之间的欧氏距离。
53-57
. 将所有的距离组合在一起,并使用softmax函数将其转换为概率值。然后,对于每个查询点,累加其与所有中心点的负对数似然损失值。
loss_1 /= self.Nq*self.Nc
- 将损失值除以查询集中的数据点数量和类别数量以获得平均损失值。
return loss_1
- 返回计算得到的损失值。