计算机视觉之SSD改进版本(平滑L1范数损失与焦点损失)《4》

news2025/1/15 23:20:30

        在 计算机视觉之单发多框检测(Single Shot MultiBox Detector)模型《3》中我们使用到的是L1范数损失,L1范数损失也叫做平均绝对误差(MAE),目标值与预测值之差的绝对值的和,表示的是预测值的平均误差幅度。它的缺点就是0点附近不能求导,不方便求解,而且不光滑,网络也不是很稳定,所以我们设计一个在0点附近使用一个平方函数,让它显得很平滑,这里通过一个超参数\sigma来控制平滑区域。

平滑L1范数损失(Smooth_L1)

我们可以先看下数学公式:

f(x) = \begin{cases} (\sigma x)^2/2,& \text{if }x < 1/\sigma^2\\ |x|-0.5/\sigma^2,& \text{otherwise} \end{cases}

代码实现很简单,我们看个小栗子: 

from mxnet import nd

print(nd.smooth_l1(nd.array([-1,0.1, 0.5, 1, 2, 3, 4]), scalar=1))
#[0.5   0.005 0.125 0.5   1.5   2.5   3.5  ]

这里的sigma是参数scalar,可以看出对输入做了分两步讨论并计算的。我们来画图更直观的看下:

sigmas = [10, 1, 0.5]
lines = ['-', '--', '-.']
x = nd.arange(-2, 2, 0.1)
d2l.set_figsize()

for l, s in zip(lines, sigmas):
    y = nd.smooth_l1(x, scalar=s)
    d2l.plt.plot(x.asnumpy(),y.asnumpy(),l,label='sigma=%.1f'%s)
d2l.plt.legend()
d2l.plt.show()

可以看到这个超参数sigma很大的时候,图形就是这个L1范数损失了,在零点显得很尖锐,当sigma比较小的时候,这个图形就显得很平滑。
另外:L2范数损失函数,也叫最小平方误差(LSE),就是目标值与预测值之差的平方和最小化。一般的回归问题用这个比较多,但是有了平方,你想下,这个误差就平方了,也就是说模型对于这个样本的敏感性高很多了,相对来说不是很稳定,所以L1范数损失的鲁棒性比L2范数损失要好点。
我们来看下在MXNet中的实现:

class SmoothL1Loss(gluon.loss.Loss):
   def __init__(self, batch_axis=0, **kwargs):
       super(SmoothL1Loss, self).__init__(None, batch_axis, **kwargs)

   def hybrid_forward(self, F, output, label, mask):
       loss = F.smooth_l1((output - label) * mask, scalar=1.0)
       return F.mean(loss, self._batch_axis, exclude=True)

然后应用到皮卡丘的检测中,我们发现图中的检测框要精细点了,多了几个是吧,虽然只检测到了一个皮卡丘,原因可能是迭代次数比较少的问题,使用的损失函数比较平滑的缘故,不像那个L1范数损失那么强烈。

焦点损失(Focal Loss,FL)

我们除了使用交叉熵损失:设真实类别j的预测概率是p_j,交叉熵损失为:-logp_{j}
还可以使用焦点损失(focus loss):给定正的超参数γ和α,该损失函数的定义为:

-\alpha (1-p_j)^{\gamma} logp_{j}

从公式可以看到,增大γ可以有效减小正类预测概率较大时的损失。同样的画图来直观的看下:

gammas = [0, 1, 5]
lines = ['-', '--', '-.']
d2l.set_figsize()

def focal_loss(gamma, x):
    return -(1-x)**gamma * x.log()

x = nd.arange(0.01, 1, 0.01)
for l, gamma in zip(lines, gammas):
    y = d2l.plt.plot(x.asnumpy(), focal_loss(gamma, x).asnumpy(), l, label='gamma=%.1f' % gamma)
d2l.plt.legend()
d2l.plt.show()

当然这里是α设为1了,我们看下在MXNet中的实现:

class FocalLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs):
        super(FocalLoss, self).__init__(None, batch_axis, **kwargs)
        self._axis = axis
        self._alpha = alpha
        self._gamma = gamma
    
    def hybrid_forward(self, F, output, label):
        output = F.softmax(output)
        pj = F.pick(output, label, axis=self._axis, keepdims=True)
        loss = -self._alpha * ((1 - pj) ** self._gamma) * F.log(pj)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

其中F.pick来根据标签选取各个概率,也就是上面公式中的p_j。返回值跟上面的SmoothL1Loss一样都是按照选定的维度求均值。这个焦点损失可以弥补上面平滑损失由于平滑而迭代慢的问题,现在我们结合这两个损失函数来训练模型。

SSD改进版本

我们使用上述的改进过的两个损失函数来替换原来的,发现效果要好很多,最起码多出了很多检测框,置信度依然不是很正确,不过发现0.1的置信度有很多锚框了,于是将置信度调整为0.3,减少锚框数,我们来看下全部代码:

import d2lzh as d2l
from mxnet import gluon, image, nd, init, contrib, autograd
from mxnet.gluon import loss as gloss, nn
import time

def cls_predictor(num_anchors, num_classes):
    '''
    类别预测层

    参数
    ------
    通道数:num_anchors*(num_classes+1)
        其中类别数需要加一个背景
    卷积核大小:3,填充:1
        可以保持输出的高宽不变
    '''
    return nn.Conv2D(num_anchors*(num_classes+1), kernel_size=3, padding=1)

def bbox_predictor(num_anchors):
    '''
    边界框预测层

    参数
    -----
    通道数:num_anchors*4
        为每个锚框预测4个偏移量
    '''
    return nn.Conv2D(num_anchors*4, kernel_size=3, padding=1)

def forward(x, block):
    block.initialize()
    return block(x)

Y1 = forward(nd.zeros((2, 8, 20, 20)), cls_predictor(5, 10))
Y2 = forward(nd.zeros((2, 16, 10, 10)), cls_predictor(3, 10))

def flatten_pred(pred):
    return pred.transpose((0, 2, 3, 1)).flatten()

def concat_preds(preds):
    return nd.concat(*[flatten_pred(p) for p in preds], dim=1)

def down_sample_blk(num_channels):
    '''
    高宽减半块
        步幅为2的2x2的最大池化层将特征图的高宽减半
    串联两个卷积层和一个最大池化层
    '''
    blk = nn.Sequential()
    for _ in range(2):
        blk.add(nn.Conv2D(num_channels, kernel_size=3, padding=1),
                nn.BatchNorm(in_channels=num_channels), nn.Activation('relu'))
    blk.add(nn.MaxPool2D(pool_size=(2, 2), strides=2))
    return blk

def base_net():
    '''
    基础网络块
        串联3个高宽减半块,以及通道数翻倍
    '''
    blk = nn.Sequential()
    for n in [16, 32, 64]:
        blk.add(down_sample_blk(n))
    return blk

def get_blk(i):
    if i == 0:
        blk = base_net()
    elif i == 4:
        blk = nn.GlobalMaxPool2D()
    else:
        blk = down_sample_blk(128)
    return blk

def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):
    Y = blk(X)
    anchors = contrib.nd.MultiBoxPrior(Y, sizes=size, ratios=ratio)
    cls_preds = cls_predictor(Y)
    bbox_preds = bbox_predictor(Y)
    return (Y, anchors, cls_preds, bbox_preds)

sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79], [0.88, 0.961]]
ratios = [[1, 2, 0.5]]*5
num_anchors = len(sizes[0])+len(ratios[0])-1

class TinySSD(nn.Block):
    def __init__(self, num_classes, **kwargs):
        super(TinySSD, self).__init__(**kwargs)
        self.num_classes = num_classes
        for i in range(5):
            setattr(self, 'blk_%d' % i, get_blk(i))
            setattr(self, 'cls_%d' %i, cls_predictor(num_anchors, num_classes))
            setattr(self, 'bbox_%d' % i, bbox_predictor(num_anchors))

    def forward(self, X):
        anchors, cls_preds, bbox_preds = [None]*5, [None]*5, [None]*5
        for i in range(5):
            X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(X, getattr(self, 'blk_%d' % i), sizes[i], ratios[i],
                                                                     getattr(self, 'cls_%d' % i),
                                                                     getattr(self, 'bbox_%d' % i))
        return (nd.concat(*anchors, dim=1),
                concat_preds(cls_preds).reshape(0, -1, self.num_classes+1),
                concat_preds(bbox_preds))

net=TinySSD(num_classes=1)
net.initialize()
X=nd.zeros((32,3,256,256))
anchors,cls_preds,bbox_preds=net(X)

#加载皮卡丘数据集并初始化模型
batch_size=8#本人配置不行,批处理大小调小点
train_iter,_=d2l.load_data_pikachu(batch_size)
ctx,net=d2l.try_gpu(),TinySSD(num_classes=1)
net.initialize(init=init.Xavier(),ctx=ctx)
trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.2,'wd':5e-4})

#定义损失函数
class SmoothL1Loss(gloss.Loss):
   def __init__(self, batch_axis=0, **kwargs):
       super(SmoothL1Loss, self).__init__(None, batch_axis, **kwargs)

   def hybrid_forward(self, F, output, label, mask):
       loss = F.smooth_l1((output - label) * mask, scalar=1.0)
       return F.mean(loss, self._batch_axis, exclude=True)

class FocalLoss(gloss.Loss):
    def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs):
        super(FocalLoss, self).__init__(None, batch_axis, **kwargs)
        self._axis = axis
        self._alpha = alpha
        self._gamma = gamma
    
    def hybrid_forward(self, F, output, label):
        output = F.softmax(output)
        pj = F.pick(output, label, axis=self._axis, keepdims=True)
        loss = -self._alpha * ((1 - pj) ** self._gamma) * F.log(pj)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

#cls_loss = gloss.SoftmaxCrossEntropyLoss()
#bbox_loss = gloss.L1Loss()
cls_loss = FocalLoss()
bbox_loss = SmoothL1Loss()

def calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks):
    cls=cls_loss(cls_preds,cls_labels)
    #bbox=bbox_loss(bbox_preds*bbox_masks,bbox_labels*bbox_masks)
    bbox=bbox_loss(bbox_preds,bbox_labels,bbox_masks)
    return cls+bbox
#评价函数
def cls_eval(cls_preds,cls_labels):
    #类别预测结果放在最后一维,所以argmax需指定最后一维
    return (cls_preds.argmax(axis=-1)==cls_labels).sum().asscalar()
def bbox_eval(bbox_preds,bbox_labels,bbox_masks):
    return ((bbox_labels-bbox_preds)*bbox_masks).abs().sum().asscalar()

#训练模型
for epoch  in range(20):
    acc_sum,mae_sum,n,m=0.0,0.0,0,0
    train_iter.reset()
    start=time.time()
    for batch in train_iter:
        X=batch.data[0].as_in_context(ctx)
        Y=batch.label[0].as_in_context(ctx)
        with autograd.record():
            #生成多尺度锚框,为每个锚框预测类别和偏移量
            anchors,cls_preds,bbox_preds=net(X)
            bbox_labels,bbox_masks,cls_labels=contrib.nd.MultiBoxTarget(anchors,Y,cls_preds.transpose((0,2,1)))
            l=calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks)
        l.backward()
        trainer.step(batch_size)
        acc_sum+=cls_eval(cls_preds,cls_labels)
        n+=cls_labels.size
        mae_sum+=bbox_eval(bbox_preds,bbox_labels,bbox_masks)
        m+=bbox_labels.size
    if(epoch+1)%5==0:
        print('迭代次数:%2d,类别损失误差:%.2e,正类锚框偏移量平均绝对误差:%.2e,耗时:%.1f秒'%(epoch+1,1-acc_sum/n,mae_sum/m,time.time()-start))

net.collect_params().save('ssd.params')

img=image.imread('pkq.png')
feature=image.imresize(img,256,256).astype('float32')
X=feature.transpose((2,0,1)).expand_dims(axis=0)#转换成卷积层需要的四维格式
#net.collect_params().load('ssd.params')
def predict(X):
    anchors,cls_preds,bbox_preds=net(X.as_in_context(ctx))
    cls_probs=cls_preds.softmax().transpose((0,2,1))
    output=contrib.nd.MultiBoxDetection(cls_probs,bbox_preds,anchors)
    idx=[i for i,row in enumerate(output[0]) if row[0].asscalar()!=-1]
    return output[0,idx]

output=predict(X)

d2l.set_figsize((5,5))

def display(img,output,threshold):
    fig=d2l.plt.imshow(img.asnumpy())
    for row in output:
        score=row[1].asscalar()
        if score<threshold:
            continue
        h,w=img.shape[0:2]
        bbox=[row[2:6]*nd.array((w,h,w,h),ctx=row.context)]
        d2l.show_bboxes(fig.axes,bbox,'%.2f'%score,'r')

display(img,output,threshold=0.3)
d2l.plt.show()

output[0,idx]

附带上节的一个知识点的解释

idx=[i for i,row in enumerate(output[0]) if row[0].asscalar()!=-1]
return output[0,idx]

对这部分增补个解释,对于初学者可能有点不理解,output是来自MultiBoxDetection函数的返回:第一列是类别;第二列是置信度;后面的四列就是左上角与右下角的坐标(相对坐标)
if row[0].asscalar()!=-1这里就是将第一列是-1(背景或非极大值抑制中被移除)的情况筛选掉。所以返回的idx就是所有正类的行索引
return output[0,idx]输出正类的锚框
其中ids是列表,我们举个例子:

from mxnet import nd
a=a.reshape(3,4,5)
'''
[[[ 0.  1.  2.  3.  4.]
  [ 5.  6.  7.  8.  9.]
  [10. 11. 12. 13. 14.]
  [15. 16. 17. 18. 19.]]

 [[20. 21. 22. 23. 24.]
  [25. 26. 27. 28. 29.]
  [30. 31. 32. 33. 34.]
  [35. 36. 37. 38. 39.]]

 [[40. 41. 42. 43. 44.]
  [45. 46. 47. 48. 49.]
  [50. 51. 52. 53. 54.]
  [55. 56. 57. 58. 59.]]]
<NDArray 3x4x5 @cpu(0)>
'''
a[0]
'''
[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]]
<NDArray 4x5 @cpu(0)>
'''

a[0,[2,3]]
'''
[[10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]]
<NDArray 2x5 @cpu(0)>
'''

另外a[0,2,3]这样的写法等价于a[0][2][3]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/74435.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

U3D中使用SPINE疑难杂症和解决办法

我使用的SPINE是3.8.99&#xff0c;项目当中SPINE需要使用特别多&#xff0c;网上都千篇一律&#xff0c;找不到一些实际遇到的问题&#xff0c;下面都是我遇到的一些稀奇古怪的问题。 1.SPINE导入U3D&#xff0c;拖到场景里&#xff0c;可以选择创建为2D或者UGUI组件&#xf…

分片上传—webloader

最近研究大文件上传方案的时候偶然间发现的一个百度开源的工具&#xff1a;webloader&#xff0c;用了一下&#xff0c;确实还不错&#xff0c;下面带着大家一起使用一下。 1.引入资源 使用Web Uploader文件上传需要引入三种资源&#xff1a;JS, CSS, SWF。 所以我们需要先下…

tomcat学习笔记

1.tomcat使用的方法有很多种&#xff0c;我这边使用的是直接解压使用的版本 tomcat 9.0.45版本免安装版下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1c6NN2Z-McuN4uw6JGWZmrA?pwdrl7t 提取码&#xff1a;rl7t 2.启动方式是在bin目录下找到startup.bat运行&…

用HTML+CSS做一个漂亮简单的花店网页【免费的学生网页设计成品】

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

学习笔记-4-ANN-1-Preliminaries

细节内容请关注微信公众号&#xff1a;运筹优化与数据科学 ID: pomelo_tree_opt Outline 1. ANN与SVM 2. ANN的数学基础 3. ANN history 4. Deep neural network ------------------------------ 1. ANN与SVM的区别 SVM, SVR中有很多数学推导的过程&#xff0c; 例如pri…

Spring 使用指南 ~ 3、Spring 中 bean 的生命周期详解

Spring 中 bean 的生命周期详解 一、bean 的生命周期图解 二、bean 创建时使用 Spring 的资源 实现 aware 类型接口的 bean&#xff0c;可以在 bean 实例化的时候获取到一些相对应的资源&#xff0c;如实现 BeanNameAware 的 bean&#xff0c;就可以获取到 beanName。Spring …

[附源码]JAVA毕业设计无人驾驶汽车管理系统(系统+LW)

[附源码]JAVA毕业设计无人驾驶汽车管理系统&#xff08;系统LW&#xff09; 项目运行 环境项配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目…

[附源码]JAVA毕业设计西藏民族大学论文管理系统(系统+LW)

[附源码]JAVA毕业设计西藏民族大学论文管理系统&#xff08;系统LW&#xff09; 项目运行 环境项配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 …

KubeSphere v3.3.1 权限控制详解

作者&#xff1a;周文浩&#xff0c;青云科技研发工程师&#xff0c;KubeSphere Maintainer。热爱云原生&#xff0c;热爱开源&#xff0c;目前负责 KubeSphere 权限控制的开发与维护。 KubeSphere 3.3.1 已经发布一个多月了。 3.3.1 版本对于 KubeSphere 来说只是一个小的 Pat…

ADI Blackfin DSP处理器-BF533的开发详解10:SPORT-IIS口驱动和应用(含源代码)

硬件准备 ADSP-EDU-BF533&#xff1a;BF533开发板 AD-HP530ICE&#xff1a;ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 接口功能介绍 ADSP-BF533上有两个 SPORT 口&#xff0c;SPORT&#xff08;synchronous serial ports&#xff09;接口是 ADSP-BF53x 上速度最快的…

执法仪物联卡在哪里采购靠谱

在这个万物互联的时代&#xff0c;针对于企业设备联网的物联卡就显得格外重要了&#xff0c;而共享单车&#xff0c;移动支付&#xff0c;智慧城市&#xff0c;自动售卖机等企业采购物联卡会面临着各种问题&#xff0c;低价陷阱&#xff0c;流量虚假&#xff0c;管理混乱&#…

基于LSTM递归神经网络的多特征电能负荷预测(Python代码实现)

&#x1f468;‍&#x1f393;个人主页&#xff1a;研学社的博客 &#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜…

Kubernetes 系统化学习之 集群安全篇(七)

Kubernetes 作为一个分布式集群管理的工具&#xff0c;保证集群的安全性是其一个重要的任务。API Server 是集群内部各个组件通信的中介&#xff0c;也是外部控制的入口&#xff0c;所以 K8s 的安全机制就是围绕保护 API Server 来设计的。K8s 使用了认证&#xff08;Authentic…

计算机毕业设计springboot+vue基本微信小程序的大学生竞赛信息发布与报名小程序

项目介绍 大学生竞赛是提升大学生综合能力和专业素质的重要手段和途径,越来越受到广大高校师生的关注与重视。大学生学科竞赛活动不仅有利于提升大学生的专业素养,也有利于提升大学生的创新、实践能力、竞争意识与团队精神。 各类学科竞赛汇总、信息发布、信息收集、报名、备赛…

针对前端项目node版本问题导致依赖安装异常的处理办法

Mac如何切换版本 前端项目开发过程中&#xff0c;多人开发会遇到由于node版本不同造成的依赖不适配。 比如: node 16.xx 大多都会遇到依赖版本与node版本不适配导致安装报错等问题&#xff0c;并且你不管如何更新还是使用稳定版本的node.js都不起作用&#xff0c;此时你需要修…

看直播怎么录屏?这2个方法,一看就会!

​现在很多人在斗鱼、虎牙、斗鱼、腾讯课堂等平台上直播&#xff0c;有些人在视频聊天平台上&#xff0c;如微信上直播。我们如何保存这些直播视频&#xff1f;看直播怎么录屏&#xff1f;今天小编就分享2个方法&#xff0c;教你如何看直播的同时录屏。 看直播怎么录屏方法一&a…

Font字体属性

Font字体属性 源代码 font font属性用于定义字体系列、大小、粗细、和文字样式(如斜体) font-family font-family属性用于定义文本字体&#xff0c;多个字体用 ” , ” 号隔开&#xff0c;一般情况下&#xff0c;有空格隔开的多个单词组成的字体&a…

Eziriz .NET Reactor保护NET代码

Eziriz .NET Reactor保护NET代码 NET Reactor软件是一个简单而小巧的工具&#xff0c;但对保护NET代码非常强大。会的。编程数据可以通过使用本程序、编写的代码和程序来保护&#xff0c;并禁止复制和使用它们。 Eziriz.NET Reactor软件的功能和特点&#xff1a; -支持收集和模…

MODBUS-ASCII协议

MODBUS协议在RS485总线上有RTU和ASCII两种传输格式。 其中ASCII协议应用比较少&#xff0c;主要还是以RTU格式为主。 相比较于RTU的2进制传输&#xff0c;ASCII使用的是文本传输&#xff0c;整个数据包都是可打印字符。 下面是示例&#xff1a; :010300000001FB\r\n 帧头是冒…