Semantic Segmentation using Adversarial Networks代码

news2024/11/14 20:03:51

代码来源
首先看一下模型架构:
在这里插入图片描述
损失计算:
在这里插入图片描述在这里插入图片描述

class GANUpdater(chainer.training.StandardUpdater, UpdaterMixin):

    def __init__(self, *args, **kwargs):
        self.model = kwargs.pop('model') # set for exeptions.Evaluator
        self.gen, self.dis = self.model['gen'], self.model['dis'] 
        self.L_bce_weight = kwargs.pop('L_bce_weight')
        self.n_class = kwargs.pop('n_class')
        self.xp = chainer.cuda.cupy if kwargs['device'] >= 0 else np
        kwargs = self._standard_updater_kwargs(**kwargs)
        super(GANUpdater, self).__init__(*args, **kwargs)

    def _get_loss_dis(self):
        batchsize = self.y_fake.data.shape[0]
        loss = F.softmax_cross_entropy(self.y_real, Variable(self.xp.ones(batchsize, dtype=self.xp.int32), volatile=not self.gen.train))
        loss += F.softmax_cross_entropy(self.y_fake, Variable(self.xp.zeros(batchsize, dtype=self.xp.int32), volatile=not self.gen.train))
        chainer.report({'loss': loss}, self.dis)
        return loss

    def _get_loss_gen(self):
        batchsize = self.y_fake.data.shape[0]
        L_mce = F.softmax_cross_entropy(self.pred_label_map, self.ground_truth, normalize=False)
        L_bce = F.softmax_cross_entropy(self.y_fake, Variable(self.xp.ones(batchsize, dtype=self.xp.int32), volatile=not self.gen.train))
        loss = L_mce + self.L_bce_weight * L_bce

        # log report
        label_true = chainer.cuda.to_cpu(self.ground_truth.data)
        label_pred = chainer.cuda.to_cpu(self.pred_label_map.data).argmax(axis=1)
        logs = []
        for i in six.moves.range(batchsize):
            acc, acc_cls, iu, fwavacc = utils.label_accuracy_score(
                label_true[i], label_pred[i], self.n_class)
            logs.append((acc, acc_cls, iu, fwavacc))
        log = np.array(logs).mean(axis=0)
        values = {
            'loss': loss,
            'accuracy': log[0],
            'accuracy_cls': log[1],
            'iu': log[2],
            'fwavacc': log[3],
        }
        chainer.report(values, self.gen)

        return loss

    def _make_dis_input(self, input_img, label_map):
        b = F.broadcast_to(input_img[:,0,:,:], shape=label_map.shape)
        g = F.broadcast_to(input_img[:,1,:,:], shape=label_map.shape)
        r = F.broadcast_to(input_img[:,2,:,:], shape=label_map.shape)
        product_b = label_map * b
        product_g = label_map * g
        product_r = label_map * r
        dis_input = F.concat([product_b, product_g, product_r], axis=1)
        return dis_input

    def _onehot_encode(self, label_map):
        for i, c in enumerate(six.moves.range(self.n_class)):
            mask = label_map==c
            mask = mask.reshape(1,mask.shape[0],mask.shape[1])
            if i==0:
                onehot = mask
            else:
                onehot = np.concatenate([onehot, mask]) 
        return onehot.astype(self.xp.float32)

    def forward(self, batch):
        label_onehot_batch = [self._onehot_encode(pair[1]) for pair in batch]

        input_img, ground_truth = self.converter(batch, self.device)
        ground_truth_onehot = self.converter(label_onehot_batch, self.device)
        input_img = Variable(input_img, volatile=not self.gen.train)
        ground_truth = Variable(ground_truth, volatile=not self.gen.train)
        ground_truth_onehot = Variable(ground_truth_onehot, volatile=not self.gen.train)
        
        x_real = self._make_dis_input(input_img, ground_truth_onehot)
        y_real = self.dis(x_real)

        pred_label_map = self.gen(input_img)
        x_fake = self._make_dis_input(input_img, F.softmax(pred_label_map))
        y_fake = self.dis(x_fake)

        self.y_fake = y_fake
        self.y_real = y_real
        self.pred_label_map = pred_label_map
        self.ground_truth = ground_truth
        
    def calc_loss(self):
        self.loss_dis = self._get_loss_dis()
        self.loss_gen = self._get_loss_gen()
        
    def backprop(self):
        self.dis.cleargrads()
        self.gen.cleargrads()
        self.loss_dis.backward()
        self.loss_gen.backward()
        self.get_optimizer('dis').update()
        self.get_optimizer('gen').update()

    def update_core(self):
        batch = self.get_iterator('main').next()
        self.forward(batch)
        self.calc_loss()
        self.backprop()

首先看生成器的损失:由两项组成,第一项计算分割的label_map和GT之间的损失,第二项计算进过生成器的输出和1之间的损失。

    def _get_loss_gen(self):
        batchsize = self.y_fake.data.shape[0]
        L_mce = F.softmax_cross_entropy(self.pred_label_map, self.ground_truth, normalize=False)
        L_bce = F.softmax_cross_entropy(self.y_fake, Variable(self.xp.ones(batchsize, dtype=self.xp.int32), volatile=not self.gen.train))
        loss = L_mce + self.L_bce_weight * L_bce

生成器的输入为x_fake。是输入图片和经过softmax之后predict_label进行concat之后的结果。如果是原始的GAN就是predict_label直接输入到辨别器中。x_fake输入到辨别器产生的为y_fake。
在这里插入图片描述
辨别器的损失:y_real即GT和原始的RGB图concat之后输入到辨别器的结果。那么希望分辨器能够分辨出来,所以与0进行损失计算。
这里的concat并非RGB和GT直接Concat,而是RGB广播到label大小后与label逐通道相乘再concat。
在这里插入图片描述
经过分割模型后生成的GT进行one-hot编码,即numclass个通道,每个通道由0,1组成。每个通道即为RGB中的每个类别,用1组成其余的由0组成。那么与原始的RGB相乘后,选择出RGB中对应的类别。
在这里插入图片描述
在这里插入图片描述
y_fake同上所述,我们希望分辨器能够辨别出来他是分割的结果而非原始的GT,因此与0计算损失。
在这里插入图片描述

与传统的GAN不同的地方是:segmentation的输出并非直接输入到adversarial model中,而是真实的GT和原始的RGB相乘,通道由C变为3C。
在这里插入图片描述
为了防止混淆画一下流程图:原始GAN。
在这里插入图片描述
本文:
在这里插入图片描述
接着是生成器和辨别器的组成:和DCGAN区别的地方是生成器输入不再是噪声,而是图片。和DCGAN类似的地方是用卷积进行下采样和上采样。
生成器:

import os,sys

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np

sys.path.append(os.path.split(os.path.split(os.getcwd())[0])[0])
import functions as f


class FCN32s(chainer.Chain):
    """Fully Convolutional Network 32s"""

    def __init__(self, n_class=21):
        self.train=True
        super(FCN32s, self).__init__(
            conv1_1=L.Convolution2D(3, 64, 3, stride=1, pad=100),
            conv1_2=L.Convolution2D(64, 64, 3, stride=1, pad=1),
            conv2_1=L.Convolution2D(64, 128, 3, stride=1, pad=1),
            conv2_2=L.Convolution2D(128, 128, 3, stride=1, pad=1),
            conv3_1=L.Convolution2D(128, 256, 3, stride=1, pad=1),
            conv3_2=L.Convolution2D(256, 256, 3, stride=1, pad=1),
            conv3_3=L.Convolution2D(256, 256, 3, stride=1, pad=1),
            conv4_1=L.Convolution2D(256, 512, 3, stride=1, pad=1),
            conv4_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
            conv4_3=L.Convolution2D(512, 512, 3, stride=1, pad=1),
            conv5_1=L.Convolution2D(512, 512, 3, stride=1, pad=1),
            conv5_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
            conv5_3=L.Convolution2D(512, 512, 3, stride=1, pad=1),
            fc6    =L.Convolution2D(512, 4096, 7, stride=1, pad=0),
            fc7    =L.Convolution2D(4096, 4096, 1, stride=1, pad=0),
            score_fr=L.Convolution2D(4096, n_class, 1, stride=1, pad=0,nobias=True, initialW=np.zeros((n_class, 4096, 1, 1))),
            upscore=L.Deconvolution2D(n_class, n_class, 64, stride=32, pad=0,nobias=True, initialW=f.bilinear_interpolation_kernel(n_class, n_class, ksize=64)),)

    def __call__(self, x):
        h = F.relu(self.conv1_1(x))
        h = F.relu(self.conv1_2(h))
        h = F.max_pooling_2d(h, 2, stride=2, pad=0)
        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        h = F.max_pooling_2d(h, 2, stride=2, pad=0)
        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = F.relu(self.conv3_3(h))
        h = F.max_pooling_2d(h, 2, stride=2, pad=0)
        h = F.relu(self.conv4_1(h))
        h = F.relu(self.conv4_2(h))
        h = F.relu(self.conv4_3(h))
        h = F.max_pooling_2d(h, 2, stride=2, pad=0)
        h = F.relu(self.conv5_1(h))
        h = F.relu(self.conv5_2(h))
        h = F.relu(self.conv5_3(h))
        h = F.max_pooling_2d(h, 2, stride=2, pad=0)
        h = F.relu(self.fc6(h))
        h = F.dropout(h, ratio=.5, train=self.train)
        h = F.relu(self.fc7(h))
        h = F.dropout(h, ratio=.5, train=self.train)
        score_fr = self.score_fr(h)

        upscore = self.upscore(score_fr)
        score = f.crop_to_target(upscore, target=x)

        return score

辨别器:四种变形,主要的区别就是卷积的通道不一致。

import os, sys

import chainer
import chainer.functions as F
import chainer.links as L

sys.path.append(os.path.split(os.path.split(os.getcwd())[0])[0])
import functions as f


class LargeFOV(chainer.Chain):

    def __init__(self, n_class=21):
        super(LargeFOV, self).__init__(
            conv1_1=L.Convolution2D(3*n_class, 96, 3, stride=1, pad=1),
            conv1_2=L.Convolution2D(96,  128, 3, stride=1, pad=1),
            conv1_3=L.Convolution2D(128, 128, 3, stride=1, pad=1),
            conv2_1=L.Convolution2D(128, 256, 3, stride=1, pad=1),
            conv2_2=L.Convolution2D(256, 256, 3, stride=1, pad=1),
            conv3_1=L.Convolution2D(256, 512, 3, stride=1, pad=1),
            conv3_2=L.Convolution2D(512, 2,   3, stride=1, pad=1),
        )

    def __call__(self, x):
        h = F.relu(self.conv1_1(x))
        h = F.relu(self.conv1_2(h))
        h = F.relu(self.conv1_3(h))
        h = F.max_pooling_2d(h, 2, stride=2)
        h = F.relu(self.conv2_1(h))
        h = F.relu(self.conv2_2(h))
        h = F.max_pooling_2d(h, 2, stride=2)
        h = F.relu(self.conv3_1(h))
        h = self.conv3_2(h)
        h = f.global_average_pooling_2d(h) #B,2,1,1
        h = F.reshape(h, (h.shape[0],h.shape[1]))# B,2
        return h

有一个疑惑的地方是输出的通道为2,输出的是一个概率,那输出不应该为1?
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505812.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

O2OA中如何使用PostgreSQL + Citus 实现分布式数据库实现方案?

虽然 O2OA 数据表高效的表结构以及索引的设计已经极大程度地保障了数据存取操作的性能,但是随着使用时间从增长,数据表存放的数据量也会急剧增长。此时,仍然需要有合适的方案来解决数据量产生的系统性能瓶颈。本文介绍通过 PostgreSQL Citus…

2023年5月DAMA-CDGA/CDGP数据治理认证开班啦,我要报名学习

6月18日DAMA-CDGA/CDGP数据治理认证考试开放报名中! 考试开放地区:北京、上海、广州、深圳、长沙、呼和浩特、杭州、南京、济南、成都、西安。其他地区凑人数中… DAMA-CDGA/CDGP数据治理认证班进行中,报名从速! DAMA认证为数据…

【刷题之路】LeetCode 234. 回文链表

【刷题之路】LeetCode 234. 回文链表 一、题目描述二、解题1、方法1——复制值到数组后用双指针1.1、思路分析1.2、代码实现 2、方法2——反转另一半链表2.1、思路分析2.2、代码实现2.3、补充 3、方法3——递归3.1、思路分析3.2、代码实现 一、题目描述 原题连接: …

计算机图形学 | 裁剪与屏幕映射

计算机图形学 | 裁剪与屏幕映射 计算机图形学 | 裁剪与屏幕映射8.1 裁剪思想裁剪的概念编码裁剪法中点裁剪法Liang-Barsky算法 8.2 真正的裁剪——在三维空间遇见多边形真正的裁剪多边形的裁剪Weiler-Atherton算法三维空间中的裁剪 8.3 几何阶段的完结:屏幕映射屏幕…

API 接口的使用和功能

随着互联网的快速发展,API接口已经成为了现代开发中不可或缺的一部分。API接口可以让你的应用程序与其他应用程序、系统或服务进行数据交流和集成。如果你正在开发应用程序,那么最好的方法就是使用API接口来增强功能和性能。 我们的API接口是为您的应用…

上财黄烨:金融科技人才的吸引与培养

“金融科技企业在吸引人才前,应先完善人才培养机制,建立员工画像,有针对性地培训提高成员综合素质。” ——上海金融智能工程技术研究中心上海财经大学金融科技研究院秘书长&院长助理黄烨老师 01.何为数字人才? 目前大多数研…

什么,你不会Windows本地账户和本地组账户的管理加固?没意思

什么,你不会Windows本地账户和本地组账户的管理加固?没意思 1.图形化界面方式管理用户2.图形化界面方式管理用户组3.命令行界面方式管理用户4.命令行界面方式管理账户组5.账户安全基线加固账户检查口令检查 1.图形化界面方式管理用户 1、打开管理界面 …

运维自动化工具 Ansible的安装部署和常用模块介绍

ansible安装 ansible的安装有很多种方式 官方文档:https://docs.ansible.com/ansible/latest/installation_guide/intro_installation.ht ml https://docs.ansible.com/ansible/latest/installation_guide/index.html 下载 https://releases.ansible.com/ansible…

Java入门全网最详细 - 从入门到转行

Java基础入门 - 坚持 Java 基本介绍Java 学习须知Java 学习文档Java 基础Java Hello WorldJava 变量Java 数据类型Java 运算符Java 修饰符Java 表达式 & 语句 & 代码块Java 注释--------------------------------------------------------------------------Java 控制语…

在vue中引入高德地图

既然要用到高德地图首先要申请成为高德地图开发者,并申请使用高德地图的key这两点在这篇文章就不过多赘述,有需要的小伙伴可以查查资料,或者去高德地图api官网都有很详细的介绍。高德地图官网 简单提一下申请秘钥流程(web端&#…

Python入门教程+项目实战-12.2节: 字典的操作方法

目录 12.2.1 字典的常用操作方法 12.2.2 字典的查找 12.2.3 字典的修改 12.2.4 字典的添加 12.2.5 字典的删除 12.2.6 知识要点 12.2.7 系统学习python 12.2.1 字典的常用操作方法 字典类型是一种抽象数据类型,抽象数据类型定义了数据类型的操作方法&#x…

想成为神经网络大师?这些常用算法和框架必须掌握!

神经网络是机器学习和人工智能领域中的一种常用算法,它在图像识别、自然语言处理等方面都有广泛的应用。如果你想入门神经网络,那么这篇文章就是为你准备的。 首先,了解基本概念是入门神经网络的基础。神经元是神经网络的基本组成部分&#x…

AQS底层源码解析

可重入锁 又叫递归锁,同一个线程在外层方法获得锁的时候,再进入该线程内层方法会自动获取锁,(前提锁对象是同一个对象)。不会因为之前已经获取过还没释放而阻塞。 Synchronized和ReentrantLock都是可重入锁&#xff…

玩游戏时突然弹出”显示器驱动程序已停止响应并且已恢复”怎么办

随着3A游戏大作不断面市,用户也不断地提升着自己的硬件设备。但是硬件更上了,却还会出现一些突如其来的情况,比如正准备开启某款游戏时,电脑右下角突然出现“显示器驱动程序已停止响应并且已恢复”。遇事不慌,驱动人生…

创新指南|5大策略让创新业务扩张最大避免“增长痛苦”

公司在开发和孵化新业务计划方面进行了大量投资,但很少有公司遵循严格的途径来扩大新业务规模。虽然80%的公司声称构思和孵化新企业,但只有16%的公司成功扩大了规模。典型案例是百思买在许多失败倒闭的扩大新业务取得了成功。它经历了建立新业务所需的3个…

如何使用 Python+selenium 进行 web 自动化测试?

使用Pythonselenium进行web自动化测试主要分为以下步骤: 在华为工作了10年的大佬出的Web自动化测试教程,华为现用技术教程!_哔哩哔哩_bilibili在华为工作了10年的大佬出的Web自动化测试教程,华为现用技术教程!共计16条…

VMware ESXi 7.0 U3m macOS Unlocker OEM BIOS (标准版和厂商定制版)

VMware ESXi 7.0 U3m macOS Unlocker & OEM BIOS (标准版和厂商定制版) 提供标准版和 Dell (戴尔)、HPE (慧与)、Lenovo (联想)、Inspur (浪潮)、Cisco (思科) 定制版镜像 请访问原文链接:https://sysin.org/blog/vmware-esxi-7-u3-oem/,查看最新版…

AC/DC、DC/DC转换器

什么是AC? Alternating Current(交流)的首字母缩写。 AC是大小和极性(方向)随时间呈周期性变化的电流。 电流极性在1秒内的变化次数被称为频率,以Hz为单位表示。 什么是DC? Direct Current&…

C语言的存储类别,链接和内存管理

目录 1.1作用域 1.2链接 1.3存储期 1.4存储类别 1.4.1自动变量 1.4.2寄存器变量 1.4.3块作用域的静态变量 1.4.4外部链接的静态变量 1.4.5内部链接的静态变量 1.4.6存储类别说明符 1.5动态内存管理 1.5.1出现原因 栈内存 数据段与代码段 堆内存 1.5.2动态内存函…

Flink第二章:基本操作

系列文章目录 Flink第一章:环境搭建 Flink第二章:基本操作 文章目录 系列文章目录前言一、Source1.读取无界数据流2.读取无界流数据3.从Kafka读取数据 二、Transform1.map(映射)2.filter(过滤)3.flatmap(扁平映射)4.keyBy(按键聚合)5.reduce(归约聚合)6.UDF(用户自定义函数)7.…