模型
softmax回归是多类分类模型,用于获取每个分类的置信度,置信度计算方式如下
经过全连接层,得到输出O,将O作为softmax的输入
O是输出向量,每个分量表示一个类别,y_hat_i表示i类别的置信度,softmax回归使得所有类别置信度都为非负数,且相加等于1
损失函数
使用交叉熵来衡量两个概率之间的区别,交叉熵计算方式如下
y_i是真实标签第i个分类的置信度,真实标签y只有一个分量是1,其他是0
损失函数torch实现
- torch.CrossEntroyLoss
//这里的CrossEntroyLoss函数返回batch个样本的总loss值,因此要取个平均值 from torch import nn loss = torch.CrossEntroyLoss(reduction='None') loss.mean().backward()