Loss模块

news2024/11/17 3:27:06

导入包

from utils import cos_sim, euclidean_dist

方法 EucSoftmax

变量

“”"Calculate cos distance loss.

Args:
    protos: protos vector in now episode (**class_size, hidden_size**)
    querys: queres vector to classify **(querys_len, hidden_size)**
    querys_y: corresponding y for query samples (**querys_len,** )
    t: (FloatTensor) temperature

Returns:
    return loss, neg dist (N*M)
"""

euclidean_dist(欧式距离)

欧氏距离也称为欧几里得距离,衡量得是多维空间中两个点之间绝对距离

以古希腊数学家欧几里得命名的距离,也就是我们直观的两点之间直线最短的直线距离

在这里插入图片描述
欧式距离的定义:是一个通常采用的距离定义,其是在 ∗ ∗ m **m m维空间中两个点之间的真实距离**。

在二维和三维空间中的欧氏距离就是两点之间的距离。二维的公式是:
d = ( x 1 − x w ) 2 + ( y 1 − y 2 ) 2 d = \sqrt{(x_1 - x_w)^2 + (y_1 - y_2)^2} d=(x1xw)2+(y1y2)2

三维度的公式是:
d = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 + ( z 1 − z 2 ) 2 d = \sqrt{(x_1 - x_2)^2 + (y_1- y_2)^2 + (z1 - z2)^2} d=(x1x2)2+y1y2)2+(z1z2)2

推广到n维空间欧氏距离公式为:
∑ i = 1 n ( y i − x i ) 2 \sqrt{\sum^n_{i = 1}(y_i - x_i)^2} i=1n(yixi)2

在这里插入图片描述
n 维 欧 氏 空 间 是 一 个 点 集 . n维欧氏空间是一个点集. n.
实现方式euclidean_dist(querys, protos)

F.cross_entropy

  • F.cross_entropy(x,y)
      cross_entropy(x,y)是交叉熵损失函数,一般用于在全连接层之后,做loss的计算

其中x是二维张量,是全连接层的输出;y是样本标签值。x[batch_size,type_num];y[batch_size]。

cross_entropy(x,y)计算结果是一个小数,表示loss的值

损失函数-交叉熵损失函数公式

在这里插入图片描述
实现方式:
F . c r o s s e n t r o p y ( ∗ ∗ n e g e u d i s t s , q u e r y s y ∗ ∗ ) F.cross_entropy(**neg_eu_dists, querys_y**) F.crossentropy(negeudists,querysy)
返回值:loss_sfm, neg_eu_dists 损失值和欧氏距离的负数。

utils模块中的欧氏距离

def euclidean_dist(x, y):
    """Compute euclidean distance between two tensors

    Args:
        x: (Tensor) N x D
        y: (Tensor) M x D

    Returns:
        Euclidean distance of x and y, a float
    """
    flag = False
    if y is None:
        y = x
        flag = True
    if len(list(x.size())) == len(list(y.size())) == 1:
        return torch.pow(x - y, 2).sum()
    x_norm = (x ** 2).sum(1).view(-1, 1)
    y_t = torch.transpose(y, 0, 1)
    y_norm = (y ** 2).sum(1).view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    # Ensure diagonal is zero if x=y
    if flag:
        dist = dist - torch.diag(dist.diag())
    return torch.clamp(dist, 0.0, np.inf)

方法CosT

"""Calculate cos distance loss.

Args:
    protos: protos vector in now episode (class_size, hidden_size)
    querys: queres vector to classify (querys_len, hidden_size)
    querys_y: corresponding y for query samples (querys_len, )
    t: (FloatTensor) temperature

Returns:
    return loss, neg dist (N*M)
"""

cos_sim(余弦相似度)

  • 计算余弦相似度公式:
    c o s = a ⋅ b ∣ ∣ a ∣ ∣ ∣ b ∣ ∣ ∣ cos = \frac{a \cdot b}{||a|||b|||} cos=abab
    其中 a 和 b 代 表 两 个 向 量 a和b代表两个向量 ab,(注意:向量是在空间中具有大小和方向的量

如果在二维空间中,余弦相似度的值通过以下公式计算:
在这里插入图片描述
在这里插入图片描述
余弦定理:
二维空间下计算两个向量的余弦相似性公式是根据余弦定理推导出来的,请感兴趣的可以自己推演一下(可根据以上两图)。
如果假设空间是多维的,则余弦相似度公式可扩展如下图:
在这里插入图片描述

代码实现

  • cos = cos_sim(querys, protos)
  • t_cos = t * cos
  • loss_cos = F.cross_entropy(t_cos, querys_y)
  • return loss_cos, cos

余弦相似度代码为:

def cos_sim(x, y=None, eps=1e-8):
    """Cosine Similarity calculate. Add eps for numerical stability"""
    y = x if y is None else y
    if len(list(x.size())) == len(list(y.size())) == 1:
        return torch.nn.CosineSimilarity(dim=-1, eps=eps)(x, y)
    x_n, y_n = x.norm(dim=1)[:, None], y.norm(dim=1)[:, None]
    x_norm = x / torch.max(x_n, eps * torch.ones_like(x_n))
    y_norm = y / torch.max(y_n, eps * torch.ones_like(y_n))
    sim_mt = torch.mm(x_norm, y_norm.transpose(0, 1))
    return sim_mt

总结

先将代码复制下来,后续在继续深耕,各章节的代码及其模块实现方式都行的回事与打算。

EucTriplet

“”"Calculate cos distance loss.

Args:
    protos: protos vector in now episode (class_size, hidden_size)
    querys: queres vector to classify (querys_len, hidden_size)
    querys_y: corresponding y for query samples (querys_len, )
    t: (FloatTensor) temperature

Returns:
    return loss, neg dist (N*M)
"""

a = torch.arange(querys_size).to(device)

将所有最开始读取数据时候的 t e n s o r tensor tensor变量,copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
这句话需要写的次数等于需要保存GPU上的tensor变量的个数;一般情况下这些tensor变量都是最开始读数据时的tensor变量,后面衍生的变量自然也都在GPU

代码展示

 device = protos.device
    querys_size = querys.size(0)
    a = torch.arange(querys_size).to(device)

    eu_dists = euclidean_dist(querys, protos)
    neg_eu_dists = -eu_dists

    masked_dists = eu_dists.clone()
    masked_dists[a, querys_y] = float('inf')

    neg_samples_dists = masked_dists.min(dim=-1)[0]
    pos_samples_dists = eu_dists[a, querys_y]

    loss_triplet = F.relu(pos_samples_dists - neg_samples_dists + 1.0).mean()
    return loss_triplet, neg_eu_dists

慢慢的将各种代码及包,全部搞定都行啦的里哟

F.relu

是一种人工神经网咯中常用的激活函数,通常意义下,其指的是:
数学中的斜坡函数: 即:
f ( x ) = m a x ( 0 , x ) f(x) = max(0,x) f(x)=max(0,x)
对应的函数图像如下所示:
在这里插入图片描述

  • 会自己搞清楚,这篇论文的代码中如何使用激活函数。以及如何传惨的
  • 全部都将其搞定都行啦的样子与打算,慢慢的学会如何使用激活函数以及如何传参。
  • 慢慢的研究一波都行啦的样子与打算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/55554.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深浅拷贝的区别?如何实现一个深拷贝?

一.数据类型存储 js中存在两大数据类型: 基本数据类型:保存在栈内存中; 引用数据类型:保存在堆内存中,引用数据类型的变量是一个指向堆内存中实际对象的引用,存在栈中。 二.浅拷贝 浅拷贝:…

CRM系统的功能有哪些?

**CRM系统**的功能有:1、联系人管理;2、沟通跟踪;3、潜客管理;4、电子邮件集成;5、文档管理;6、报价/提案管理;7、商机管理;8、工作流自动化;9、报表/分析;10…

利用Seagate service获得system shell

这是挖掘 CVE-2022-40286 漏洞的记录。 闲来无事,我上网随便找了一个驱动来进行测试。我想找一个知名公司的产品,但是又不能是太偏太难懂的东西。 我最先发现了一个叫"Seagate Media Sync"的软件,这是一个将文件复制到希捷无线硬…

SR-MPLS技术基础讲解

目录 SR-MPLS基础概念 使用Segment Routeing MPLS技术的优点 Segment Routeing MPLS的基本原理 SRGB Segment ID Bind SID 粘连标签 OSPF对于SR-MPLS的扩展 OSPF对邻接SID做了细分 10类LSA定义的TLV类型 10类LSA定义的TLV的报文格式 ISIS对SR-MPLS的扩展…

pyinstaller瘦身指南

目录说明无优化直接打包优化:创建专用虚拟环境原因分析和总结说明 之前写了一个自动化办公的python脚本,按需求打包exe。经过不断优化打包过程,把26.1MB的文件变成了9.5MB的文件。 打包工具pyinstaller。 安装: pip install pyi…

Ubuntu1804里进行KITTI数据集可视化操作

需要做的准备工作 1、需要提前安装kitti2bag(终端输入即可安装) pip install kitti2bag 如果没有pip,按照Ubuntu给的提示先安装pip 2、下载kitti数据集(下载圈出的两部分) kitti数据集的百度网盘链接 kitti数据集链接_FYY2LHH的博客-CSDN博客 文件存放位置如图 上图…

Android Material Design之Chip, ChipGroup(十二)

效果图 资源引入 implementation com.google.android.material:material:1.4.0属性 Chip 属性描述android:id控件idstyle样式属性系统默认4种 1.style/Widget.MaterialComponents.Chip.Entry 2.style/Widget.MaterialComponents.Chip.Choice3.style/Widget.MaterialCompon…

集团资金管理BI分析的三个关键节点

集团资金管理方面的商业智能BI分析怎么做?从财务角度来说,企业的管理是以财务管理为中心,财务管理以资金管理为中心,资金管理以现金流量为中心。围绕资金管理至少需要考虑三个方面的内容:安全、收益和效率。 在商业智…

【JavaEE】JavaScript(WebAPI)

努力经营当下,直至未来明朗! 文章目录前言一、前置知识二、【DOM】【获取元素】【事件】【操作元素】1.【获取/修改元素的内容】2.【获取/修改元素属性】3.【获取/修改 表单元素属性】4.【获取/修改样式属性】【操作节点】1.【新增节点】2.【删除节点】&…

【2-3个月左右录用】物联网、无线通信类、人工智能、传感器、人机交互等领域必投快刊,进展顺利,12月截稿

【期刊简介】3.0-4.0,JCR2/3区,中科院4区 【检索情况】SCI在检,正刊 【征稿领域】安全和隐私雾云辅助物联网网络 【参考周期】2-3个月左右 【截稿日期】2022年12月31日 【期刊简介】2.0-3.0,JCR3区,中科院4区 【检索情…

【白嫖】如何低价续费服务器

背景 现在各大云服务商的学生价服务器都已经关闭了,华为云、阿里云、百度云,以前都有学生价服务器,一年只要99。现在我找半天都没找到入口,而原价的一年得500起步。。。 但是!!!腾讯虽然也取消了…

【系统性学习】Linux Shell易忘重点整理

本文主要基于《实用Linux Shell编程》总结,并加入一些网上查询资料和博主自己的推断。 其中命令相关的,已抽取出来在另一篇系统性学习】Linux Shell常用命令中,可以一起使用。 文章目录一、基础知识二、命令与环境三、变量和数组四、条件流程…

Linux多线程C++版(八) 线程同步方式-----条件变量

目录1.条件变量基本概念2.条件变量创建和销毁3.条件变量等待操作4.条件变量通知(唤醒)操作5.代码了解线程同步6.线程的状态转换7.代码改进--从一对一到一对多1.条件变量基本概念 互斥锁的缺点是它只有两种状态:锁定和非锁定条件变量通过允许线程阻塞和等待另一个线…

Kamiya丨Kamiya艾美捷抗FLAG多克隆说明书

Kamiya艾美捷抗FLAG多克隆化学性质: 程序:用FLAG肽免疫家兔与KLH偶联。多次免疫后在弗氏佐剂中收集血清使用固定在固相上的肽。 规范: 使用氨基末端分析抗体Met FLAG BAP、氨基末端FLAG-BAP和羧基末端FLAG-BAP融合蛋白和Invitrogen Posite…

跳槽,从这一个坑,跳进另外一个坑

软件测试员跳槽有一个奇怪的现象:那些跳槽的测试员们,总是从一个坑,跳进另一个坑中,无论怎么折腾,也没能拿到更好的offer,更别说,薪资实现爆炸式增长,自身价值得到升华~ 在如今经验…

【Web安全】注入攻击

目录 前言 1、注入攻击 1.1 SQL注入 1.2 数据库攻击技巧 1.2.1 常见的攻击技巧 1.2.2 命令执行 1.2.3 攻击存储过程 1.2.4 编码问题 1.2.5 SQL Column Truncation 1.3 正确防御SQL注入 1.4 其他注入攻击 1.4.1 XML注入 1.4.2 代码注入 1.4.3 CRLF注入 前言 年…

Kotlin高仿微信-第53篇-添加好友

Kotlin高仿微信-项目实践58篇详细讲解了各个功能点,包括:注册、登录、主页、单聊(文本、表情、语音、图片、小视频、视频通话、语音通话、红包、转账)、群聊、个人信息、朋友圈、支付服务、扫一扫、搜索好友、添加好友、开通VIP等众多功能。 Kotlin高仿…

数商云SRM系统询比价有何优势?供应商平台助力汽车零部件企业快速查找供应商

随着中国汽车行业的高速发展、汽车保有量的增加以及汽车零部件市场的扩大,我国汽车零部件行业得到了迅速发展,增长速度整体高于我国整车行业。数据显示,我国汽车零部件的销售收入从2016年3.46万亿元增长至2020年的4.57万亿元,年均…

世界杯——手动为梅西标名

梅西的铁粉来集赞啦。 今天带来了一个为图片添加字样的小功能,我们的测试目标图片是: 我们的测试目标是: 我们使用的是Python语言,使用了Image包用作图片处理,matplotlib包用作坐标查阅,这个坐标还是很好看…

微服务框架 SpringCloud微服务架构 8 Gateway 网关 8.5 全局过滤器

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式,系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构8 Gateway 网关8.5 全局过滤器8.5.1 全局过滤器 GlobalFilter8.5.2 案例8.…