YOLOv8 基于BN层的通道剪枝
1. 稀疏约束训练
在损失项中增加对BN层的缩放系数
γ
\gamma
γ和偏置项
β
\beta
β的稀疏约束,
λ
\lambda
λ系数越大,稀疏约束越严重
L
=
∑
(
x
,
y
)
l
(
f
(
x
)
,
y
)
+
λ
1
∑
γ
g
(
γ
)
+
λ
2
∑
β
g
(
β
)
L = \sum_{(x,y)}l(f(x),y)+\lambda_1 \sum_{\gamma}g(\gamma)+\lambda_2 \sum_{\beta}g(\beta)
L=(x,y)∑l(f(x),y)+λ1γ∑g(γ)+λ2β∑g(β)
对于
L
1
L_1
L1 稀疏约束,有:
g
(
γ
)
=
∣
γ
∣
,
g
(
β
)
=
∣
β
∣
g(\gamma)=|\gamma|,\quad g(\beta) = |\beta|
g(γ)=∣γ∣,g(β)=∣β∣
如果直接修改YOLOv8的损失,不方便控制L只传导对BN的参数更新,因此,采用修改BN的梯度的方式修改。
相对于原始的梯度项,BN的缩放系数和偏置项会增加以下梯度:
Δ
γ
=
∂
(
λ
1
∗
g
(
γ
)
)
∂
γ
=
λ
1
∗
s
i
g
n
(
γ
)
Δ
β
=
∂
(
λ
2
∗
g
(
β
)
)
∂
β
=
λ
2
∗
s
i
g
n
(
β
)
\Delta\gamma = \frac{\partial (\lambda_1*g(\gamma))}{\partial \gamma} = \lambda_1*sign(\gamma) \\ \Delta\beta = \frac{\partial (\lambda_2*g(\beta))}{\partial \beta} = \lambda_2*sign(\beta)
Δγ=∂γ∂(λ1∗g(γ))=λ1∗sign(γ)Δβ=∂β∂(λ2∗g(β))=λ2∗sign(β)
在训练过程中,逐渐减小
λ
1
\lambda_1
λ1参数,减小对
γ
\gamma
γ的约束(稳定训练、增强训练和重调的一致性)
λ
1
=
0.01
∗
(
1
−
0.9
∗
e
n
e
)
\lambda_1 = 0.01*(1-0.9*\frac{e}{ne})
λ1=0.01∗(1−0.9∗nee)
对于YOLOv8,我们只需要找到梯度更新的地方,然后修改即可。
修改YOLOv8代码:ultralytics/engine/trainer.py-390行
# Backward
self.scaler.scale(self.loss).backward()
# ========== 新增 ==========
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# ========== 新增 ==========
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
然后执行如下代码开启训练:
yolo = YOLO("yolov8n.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)
2. 剪枝
稀疏训练之后呢,我们得到了一个best.pt和last.pt,由于需要微调,基于last.pt相对更好。
YOLOv8的结构如下:
该结构中每一个Conv层中均包含一个BN层,对BN进行通道剪枝的时候,一方面需要剪掉Conv的输出通道数和对应的权重,另一方面需要剪掉下一层Conv的输入通道数和权重。
由于前三层0,1,2通道数较少因此每个通道对特征提取均较为重要,因此不剪枝
由于第4,6,9层的输出涉及head层中的通道拼接,结构复杂不便于剪枝,因此不剪枝
此外,其它Conv非连续的部分,例如C2f内部Conv层与Bottleneck之间有split操作,FPN中C2f之间穿插了Upsample,Concat等操作。这些部分我们也不剪枝。
这样来看,我们可以剪枝的地方包括:
模块间
Backbone:
Conv(3) => C2f(4)
Conv(5) => C2f(6)
Conv(7) => C2f(8)
C2f(8) => SPPF(9)
Head:
C2f(15) => [Conv(16),Conv(Detect.cv2[0][0]),Conv(Detect.cv3[0][0])]
C2f(18) => [Conv(19),Conv(Detect.cv2[1][0]),Conv(Detect.cv3[1][0])]
C2f(21) => [Conv(Detect.cv2[2]),Conv(Detect.cv3[2])]
模块内
除了上述模块之间的衔接,模块内的连续Conv主要包括两部分
Bottleneck in C2f
Conv(Bottleneck.cv1) => Conv(Bottleneck.cv2)
cv2, cv3 in Detect
Conv(Detect.cv2[0][0]) => Conv(Detect.cv2[0][1])
Conv(Detect.cv2[0][1]) => Conv2d(Detect.cv2[0][2])
Conv(Detect.cv3[0][0]) => Conv(Detect.cv3[0][1])
Conv(Detect.cv3[0][1]) => Conv2d(Detect.cv3[0][2])
Conv(Detect.cv2[1][0]) => Conv(Detect.cv2[1][1])
Conv(Detect.cv2[1][1]) => Conv2d(Detect.cv2[1][2])
Conv(Detect.cv3[1][0]) => Conv(Detect.cv3[1][1])
Conv(Detect.cv3[1][1]) => Conv2d(Detect.cv3[1][2])
Conv(Detect.cv2[2][0]) => Conv(Detect.cv2[2][1])
Conv(Detect.cv2[2][1]) => Conv2d(Detect.cv2[2][2])
Conv(Detect.cv3[2][0]) => Conv(Detect.cv3[2][1])
Conv(Detect.cv3[2][1]) => Conv2d(Detect.cv3[2][2])
剪枝代码如下:
import torch
from ultralytics import YOLO
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
def prune_conv(conv1: Conv, conv2: Conv, threshold=0.01):
# 剪枝top-bottom conv结构
# 首先,剪枝conv1的bn层和conv层
# 获取conv1的bn层权重和偏置参数作为剪枝的依据
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
# 索引列表,用于存储剪枝后保留的参数索引
keep_idxs = []
local_threshold = threshold
# 保证剪枝后的通道数不少于8,便于硬件加速
while len(keep_idxs) < 8:
# 取绝对值大于阈值的参数对应的索引
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
# 降低阈值
local_threshold = local_threshold * 0.5
# print(local_threshold)
# 剪枝后的通道数
n = len(keep_idxs)
# 更新BN层参数
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
# 更新conv层权重
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
# 更新conv层输出通道数
conv1.conv.out_channels = n
# 更新conv层偏置,如果存在的话
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
# 然后,剪枝conv2的conv层
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is not None:
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
# 更新输入通道数
conv.in_channels = n
# 更新权重
conv.weight.data = conv.weight.data[:, keep_idxs]
def prune_module(m1, m2, threshold=0.01):
# 剪枝 模块间衔接处结构,m1需要获取模块的bottom conv,m2需要获取模块的top conv
# 打印出m1和m2的名字
print(m1.__class__.__name__, end="->")
if isinstance(m2, list):
print([item.__class__.__name__ for item in m2])
else:
print(m2.__class__.__name__)
if isinstance(m1, C2f): # C2f as a top conv
m1 = m1.cv2
if not isinstance(m2, list): # m2 is just one module
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C2f) or isinstance(item, SPPF):
m2[i] = item.cv1
prune_conv(m1, m2, threshold)
def prune():
# Load a model
yolo = YOLO("last.pt")
model = yolo.model
# 统计所有的BN层权重和偏置参数
ws = []
bs = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# 保留80%的参数
factor = 0.8
ws = torch.cat(ws)
# 从大到小排序,取80%的参数对应的阈值
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)
# 先剪枝整个网络bottleneck模块内部的结构
for name, m in model.named_modules():
if isinstance(m, Bottleneck):
prune_conv(m.cv1, m.cv2, threshold)
# 再剪枝backbone模块间衔接结构
seq = model.model
for i in range(3, 9):
if i in [6, 4, 9]: continue
prune_module(seq[i], seq[i + 1], threshold)
# 再剪枝Head模块间衔接结构
# Head模块间剪枝包括两部分,一部分是相邻下层连接,一部分是跨层到Detect层的输出
# 从last_inputs到colasts是相邻下层连接,从last_inputs到detect是跨层到最后的输出
detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
prune_module(last_input, [colast, cv2[0], cv3[0]], threshold)
# 剪枝Detect层内部模块间衔接结构
prune_module(cv2[0], cv2[1], threshold, )
prune_module(cv2[1], cv2[2], threshold)
prune_module(cv3[0], cv3[1], threshold)
prune_module(cv3[1], cv3[2], threshold)
# 设置所有参数为可训练,为retrain做准备
for name, p in yolo.model.named_parameters():
p.requires_grad = True
# 保存剪枝后的模型
yolo.save("prune.pt")
if __name__ == '__main__':
prune()
3. 重调
剪枝完成后需要进行重调,此时我们需要先取消稀疏约束,即将trainer中的约束代码重新注释掉
随后,重调的时候,需要防止代码重新根据yaml文件生成模型,而是直接读取权重模型
修改:在ultralytics/engine/model.py-808行后添加
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# 新增 ===================================
self.trainer.model.model = self.model.model
# 新增 ===================================
self.model = self.trainer.model
随后基于如下代码进行重调训练:
yolo = YOLO("prune.pt")
yolo.train(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640, epochs=50)
4. 对比
我们可以对比一下稀疏训练的原模型、剪枝后的模型、重调后的模型的精度、参数、计算量
def compare_prune():
# 统计压缩前后的参数量,精度,计算量
yolo = YOLO("last.pt")
before_results = yolo.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)
yolo_prune = YOLO("prune.pt")
prune_results = yolo_prune.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)
yolo_retrain = YOLO("retrain.pt")
retrain_results = yolo_retrain.val(data='ultralytics/cfg/datasets/diagram.yaml', imgsz=640)
# 打印压缩前后的参数量,精度,计算量
n_l, n_p, n_g, flops = yolo.info()
prune_n_l, prune_n_p, prune_n_g, prune_flops = yolo_prune.info()
retrain_n_l, retrain_n_p, retrain_n_g, retrain_flops = yolo_retrain.info()
acc = before_results.box.map
prune_acc = prune_results.box.map
retrain_acc = retrain_results.box.map
print(f"{'':<10}{'Before':<10}{'Prune':<10}{'Retrain':<10}")
print(f"{'Params':<10}{n_p:<10}{prune_n_p:<10}{retrain_n_p:<10}")
print(f"{'FLOPs':<10}{flops:<10}{prune_flops:<10}{retrain_flops:<10}")
print(f"{'Acc':<10}{acc:<10}{prune_acc:<10}{retrain_acc:<10}")