代码详解 —— VGG Loss

news2024/11/26 4:51:49

文章目录

  • VGG Loss 的基础概念
  • VGG 的网络结构
  • VGG LOSS 的代码解析
  • 参考

VGG Loss 的基础概念

VGG Loss 是content Loss中的一种。

为了评价图像的perceptual quality,《Perceptual losses for real time style transfer and super-resolution》 和 《GeneratingImageswithPerceptualSimilarityMetricsbasedonDeepNetworks》将content loss引入到SR中。

Content Loss利用预先训练的图像分类网络来度量图像之间的语义差异。将该网络表示为 φ φ φ,提取的第 l l l层high-level representation为 φ ( l ) ( I ) φ^{(l)}(I) φ(l)(I),Content Loss表示为两幅图像high-level representation之间的欧氏距离,如下:
L content  ( I ^ , I ; ϕ , l ) = 1 h l w l c l ∑ i , j , k ( ϕ i , j , k ( l ) ( I ^ ) − ϕ i , j , k ( l ) ( I ) ) 2 , \mathcal{L}_{\text {content }}(\hat{I}, I ; \phi, l)=\frac{1}{h_l w_l c_l} \sqrt{\sum_{i, j, k}\left(\phi_{i, j, k}^{(l)}(\hat{I})-\phi_{i, j, k}^{(l)}(I)\right)^2}, Lcontent (I^,I;ϕ,l)=hlwlcl1i,j,k(ϕi,j,k(l)(I^)ϕi,j,k(l)(I))2 ,
其中 h l h_l hl w l w_l wl c l c_l cl分别为 l l l层上表示的高度、宽度和通道数。
Content Loss 本质上是将learned knowledge of hierarchical image features从分类网络 φ φ φ转移到SR网络中。
与像素损失相比,内容损失促使输出图像 I ^ \hat I I^在感知上与目标图像 I I I相似,而不是强迫它们精确匹配像素。因此,它产生的结果在视觉上更加直观,其中VGG和ResNet是最常用的预训练CNN。

VGG 的网络结构

VGG的论文:《Very deep convolutional networks for large-scale image recognition》

VGG网络采用重复堆叠的小卷积核替代大卷积核,在保证具有相同感受野的条件下,提升了网络的深度,从而提升网络特征提取的能力。
可以把VGG网络看成是数个vgg_block的堆叠,每个vgg_block由几个卷积层+ReLU层,最后加上一层池化层组成。VGG网络名称后面的数字表示整个网络中包含参数层的数量(卷积层或全连接层,不含池化层),如图所示。

在这里插入图片描述

  • VGG16
    5个VGG块的卷积层数量分别为(2, 2, 3, 3, 3),再加上3个全连接层,总的参数层数量为16,因此命名为VGG16。
  • VGG19
    5个VGG块的卷积层数量分别为(2, 2, 4, 4, 4),再加上3个全连接层,总的参数层数量为19,因此命名为VGG19。

VGG LOSS 的代码解析


VGG19的代码实现

class VGG19(torch.nn.Module): # VGG19的网络
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

首先我们先打印出vgg_pretrained_features,对应着VGG19中的各个网络层。
在这里插入图片描述

然后分别打印出sclie1,sclie2,sclie3,sclie4, sclie5。可以看出就是把VGG网络的前30层拆分成不同的分组。
在这里插入图片描述


VGG LOSS的代码实现
上文定义的vgg的输出分别是5个sclice输出组成的列表。
假设输入分别是x和y,vgg loss 的值就是分别将x和y将5个sclice输出计算loss,一共有5个loss。然后再将这5个loss按照一定的权重加权求和得到最终的loss 。

# VGG 特征距离损失
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().cuda()
        # self.criterion = nn.L1Loss()
        self.criterion = nn.L1Loss(reduction='sum') # 求和
        self.criterion2 = nn.L1Loss()# 求平均
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] # 各个slice的输出权重

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            # print(x_vgg[i].shape, y_vgg[i].shape)
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) # 不同slice的loss的加权和
        return loss

    def forward2(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            # print(x_vgg[i].shape, y_vgg[i].shape)
            loss += self.weights[i] * self.criterion2(x_vgg[i], y_vgg[i].detach())
        return loss

参考

《SR中的常见的损失函数》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/875993.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机视觉 | 图像分割】arxiv 计算机视觉关于图像分割的学术速递(8 月 8 日论文合集)

文章目录 一、分割|语义相关(19篇)1.1 Mask Frozen-DETR: High Quality Instance Segmentation with One GPU1.2 AdaptiveSAM: Towards Efficient Tuning of SAM for Surgical Scene Segmentation1.3 SEM-GAT: Explainable Semantic Pose Estimation using Learned Graph Atten…

【Vue-Router】路由传参

1. query 传参 list.json {"data": [{"name": "面","price":300,"id": 1},{"name": "水","price":400,"id": 2},{"name": "菜","price":500,"…

eachars 自适应

目录 1. 案例: 2. 原因: 3. 解决: 1. 案例: 默认是正常宽度(如图1),当再次跳转会该页面时,eachars图发生变化(如图2)。 图1 图2 2. 原因: 没有…

三维模型OSGB格式轻量化在三维展示效果上的重要性探讨

三维模型OSGB格式轻量化在三维展示效果上的重要性探讨 三维模型在展示中的效果是十分关键的,因为它直接影响用户对模型的理解和体验。而OSGB格式轻量化是实现优质三维展示效果的关键技术之一。下面将详细介绍轻量化OSGB格式的三维模型在三维展示效果上的重要性。 首…

“多测合一”生产软件-不动产测量(不动产权籍调查测绘软件RESS),房地一体化测量由请湖南来示范

湖南“多测合一”生产软件-不动产测量软件,提取码:RESShttps://pan.baidu.com/s/1OqakLJICIP6buNiZ6j9Npw?pwdRESS 2020年7 月,国务院办公厅印发《 国务院办公厅关于进一步优化营商环境 更好服务市场主体的实施意见》 (国办发〔 …

Android之SQLite数据库的使用总结

一、SQLite数据库基本使用 1、特点 (1)嵌入式数据库,体积小 (2)数据库是由底层的sqlite.c执行程序的代码动态生成的,不用人为去创建数据库 (3)涉及抽象类SQLiteOpenHelper 2、继承…

【Java从0到1学习】08 String类

1. 概述 字符串是由多个字符组成的一串数据(字符序列),字符串可以看成是字符数组。 在实际开发中,字符串的操作是最常见的操作,没有之一。而Java没有内置的字符串类型,所以,就在Java类库中提供了一个类String 供我们…

如何使用CSS实现一个响应式网格布局?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 使用CSS实现响应式网格布局⭐ 设置基本的HTML结构⭐ 创建基本的CSS样式⭐ 添加媒体查询以实现响应式效果⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端…

成集云 | 用友U8采购请购单同步钉钉 | 解决方案

源系统成集云目标系统 方案介绍 用友U8是中国用友集团开发和推出的一款企业级管理软件产品。具有丰富的功能模块,包括财务管理、采购管理、销售管理、库存管理、生产管理、人力资源管理、客户关系管理等,可根据企业的需求选择相应的模块进行集…

Oracle 增加重做日志组、组成员

重做日志文件记录数据所有的修改信息并提供一种数据库失败时的恢复机制 一个Oracle数据库要求至少有两组重做日志文件 组中每个日志文件被称作一个组成员 需求:目前有三组重做日志组,增加一个重做日志组、并且增加两个重做日志组成员 1、查看重做日志组…

Python实现SSA智能麻雀搜索算法优化循环神经网络回归模型(LSTM回归算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 麻雀搜索算法(Sparrow Search Algorithm, SSA)是一种新型的群智能优化算法,在2020年提出&a…

日志采集分析ELK

这里的 ELK其实对应三种不同组件 1.ElasticSearch:基于Java,一个开源的分布式搜索引擎。 2.LogStash:基于Java,开源的用于收集,分析和存储日志的工具。(它和Beats有重叠的功能,Beats出现之后&a…

什么是集成测试?集成测试方法有哪些?

1、基本概念: 将软件集成起来后进行测试。集成测试又叫子系统测试、组装测试、部件测试等。集成测试主要是针对软件高层设计进行测试,一般来说是以模块和子系统为单位进行测试。 2、集成测试包含的层次: 1)模块内的集成&#x…

浅谈防火门监控系统的设计和安装

安科瑞 华楠 摘要:随着现代建筑中火灾成为重要的安全隐患,消防得到了越来越多民众的关注。更好地完成建筑消防设计成为了建筑电气设计中的重要一环。防火门监控系统是《火灾自动报警系统设计规范》(GB50116-2013)中规定在建筑中设置的火灾自动报警系统子…

h3c radius认证测试

客户端ssh连接上交换机进行管理,用radius里面的用户名和密码进行登陆 dis current-configuration vlan 1 vlan 2 interface Vlan-interface1 ip address 192.168.0.105 255.255.255.0 interface Vlan-interface2 ip address 192.168.60.244 255.255.255.0 interfac…

怎么样提升科研水平?

第15期《求是》杂志发表中强调加强基础研究,是实现高水平科技自立自强的迫切要求,是建设世界科技强国的必由之路。强化党建引领,为提升科研水平定向护航。 前一阵,广西审计厅发布的报告揭示了科研经费1.3亿,成果转化为…

CVPR 2023 | 用户可控的条件图像到视频生成方法(基于Diffusion)

注1:本文系“计算机视觉/三维重建论文速递”系列之一,致力于简洁清晰完整地介绍、解读计算机视觉,特别是三维重建领域最新的顶会/顶刊论文(包括但不限于 Nature/Science及其子刊; CVPR, ICCV, ECCV, NeurIPS, ICLR, ICML, TPAMI, IJCV 等)。 本次介绍的论…

学科在线教育元宇宙VR虚拟仿真平台落实更高质量的交互学习

为推动教育数字化,建设全民终身学习的学习型社会、学习型大国,元宇宙企业深圳华锐视点深度融合VR虚拟现实、数字孪生、云计算和三维建模等技术,搭建教育元宇宙平台,为学生提供更加沉浸式的学习体验,提高学习效果和兴趣…

JAVA宝典----输入输出流(理解记忆)

目录 一、 Java IO流的实现机制是什么? 二、Java中有几种类型的流? 三、管理文件和目录的类是什么? 四、Java Socket是什么? 五、什么是 JAVA NIO? 六、 什么是Java序列化? (1)序…

VVIC-据关键词取商品列表

一、接口参数说明: item_search-根据关键词取商品列表,点击更多API调试,请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/vvic/item_search 名称类型必须描述keyString是调用key(点击…