【pytorch】在多个batch中如何使用nn.CrossEntropyLoss

news2025/2/21 4:00:30

问题

有的时候我们需要计算多个batch的CrossEntropyLoss, 如下面的代码片段

....
criterion = nn.CrossEntropyLoss()

....

for input, target in self.dataloader:
            optimizer.zero_grad()

            .....
            # output shape (5,4,14)
            # target shape (5,4)
            loss = criterion(output, target)

从官网上的例子来看, 一般input为(Number of Batch, Features), 而target一般为 (N,)

Example of target with class indices

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

上面是一个batch的执行,但是在一些实际的训练过程中,可能是多个batch

如果直接执行开头的代码,会抛出如下错误
ValueError: Expected target size (5, 14), got torch.Size([5, 4])

这是因为开头的例子是一个nlp任务,input的shape是(5,4,14), 即(Number of Batch, Sequence length, Embedding size),这里多处一维,sequence length


分析

把output和target的数据通过debug获取出来单独计算尝试一下,下面的代码中,同时我使用numpy自己实现了一遍CrossEntropyLoss的计算,可以直接跳过查看最后调用nn.CrossEntropyLoss的部分。

import torch
import numpy as np


def my_softmax(x):
    output = np.zeros(x.shape)
    for n in range(x.shape[0]):
        exp_x = np.exp(x[n, :])
        output[n] = (exp_x / np.sum(exp_x))

    return output


def my_log_softmax(x):
    return np.log(my_softmax(x))


def my_nll_loss(P, Y, reduction='mean', ignore_index=-100):
    loss = []
    for n in range(len(Y)):
        if Y[n] == ignore_index:
            loss.append(-0.)
            continue
        p_n = P[n][Y[n]]
        loss.append(p_n.item())

    if reduction == 'mean':
        return -np.mean(loss)

    if reduction == 'sum':
        return -np.sum(loss)

    return -np.array(loss)


def batch_cross_entropy():
    # [B,S,E]
    output = np.array([[[-6.9800e-01, 6.8742e-01, 2.5055e-01, -6.6209e-01, -4.6491e-01,
                         -1.3935e-01, -1.7100e-01, 4.0013e-02, -3.6995e-01, -8.5358e-01,
                         -4.9449e-01, -4.5180e-01, -2.7848e-01, -1.1511e+00],
                        [-7.7217e-01, 5.0190e-01, 3.3348e-01, -4.0213e-01, -4.6606e-01,
                         -6.0082e-02, 4.7225e-01, 1.4079e-01, -1.7741e-01, -7.9565e-01,
                         -5.7972e-01, -4.8082e-01, -1.8605e-02, -9.5264e-01],
                        [-6.8221e-01, 3.7776e-01, 3.4762e-02, -6.9478e-01, -2.2510e-01,
                         3.0994e-01, -1.3499e-01, -1.6287e-01, -1.6151e-01, -2.4974e-01,
                         -4.6694e-01, -6.1922e-01, 2.4364e-01, -9.0690e-01],
                        [-8.0960e-01, 5.0074e-01, -1.8677e-01, -7.8651e-01, -4.1738e-01,
                         4.1874e-01, -2.3718e-01, -2.1826e-01, -3.3325e-01, -9.2656e-02,
                         -4.6586e-01, -8.4838e-01, 1.6432e-01, -6.5928e-01]],

                       [[-2.4560e-01, 6.9763e-01, 1.8138e-01, -3.2625e-02, -2.4262e-01,
                         -2.5643e-01, 1.1205e-01, 2.4543e-02, -4.4613e-01, -1.0645e+00,
                         -3.6831e-01, -4.1188e-02, -2.0788e-02, -1.0442e+00],
                        [-2.8846e-01, 8.2847e-01, -5.4134e-02, -7.8471e-01, 1.3351e-02,
                         -7.4033e-01, -6.3344e-01, -3.5146e-01, -8.5599e-01, -1.0859e+00,
                         -1.6991e-01, 4.7074e-02, 1.0111e-01, -5.1003e-01],
                        [-6.1263e-01, 7.3131e-01, 5.7170e-01, -3.8304e-02, 2.6139e-02,
                         -1.1358e-01, 5.1920e-01, 3.4961e-01, -2.8680e-01, -8.5890e-01,
                         -5.1087e-01, -3.2754e-01, 2.2287e-01, -6.6090e-01],
                        [-5.7762e-01, -1.6064e-01, -5.4849e-01, -5.2790e-02, -3.1316e-01,
                         5.7697e-01, 1.8820e-01, 1.9771e-03, 2.3494e-01, 4.6401e-02,
                         -6.0379e-01, -5.6362e-01, 1.0715e-01, -6.7643e-01]],

                       [[-7.5844e-01, 8.9643e-01, 4.2627e-02, -3.2765e-01, -3.2391e-01,
                         -3.7126e-01, 1.3792e-02, 1.6282e-03, -5.8745e-01, -4.6443e-01,
                         -2.7597e-01, -3.4279e-01, 1.0330e-03, -6.5268e-01],
                        [-6.7271e-01, 8.8120e-01, 4.4617e-01, -9.2040e-01, -3.0459e-01,
                         -3.1417e-01, -3.9815e-01, 1.0694e-01, -7.2992e-01, -5.3737e-01,
                         -1.6901e-01, -3.7259e-01, 9.2190e-02, -9.0215e-01],
                        [-6.4774e-01, 7.2040e-01, 7.7526e-01, -1.0923e+00, -8.9171e-02,
                         -6.2309e-05, 3.4601e-01, -6.7397e-02, -5.2992e-01, -4.7396e-01,
                         -2.0592e-01, -2.9428e-01, 2.7567e-01, -1.0032e+00],
                        [-9.6423e-01, 6.1445e-01, -6.5032e-01, -5.5757e-01, -6.0174e-01,
                         -1.6667e-01, 1.9756e-01, -5.3273e-01, -2.6795e-01, -1.6678e-01,
                         -4.7283e-01, -7.7119e-01, 8.7784e-02, -4.2825e-01]],

                       [[-2.8459e-01, 6.0364e-01, 5.0745e-01, -1.1500e-01, -2.8906e-01,
                         -2.1891e-01, 3.1818e-01, 2.6412e-01, -3.1559e-01, -9.2631e-01,
                         -2.5491e-01, -1.3816e-02, -2.7776e-01, -1.3621e+00],
                        [-6.3529e-01, 8.0968e-01, 5.9280e-01, -6.2296e-01, -3.4726e-01,
                         -1.6531e-01, 6.7529e-02, 3.7592e-01, -7.3573e-01, -1.0816e+00,
                         -3.1254e-01, -4.2386e-01, -2.4192e-01, -1.1896e+00],
                        [-7.9503e-01, 5.1963e-01, 5.1673e-01, -6.4723e-01, -8.6342e-02,
                         -2.1490e-01, 2.7284e-02, 2.6488e-01, -7.0478e-01, -1.1432e+00,
                         -2.9212e-01, -5.3028e-01, -4.8153e-01, -8.5909e-01],
                        [-7.9562e-01, 5.3502e-01, 1.2687e-01, -6.4034e-01, -1.4381e-01,
                         1.0957e-01, 2.4598e-02, 2.3910e-02, -6.8106e-01, -5.3939e-01,
                         -2.7420e-01, -4.9182e-01, 5.0746e-02, -8.6493e-01]],

                       [[-2.5208e-01, 9.5292e-02, 1.4688e-01, 4.0238e-01, -3.0913e-01,
                         -2.0094e-02, 3.9704e-01, 5.1999e-01, 1.2463e-01, -6.6643e-01,
                         -4.4233e-01, 4.3938e-03, -3.6015e-01, -1.0695e+00],
                        [-4.2988e-01, 3.2485e-01, 1.2833e-01, -7.1189e-01, -1.7690e-01,
                         -3.1612e-01, -4.5157e-01, -1.4707e-01, -2.3045e-01, -9.6345e-01,
                         -3.4908e-01, -4.5350e-01, -1.7349e-01, -7.9216e-01],
                        [-6.3809e-02, -5.2756e-02, -2.1734e-01, -1.5490e-01, -5.1187e-02,
                         -2.3425e-01, -3.4012e-01, -1.7033e-01, 5.0935e-02, -2.8938e-01,
                         -6.8729e-02, -2.7069e-01, -3.3257e-01, -4.0449e-01],
                        [-4.5155e-01, 1.0152e-01, -4.5864e-01, -2.4100e-01, -3.2433e-01,
                         3.0919e-01, 1.1523e-01, -3.3954e-01, 2.0666e-01, -1.9090e-01,
                         -4.4507e-01, -8.4536e-01, 2.5585e-01, -6.3963e-01]]])

    # [B, S]
    target = np.array([[0, 4, 1, -100],
                       [0, 8, 4, 1],
                       [0, 8, 7, 1],
                       [0, 11, 5, 1],
                       [0, 8, 6, 1]])

    my_crossentropy = np.zeros((output.shape[0], output.shape[1]))
    for i in range(output.shape[0]):
        my_crossentropy[i] = my_nll_loss(my_log_softmax(output[i]), target[i], reduction='none')
    
    # 注意这里一定要将reduction改为none,如果采用默认mean,那么所有的值都会混合到一起做平均
    # 这有时是合理的,有的时候却不是;所以最好的方式是自己做reduction
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    loss = criterion(torch.from_numpy(output).permute(0, 2, 1), torch.from_numpy(target).long())

    print("my_crossentropy:", my_crossentropy)
    print("crossentropy:", loss)


batch_cross_entropy()

在这里插入图片描述

这里需要把index标记为-100的去处计算,所以在做reduction的时候需要单独处理一下。


参考

  • 【pytorch】使用numpy实现pytorch的softmax函数与cross_entropy函数
  • stack overflow: torch-nn-crossentropyloss-over-multiple-batches

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/86795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文本预处理方法总结

数据的预处理 项目需要,需要进行词库训练与样本向量化处理,总结后有以下4种方法: 方法1:tf 1.xx版本: 词汇样本的处理:使用tensorflow.contrib.learn模块 vocab_process learn.preprocessing.Vocabula…

Docker安装RabbitMQ

文章目录1.下载Rabbitmq镜像2.创建并运行 RabbitMQ 容器3.启动rabbitmq_management4.访问前端页面5.开通端口1.下载Rabbitmq镜像 下载最新版本的镜像: docker pull rabbitmq如何想要其他版本可以访问 Docker 官网 https://hub.docker.com/_/rabbitmq?tabtags 2.…

AOP注解实现接口敏感字段加密

AOP注解实现接口敏感字段加密 文章目录AOP注解实现接口敏感字段加密定义方法注解EncryptMethod定义字段注解EncryptField新增加密解密工具定义AOP核心处理类EncryptFieldAop使用注解项目如果不允许明文存储敏感数据(例如身份证号、银行卡号,手机号等&…

ShuffleNetV2 结构(附源码)

本文不细看paper,只看网络结构和源码实现。 看下ShuffleNetV2的结构吧。 image是3通道进去,经过conv1和maxpool, 然后stage2~4则是主题,里面stride 2和 stride 1的shuffleBlock分别重复几次。 shuffleBlock如下,左边是stride…

搭建Kubord管理k8s/EKS以及Harbor私有仓库教程

eks首先要去aws后台进行创建,这里不再讲解详细的过程,下面讲解如果通过命令行以及kuboard调度esk服务。 安装docker以及docker-compose yum install docker service docker start curl https://get.daocloud.io/docker/compose/releases/download/1.24…

零食商城小程序开发,建立商家良好品牌形象

相信很多人都无法拒绝来自零食的诱惑,尤其是在闲暇刷剧时,一边看剧一边享受着味蕾的满足,简直不要太幸福。现在人们对于零食的要求越来越高,不仅注重口感,更讲究包装,这就让零食行业逐渐走向精细化。而零食…

ssm+Vue计算机毕业设计校园统一网络授课平台(程序+LW文档)

ssmVue计算机毕业设计校园统一网络授课平台(程序LW文档) 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项…

SpringMVC-狂神

SpringMVC优点: 轻量级,简单易学 高效,基于请求响应的MVC框架 与Spring无缝结合 功能强大:RESTful风格,数据验证,格式化,本地化,主题等 简单灵活 SpringMVC全部围绕DispatchSer…

AI(人工智能),时代的风口

你知道AI并非一个新词吗? 你知道 AI 正在影响着包括数学、物理学、生命科学等诸多领域的前沿科学研究吗? “AI是一个具有魅力的词,也是一个很古老的词”。 我们通常所说的AI (Artificial intelligence) 翻译为“人工…

安卓玩机搞机技巧综合资源-----不亮屏幕导资料 有屏幕锁保数据刷机等 多种方式【十五】

接上篇 安卓玩机搞机技巧综合资源------如何提取手机分区 小米机型代码分享等等 【一】 安卓玩机搞机技巧综合资源------开机英文提示解决dm-verity corruption your device is corrupt. 设备内部报错 AB分区等等【二】 安卓玩机搞机技巧综合资源------EROFS分区格式 小米红…

C#打开摄像头后获取图片,调用face_recognition进行人脸识别

运行效果如截图:左边和保存的图片做对比,打印相似度,部分打印内容为python中的打印输出,可以用来做结果判断。右边打开摄像头后,可以单张图片进行人脸识别,或者一直截图镜头中的图片进行比对。期中python是…

ReSharper添加对最新C#11特性的支持

ReSharper添加对最新C#11特性的支持 C#11 UTF-8文字-增加了对UTF-8文字的基本支持。代码分析现在建议对文字使用u8后缀,而不是System.Text.Encoding.UTF8.GetBytes()方法或具有适当UTF8符号的字节数组。还有一组UTF-8文本的编译器警告和错误。 文件本地类型-添加了对…

服务器公网带宽1M能同时接受多少人访问?

文章目录1、什么是服务器的带宽?2、服务器带宽多少?3、服务器带宽1M能同时接受多少人访问?1、什么是服务器的带宽? 在服务器托管中,服务器带宽指在特定时间段从或向网站/服务器传输的数据量,例如,单月内的累积消耗“带宽”,实…

【开源掌机】百问网DShanMCU-Mio开源掌机(爻-澪)项目,完美支持运行10多个模拟器!

众筹说明 定金翻倍,即定金19.9元,在付尾款时可抵40元(成品售价不会超过120元)!达标当天就开搞,满100人加速搞尽量在年前发货,让大家先玩起来!如果不达标则原路退款,项目取消。 众筹时间&#…

利用Matlab进行图像分割和边缘检测

本文章包含以下内容: 1、灰度阀值分割 (1)单阈值分割图像 先将一幅彩色图像转换为灰度图像,显示其直方图,参考直方图中灰度的分布,尝试确定阈值;应反复调节阈值的大小,直至二值化的效果最为满意…

LDR6035PD快充快放带数据还要啥莲花清翁

随着Type-C的普及和推广,目前市面上的移动电源正在慢慢淘汰micro-USB接口,逐渐都更新成了Type-C接口,micro-USB接口从2007年上市,已经陪伴我们走过十多个年头,自从2015年Type-C登场,micro-USB也开始渐渐淡出…

写给前端开发者的「Promise备忘手册」

前言 大家好,我是HoMeTown,Promise想必大家都知道,在平时的开发工程中也经常会有用到,但是Promise作为ES6的重要特性,其实还拥有很多丰富的知识,本文面向比较初级一些的同学,可以帮你搞懂Promi…

金庸群侠传3DUnity重置入门-Mods开发

金庸3DUnity重置入门系列文章 金庸3dUnity重置入门 - lua 语法 金庸3dUnity重置入门 - UniTask插件 金庸3dUnity重置入门 - Mods开发 金庸3dUnity重置入门 - Cinemachine 动画 金庸3dUnity重置入门 - 大世界实现方案 金庸3dUnity重置入门 - 素材极限压缩 (部分可能放到付…

[附源码]Nodejs计算机毕业设计基于web的社团管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

机器学习——01基础知识

机器学习——01基础知识 github地址:https://github.com/yijunquan-afk/machine-learning 参考资料 [1] 庞善民.西安交通大学机器学习导论2022春PPT [2] 周志华. 机器学习.北京:清华大学出版社,2016 [3] AIlearning 一、机器学习算法的应用 目前,机…