LaneAF论文解读和代码讲解

news2024/10/3 2:23:35

论文地址:https://arxiv.org/abs/2103.12040

一、论文简介

LaneAF是一个语义分割+聚类后处理的一种方法。相对于之前的用聚类算法对embedding分支聚类的方法,该论文提出了水平和垂直两个向量场,用来取缔之前的普通聚类。根据向量场就可以完成聚类问题。

二、网络结构

 使用了DLA-34作为骨干网络,然后经过DLAUP和IDAUP的上采样处理,然后接三个预测分支。

具体的网络结构如下所示:

c是DLA的架构,d是论文里改进的。相当于加了更多的跨层连接。

后面的分支一个是二值分割网络channel=1,一个channel=2的VAF,一个channel=1的HAF.

 (hm): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (vaf): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
  )
  (haf): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
  )

 损失函数:分割使用的加权交叉熵损失和IOU损失,两个向量场分支只对前景点使用L1loss。

 三、网络中的VAF和HAF的生成与推理

1、VAF和HAF的生成

GT的生成,自底向上逐行扫描,对每一行属于当前车道线实例的的像素点根据公式计算其VAF和HAF。公式如下:

 首先对每一行计算HAF。其中\overline{x_{y}^{l}}表示当前行属于该车道线实例的像素中心(一行属于一个车道线的像素有好几个,求平均就是中心值)。这样每行的HAF,要么指向左,要么指向右,要么为0.

 

 再来看看VAF的计算公式,对于属于车道线实例l的像素i,其VAF表示指向上一行车道线实例中心像素的单位向量。这样HAF把每一行的像素分成一个一个实例,然后VAF再把他们串起来成为车道线实例。

2、VAF和HAF的解码

预测完VAF和HAF之后,对于每个前景点,需要对这两个向量解码来实现车道线实例分割。自底向上,先根据HAF对每一行进行聚类,然后再根据VAF对不同行进行关联。行聚类的判断标准是:只有前面像素指向左并且当前像素指向右时,才会为当前像素重新分配一个cluster,这样就可以完成了行聚类,那么不同行之间怎么关联?本行的一簇像素与上一行的簇中心计算平均距离,然后取最小的进行匹配。具体的看下面的代码讲解。

四、代码分析

heads = {'hm': 1, 'vaf': 2, 'haf': 1}
model = get_pose_net(num_layers=34, heads=heads, head_conv=256, down_ratio=4)

def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
    model = DLASeg('dla{}'.format(num_layers), heads,
                   pretrained=True,
                   down_ratio=down_ratio,
                   final_kernel=1,
                   last_level=5,
                   head_conv=head_conv)
    return model

class DLASeg(nn.Module):
    def __init__(self, base_name, heads, pretrained, down_ratio, final_kernel,
                 last_level, head_conv, out_channel=0):
        super(DLASeg, self).__init__()
        assert down_ratio in [2, 4, 8, 16] # 4
        self.first_level = int(np.log2(down_ratio)) # 2
        self.last_level = last_level # 5
        self.base = globals()[base_name](pretrained=pretrained) # DLA-34
        channels = self.base.channels # [16, 32, 64, 128, 256, 512]
        scales = [2 ** i for i in range(len(channels[self.first_level:]))] # [1, 2, 4, 8]
        self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)

        if out_channel == 0:
            out_channel = channels[self.first_level] # 64

        self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], 
                            [2 ** i for i in range(self.last_level - self.first_level)])
        
        self.heads = heads
        for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
              fc = nn.Sequential(
                  nn.Conv2d(channels[self.first_level], head_conv,
                    kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(head_conv, classes, 
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))
              if 'hm' in head:
                fc[-1].bias.data.fill_(-2.19)
              else:
                fill_fc_weights(fc)
            else:
              fc = nn.Conv2d(channels[self.first_level], classes, 
                  kernel_size=final_kernel, stride=1, 
                  padding=final_kernel // 2, bias=True)
              if 'hm' in head:
                fc.bias.data.fill_(-2.19)
              else:
                fill_fc_weights(fc)
            self.__setattr__(head, fc)

    def forward(self, x):
        x = self.base(x) # DLA-34 backbone
        x = self.dla_up(x)

        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))

        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(y[-1])
        return [z]
def generateAFs(label, viz=False):
    # creating AF arrays
    num_lanes = np.amax(label)
    VAF = np.zeros((label.shape[0], label.shape[1], 2))
    HAF = np.zeros((label.shape[0], label.shape[1], 2))

    # loop over each lane
    for l in range(1, num_lanes+1):
        # initialize previous row/cols
        prev_cols = np.array([], dtype=np.int64)
        prev_row = label.shape[0]

        # parse row by row, from second last to first
        for row in range(label.shape[0]-1, -1, -1):
            cols = np.where(label[row, :] == l)[0] # get fg cols

            # get horizontal vector
            for c in cols:
                if c < np.mean(cols):
                    HAF[row, c, 0] = 1.0 # point to right
                elif c > np.mean(cols):
                    HAF[row, c, 0] = -1.0 # point to left
                else:
                    HAF[row, c, 0] = 0.0 # point to left

            # check if both previous cols and current cols are non-empty
            if prev_cols.size == 0: # if no previous row/cols, update and continue
                prev_cols = cols
                prev_row = row
                continue
            if cols.size == 0: # if no current cols, continue
                continue
            col = np.mean(cols) # calculate mean

            # get vertical vector
            for c in prev_cols:
                # calculate location direction vector
                vec = np.array([col - c, row - prev_row], dtype=np.float32)
                # unit normalize
                vec = vec / np.linalg.norm(vec)
                VAF[prev_row, c, 0] = vec[0]
                VAF[prev_row, c, 1] = vec[1]

            # update previous row/cols with current row/cols
            prev_cols = cols
            prev_row = row

    if viz: # visualization
        down_rate = 1 # downsample visualization by this factor
        fig, (ax1, ax2) = plt.subplots(1, 2)
        # visualize VAF
        q = ax1.quiver(np.arange(0, label.shape[1], down_rate), -np.arange(0, label.shape[0], down_rate), 
            VAF[::down_rate, ::down_rate, 0], -VAF[::down_rate, ::down_rate, 1], scale=120)
        # visualize HAF
        q = ax2.quiver(np.arange(0, label.shape[1], down_rate), -np.arange(0, label.shape[0], down_rate), 
            HAF[::down_rate, ::down_rate, 0], -HAF[::down_rate, ::down_rate, 1], scale=120)
        plt.show()

    return VAF, HAF
def decodeAFs(BW, VAF, HAF, fg_thresh=128, err_thresh=5, viz=False):
    output = np.zeros_like(BW, dtype=np.uint8) # initialize output array
    lane_end_pts = [] # keep track of latest lane points
    next_lane_id = 1 # next available lane ID

    if viz:
        im_color = cv2.applyColorMap(BW, cv2.COLORMAP_JET)
        cv2.imshow('BW', im_color)
        ret = cv2.waitKey(0)

    # start decoding from last row to first
    for row in range(BW.shape[0]-1, -1, -1):
        cols = np.where(BW[row, :] > fg_thresh)[0] # get fg cols
        clusters = [[]]
        if cols.size > 0:
            prev_col = cols[0]

        # parse horizontally
        for col in cols:
            if col - prev_col > err_thresh: # if too far away from last point
                clusters.append([])
                clusters[-1].append(col)
                prev_col = col
                continue
            if HAF[row, prev_col] >= 0 and HAF[row, col] >= 0: # keep moving to the right
                clusters[-1].append(col)
                prev_col = col
                continue
            elif HAF[row, prev_col] >= 0 and HAF[row, col] < 0: # found lane center, process VAF
                clusters[-1].append(col)
                prev_col = col
            elif HAF[row, prev_col] < 0 and HAF[row, col] >= 0: # found lane end, spawn new lane
                clusters.append([])
                clusters[-1].append(col)
                prev_col = col
                continue
            elif HAF[row, prev_col] < 0 and HAF[row, col] < 0: # keep moving to the right
                clusters[-1].append(col)
                prev_col = col
                continue

            # if col - prev_col > err_thresh: # if too far away from last point
            #     clusters.append([])
            #     clusters[-1].append(col)
            #     prev_col = col
            #     continue
            # if HAF[row, prev_col] < 0 and HAF[row, col] >= 0: # found lane end, spawn new lane
            #     clusters.append([])
            #     clusters[-1].append(col)
            #     prev_col = col
            # else: # keep moving to the right
            #     clusters[-1].append(col)
            #     prev_col = col

        # parse vertically
        # assign existing lanes
        assigned = [False for _ in clusters]
        C = np.Inf*np.ones((len(lane_end_pts), len(clusters)), dtype=np.float64)
        for r, pts in enumerate(lane_end_pts): # for each end point in an active lane
            for c, cluster in enumerate(clusters):
                if len(cluster) == 0:
                    continue
                # mean of current cluster
                cluster_mean = np.array([[np.mean(cluster), row]], dtype=np.float32)
                # get vafs from lane end points
                vafs = np.array([VAF[int(round(x[1])), int(round(x[0])), :] for x in pts], dtype=np.float32)
                vafs = vafs / np.linalg.norm(vafs, axis=1, keepdims=True)
                # get predicted cluster center by adding vafs
                pred_points = pts + vafs*np.linalg.norm(pts - cluster_mean, axis=1, keepdims=True)
                # get error between prediceted cluster center and actual cluster center
                error = np.mean(np.linalg.norm(pred_points - cluster_mean, axis=1))
                C[r, c] = error
        
     # assign clusters to lane (in acsending order of error)
        row_ind, col_ind = np.unravel_index(np.argsort(C, axis=None), C.shape)
        for r, c in zip(row_ind, col_ind):
            if C[r, c] >= err_thresh:
                break # 升序,后面的肯定都不满足阈值要求,直接跳出循环
            if assigned[c]:
                continue
            assigned[c] = True
            # update best lane match with current pixel
            output[row, clusters[c]] = r+1
            lane_end_pts[r] = np.stack((np.array(clusters[c], dtype=np.float32), row*np.ones_like(clusters[c])), axis=1)
        
     # initialize unassigned clusters to new lanes
        for c, cluster in enumerate(clusters):
            if len(cluster) == 0:
                continue
            if not assigned[c]:
                output[row, cluster] = next_lane_id
                lane_end_pts.append(np.stack((np.array(cluster, dtype=np.float32), row*np.ones_like(cluster)), axis=1))
                next_lane_id += 1

    if viz:
        im_color = cv2.applyColorMap(40*output, cv2.COLORMAP_JET)
        cv2.imshow('Output', im_color)
        ret = cv2.waitKey(0)

    return output

 

五、Reference

[Paper Reading]LaneAF: Robust Multi-Lane Detection with Affinity Fields - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/369616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HashMap红黑树源码解读

链表转换为红黑树节点当往hashMap中添加元素&#xff0c;在同一个hash槽位挂载的元素超过8个后&#xff0c;执行treeifyBin方法。在treeifyBin方法中&#xff0c;只有当tab数组&#xff08;hash槽位&#xff09;的长度不小于MIN_TREEIFY_CAPACITY&#xff08;默认64&#xff09…

modbus协议

1 MODBUS 应用层报文传输协议 ADU应用数据单元PDU协议数据单元MBMODBUS协议MBAPMODBUS协议 ADU&#xff1a;地址域 PDU 差错校验PDU&#xff1a;功能码数据 串行链路&#xff1a; 最大RS485 ADU 256 字节PDU 256 - 服务器地址&#xff08;1字节&#xff09;- CRC&#xf…

Linux学习(7)文件权限与目录配置

目录 1. 使用者与群组 1&#xff0c;文件拥有者 2&#xff0c;群组概念 3&#xff0c;其他人 Linux 用户身份与群组记录的文件 2.Linux 文件权限概念 Linux的文件属性 第一栏代表这个文件的类型与权限(permission)&#xff1a; 第二栏表示有多少档名连结到此节点(i-no…

linux CUDAtoolkit+cudnn+tensorrt 的安装

windows上 CUDAtoolkitcudnn的安装 CUDAtoolkitcudnn的安装 须知 use command ubuntu-drivers devices查看你的显卡类型和推荐的驱动版本百度 nvidia-driver-*** 支持的 cuda 或 去文档查找驱动(比如450&#xff0c;460)匹配的cuda版本 下载 网盘下载 https://www.aliyundr…

Python实现贝叶斯优化器(Bayes_opt)优化Catboost回归模型(CatBoostRegressor算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。1.项目背景贝叶斯优化器 (BayesianOptimization) 是一种黑盒子优化器&#xff0c;用来寻找最优参数。贝叶斯优化器是…

18523-47-2,3-Azidopropionic Acid,叠氮基丙酸,可以与炔烃发生点击化学反应

【中文名称】3-叠氮基丙酸【英文名称】 3-Azidopropionic Acid&#xff0c;3-Azidopropionic COOH【结 构 式】【CAS】18523-47-2【分子式】C3H5N3O2【分子量】115.09【纯度标准】95%【包装规格】1g&#xff0c;5g&#xff0c;10g【是否接受定制】可进行定制&#xff0c;定制时…

龙蜥开发者说:为爱发电!当一个龙蜥社区打包 Contributor 是怎样的体验?| 第16期

「龙蜥开发者说」第 16 期来了&#xff01;开发者与开源社区相辅相成&#xff0c;相互成就&#xff0c;这些个人在龙蜥社区的使用心得、实践总结和技术成长经历都是宝贵的&#xff0c;我们希望在这里让更多人看见技术的力量。本期故事&#xff0c;我们邀请了龙蜥社区开发者 Fun…

无线通信时代的新技术----信标( Beacon)

随着IT技术的发展&#xff0c;无线通信技术也在不断发展。 现已根据预期用途开发了各种无线通信技术&#xff0c;例如 NFC、WIFI、Bluetooth和 RFID。 车辆内部结构的复杂化和数字化&#xff0c;车载通信网络技术的重要性也越来越高。 一个典型的例子是远程信息处理。 远程信息…

注重邮件数据信息安全 保障企业稳步发展

近年来&#xff0c;世界各地的政府、银行、电信公司、制造业以及零售业等&#xff0c;不断发生数据泄密事件。 就企业而言&#xff0c;邮件数据很容易成为竞争对手或者诈骗者窃取的目标。 电子邮件是企业中一种重要的沟通工具但是随着网络攻击手段的不断升级&#xff0c;电子邮…

RN面试题

RN面试题1.React Native相对于原生的ios和Android有哪些优势&#xff1f;1.性能媲美原生APP 2.使用JavaScript编码&#xff0c;只要学习这一种语言 3.绝大部分代码安卓和IOS都能共用 4.组件式开发&#xff0c;代码重用性很高 5.跟编写网页一般&#xff0c;修改代码后即可自动刷…

关系数据库

关系的三类完整性约束实体完整性规则• 保证关系中的每个元组都是可识别的和惟一的 • 指关系数据库中所有的表都必须有主键&#xff0c;而且表中不允许存在如下记录&#xff1a;– 无主键值的记录– 主键值相同的记录• 原因&#xff1a;实体必须可区分• 就像实体-学生&#…

谷歌外推留痕,谷歌搜索留痕快速收录怎么做出来的?

本文主要分享谷歌搜索留痕的收录效果是怎么做的&#xff0c;让你对谷歌留痕技术有一个全面的了解。 本文由光算创作&#xff0c;有可能会被修改和剽窃&#xff0c;我们佛系对待这样的行为吧。 谷歌搜索留痕快速收录怎么做出来的&#xff1f; 答案是&#xff1a;通过谷歌蜘蛛…

XLSX.utils读取日期格式错误

表格中的时间为2023/2/16调用 XLSX.utils.sheet_to_json 读取到的时间为2/16/23时间格式不对-期待的时间格式为2023-02-16 00:00增加代码 cellDates: true, dateNF: "yyyy-MM-dd HH:mm" 解决问题readerData (rawFile) {this.loading truethis.isFile true // 流程结…

透射电镜测试样品的制备要求和方法

透射电镜&#xff08;Transmission Electron Microscope&#xff0c;TEM&#xff09;是一种高分辨率的显微镜&#xff0c;能够对样品进行高精度的成像和分析。为了得到高质量的TEM图像&#xff0c;样品制备是非常重要的。 ​ 样品选择 TEM样品应该是具有明确结构和化学成分的…

《分布式技术原理与算法解析》学习笔记Day21

分布式数据存储三要素 什么是分布式数据存储系统&#xff1f; 分布式存储系统的核心逻辑&#xff0c;就是将用户需要存储的数据根据某种规则存储到不同的机器上&#xff0c;当用户想要获取指定数据时&#xff0c;再按照规则到存储数据的机器中获取。 分布式存储系统的三要素…

苏州市软件行业协会第五届第四次理事会暨元宇宙专委会成立决议会在苏召开

2月17日&#xff0c;2022年度苏州市软件行业协会第五届第四次理事会暨苏州市软件行业协会元宇宙专委会成立决议会在西交利物浦大学顺利召开。会议选举西交利物浦大学担任苏州市软件行业协会元宇宙专委会第一届轮值会长单位。 苏州市工信局大数据处处长&#xff08;信息化和软件…

python+pytest接口自动化(1)-接口测试基础

接口定义一般我们所说的接口即API&#xff0c;那什么又是API呢&#xff0c;百度给的定义如下&#xff1a;API&#xff08;Application Programming Interface&#xff0c;应用程序接口&#xff09;是一些预先定义的接口&#xff08;如函数、HTTP接口&#xff09;&#xff0c;或…

MySQL锁篇

文章目录说明&#xff1a;锁篇一、MySQL有那些锁&#xff1f;二、MySQL 是怎么加锁的&#xff1f;三、update 没加索引会锁全表&#xff1f;四、MySQL 记录锁间隙锁可以防止删除操作而导致的幻读吗&#xff1f;五、MySQL 死锁了&#xff0c;怎么办&#xff1f;六、字节面试&…

【单例模式】单例模式创建的几种方式

一、饿汉模式饿汉模式是在类加载的时候就初始化了一份单例对象&#xff0c;所以他不存在线程安全问题。优点&#xff1a;不存在线程安全问题&#xff0c;天然的线程安全缺点&#xff1a;在类加载的时候就已经创建了对象&#xff0c;如果后续代码里没有使用到单例&#xff0c;就…

跟20%的同行去竞争80%的蓝海市场不香吗?

近年来&#xff0c;由于科技的发展等诸多因素&#xff0c;跨境电商行业有了长足的发展空间&#xff0c;不少人也有想要入行的打算。对于不是很了解这一行业的新手来说&#xff0c;如何选择合适的跨境电商市场与平台就显得至关重要。 一直以来&#xff0c;作为全球第四大电商市…