代码以及视频讲解
本文所涉及所有资源均在传知代码平台可获取
概述
本文复现论文 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence[1] 提出的半监督学习方法。
半监督学习(Semi-supervised Learning)是一种机器学习方法,它将少量的标注数据(带有标签的数据)和大量的未标注数据(不带标签的数据)结合起来训练模型。在许多实际应用中,标注数据获取成本高且困难,而未标注数据通常较为丰富和容易获取。因此,半监督学习方法被引入并被用于利用未标注数据来提高模型的性能和泛化能力。
该论文介绍了一种基于一致性和置信度的半监督学习方法 FixMatch。FixMatch首先使用模型为弱增强后的未标注图像生成伪标签。对于给定图像,只有当模型产生高置信度预测时才保留伪标签。然后,模型在输入同一图像的强增强版本时被训练去预测伪标签。FixMatch 在各种半监督学习数据集上实现了先进的性能。
算法原理
FixMatch 结合了两种半监督学习方法:一致性正则化和伪标签。其主要创新点在于这两种方法的结合以及在执行一致性正则化时分别使用了弱增强和强增强。
FixMatch 的损失函数由两个交叉熵损失项组成:一个用于有标签数据的监督损失
l
s
l_s
ls 和一个用于无标签数据的无监督损失
l
u
l_u
lu 。具体来说,
l
s
l_s
ls 只是对弱增强有标签样本应用的标准交叉熵损失:
l
s
=
1
B
∑
b
=
1
B
H
(
p
b
,
p
m
(
y
∣
α
(
x
b
)
)
)
l_s=\frac{1}{B}\sum_{b=1}^B H(p_b,p_m(y|\alpha(x_b)))
ls=B1b=1∑BH(pb,pm(y∣α(xb)))
其中
B
B
B 表示 batch size,
H
H
H 表示交叉熵损失,
p
b
p_b
pb 表示标记,
p
m
(
y
∣
α
(
x
b
)
)
p_m(y|\alpha(x_b))
pm(y∣α(xb)) 表示模型对弱增强样本的预测结果。
FixMatch 对每个无标签样本计算一个伪标签,然后在标准交叉熵损失中使用该标签。为了获得伪标签,我们首先计算模型对给定无标签图像的弱增强版本的预测类别分布:
q
b
=
p
m
(
y
∣
α
(
u
b
)
)
q_b=p_m(y|\alpha(u_b))
qb=pm(y∣α(ub))。然后,我们使用
q
^
b
=
arg
max
q
b
\hat q_b=\arg\max q_b
q^b=argmaxqb 作为伪标签,但我们在交叉熵损失中对模型对
u
b
u_b
ub 的强增强版本的输出进行约束:
l
u
=
1
μ
B
∑
b
=
1
μ
B
1
(
m
a
x
(
q
b
)
>
τ
)
H
(
q
^
b
,
p
m
(
y
∣
A
(
u
b
)
)
)
l_u=\frac{1}{\mu B}\sum_{b=1}^{\mu B} 1(max(q_b)>\tau)H(\hat q_b,p_m(y|A(u_b)))
lu=μB1b=1∑μB1(max(qb)>τ)H(q^b,pm(y∣A(ub)))
其中
μ
\mu
μ 表示无标签样本与有标签样本数量之比,
1
(
m
a
x
(
q
b
)
>
τ
)
1(max(q_b)>\tau)
1(max(qb)>τ) 当前仅当
m
a
x
(
q
b
)
>
τ
max(q_b)>\tau
max(qb)>τ 成立时为 1 否则为 0,
τ
\tau
τ 表示置信度阈值,
A
(
u
b
)
A(u_b)
A(ub) 表示对无标签样本的强增强。
FixMatch的总损失就是 l s + λ u l u l_s + \lambda_u l_u ls+λulu,其中 λ u \lambda_u λu 是表示无标签损失相对权重的标量超参数。
FixMatch 利用两种增强方法:“弱增强”和“强增强”。论文所使用的弱增强是一种标准的翻转和位移增强策略。具体来说,除了SVHN数据集之外,我们在所有数据集上以50%的概率随机水平翻转图像,并随机在垂直和水平方向上平移图像最多12.5%。对于“强增强”,我采用了基于随机幅度采样的 RandAugment,然后进行了 Cutout 处理。
我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 数据集上对 FixMatch 进行了实验。关于使用的神经网络,我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。实验结果如下表所示:
数据集 | 准确率(%) |
---|---|
CIFAR-10 | 86.39 |
CIFAR-100 | 68.88 |
SVHN | 91.25 |
FER2013 | 68.57 |
为了直观展示 FixMatch 的效果,我在线部署了基于 FER2013 数据集训练的 Wide ResNe-37-2 模型。FER2013[2] 是一个面部表情识别数据集,其包含约 30000 张不同表情的面部 RGB 图像,尺寸限制为 48×48。其主要标签可分为 7 种类型:愤怒(Angry),厌恶(Disgust),恐惧(Fear),快乐(Happy),悲伤(Sad),惊讶(Surprise),中性(Neutral)。厌恶表情的图像数量最少,只有 600 张,而其他标签的样本数量均接近 5,000 张。
核心逻辑
具体的核心逻辑如下所示:
for epoch in range(epochs):
model.train()
train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)
for labeled_batch, unlabeled_batch in train_tqdm:
optimizer.zero_grad()
# 利用标记样本计算损失
data = labeled_batch[0].to(device)
labels = labeled_batch[1].to(device)
logits = model(normalize(strong_aug(data)))
loss = F.cross_entropy(logits, labels)
# 计算未标记样本伪标签
with torch.no_grad():
data = unlabeled_batch[0].to(device)
logits = model(normalize(weak_aug(data)))
probs = F.softmax(logits, dim=-1)
trusted = torch.max(probs, dim=-1).values > threshold
pseudo_labels = torch.argmax(probs[trusted], dim=-1)
loss_factor = weight * torch.sum(trusted).item() / data.shape[0]
# 利用未标记样本计算损失
logits = model(normalize(strong_aug(data[trusted])))
loss += loss_factor * F.cross_entropy(logits, pseudo_labels)
# 反向梯度传播并更新模型参数
loss.backward()
optimizer.step()
以上代码仅作展示,更详细的代码文件请参见附件。
效果演示
网站提供了在线体验功能。用户需要输入一张长宽尽可能相等且大小不超过 1MB 的正面脸部 JPG 图像,网站就会返回图片中人物表情所表达的情绪。
使用方式
- 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip FixMatch.zip
cd FixMatch
- 代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt
- 如果希望在本地运行程序,请运行如下命令:
python main.py
- 如果希望在线部署,请运行如下命令:
python main-flask.py
(以上内容皆为原创,请勿转载)
参考文献
[1] Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in neural information processing systems, 2020, 33: 596-608.
[2] Wang L, Xu S, Wang X, et al. Eavesdrop the composition proportion of training labels in federated learning[J]. arXiv preprint arXiv:1910.06044, 2019.
源码下载