【PYG】Cora数据集分类任务计算损失,cross_entropy为什么不能直接替换成mse_loss

news2025/1/14 10:50:33
  • cross_entropy计算误差方式,输入向量z为[1,2,3],预测y为[1],选择数为2,计算出一大坨e的式子为3.405,再用-2+3.405计算得到1.405
  • MSE计算误差方式,输入z为[1,2,3],预测向量应该是[1,0,0],和输入向量维度相同

在这里插入图片描述
将cross_entropy直接替换成mse_loss报错RuntimeError: The size of tensor a (7) must match the size of tensor b (140) at non-singleton dimension 1

cross_entropy 换成 mse_loss 会报错的原因是,这两个损失函数的输入和输出形状要求不同。cross_entropy 是一个分类损失函数,它期望输入是未归一化的logits(形状为 [batch_size, num_classes]),而标签是整数类别(形状为 [batch_size])。mse_loss 是一个回归损失函数,它期望输入和标签的形状相同。

如果你想使用 mse_loss 来替代 cross_entropy,你需要对标签进行one-hot编码,使它们与模型的输出形状匹配。下面是如何修改代码以使用 mse_loss 的示例:

修改代码以使用 mse_loss

  1. 加载必要的库
    你需要一个工具来将标签转换为one-hot编码。这里我们使用 torch.nn.functional.one_hot

  2. 修改训练函数
    在训练函数中,将标签转换为one-hot编码,然后计算 mse_loss

核心测试代码讲解

out=model(data)模型输出形状为torch.Size([140, 7])
data.y中测试数据输出形状为torch.Size([140]),打印第一个数据为3,7个类别中的第4个类别
将3转化为7位置独热码计算MSE,对应train_labels_one_hot第一个数据[0., 0., 0., 1., 0., 0., 0.]为4
out形状为torch.Size([140, 7]),train_labels_one_hot的形状为[140, 7]

torch.Size([140, 7]) torch.Size([140])
tensor([-0.0166,  0.0191, -0.0036, -0.0053, -0.0160,  0.0071, -0.0042],
       device='cuda:0', grad_fn=<SelectBackward0>) tensor(3, device='cuda:0')
tensor([[0., 0., 0., 1., 0., 0., 0.],
		...
		[0., 1., 0., 0., 0., 0., 0.]], device='cuda:0')
train_labels_one_hot shape torch.Size([140, 7])
test out torch.Size([2708, 7])
train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
print(out[data.train_mask].shape, data.y[data.train_mask].shape)
print(out[data.train_mask][0], data.y[data.train_mask][0])
print(train_labels_one_hot)
print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)

解释

  1. 加载库:我们使用 torch.nn.functional.one_hot 将标签转换为one-hot编码。
  2. 修改训练函数
    • 将标签 train_labels 转换为one-hot编码,train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()
    • 使用 mse_loss 计算均方误差损失 loss = F.mse_loss(train_out, train_labels_one_hot)
  3. 保持评估函数不变:评估函数仍然使用 argmax 提取预测类别,并计算准确性。

魔改完整代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora',  transform=NormalizeFeatures())
data = dataset[0]

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x
        # return F.log_softmax(x, dim=1)

# 初始化模型和优化器
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')

# 打印归一化后的特征
print(data.x[0])

print(f"data.train_mask{data.train_mask}")

# 训练模型
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    # print(f"out[data.train_mask] {data.train_mask.shape} {out[data.train_mask].shape} {out[data.train_mask]}")
    # loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
    print(out[data.train_mask].shape, data.y[data.train_mask].shape)
    print(out[data.train_mask][0], data.y[data.train_mask][0])
    print(train_labels_one_hot)
    print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
    loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)
    loss.backward()
    optimizer.step()
    return loss.item()

# 评估模型
def test():
    model.eval()
    out = model(data)
    print(f"test out {out.shape}")
    print(f"test out[0] {out[0].shape} {out[0]}")
    print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")
    print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")
    pred = out.argmax(dim=1)
    print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")
    print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

for epoch in range(1):
    loss = train()
    acc = test()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

原始代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载Cora数据集,并应用NormalizeFeatures变换
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

# 计算训练、验证和测试集的大小
num_train = data.train_mask.sum().item()
num_val = data.val_mask.sum().item()
num_test = data.test_mask.sum().item()

print(f'Number of training nodes: {num_train}')
print(f'Number of validation nodes: {num_val}')
print(f'Number of test nodes: {num_test}')

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x  # 返回未归一化的logits

# 初始化模型和优化器
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')

# 训练模型
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)  # out 的形状是 [num_nodes, num_classes]
    train_out = out[data.train_mask]  # 选择训练集节点的输出
    train_labels = data.y[data.train_mask]  # 选择训练集节点的标签

    # 将标签转换为one-hot编码
    train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()

    # 计算均方误差损失
    loss = F.mse_loss(train_out, train_labels_one_hot)
    loss.backward()
    optimizer.step()
    return loss.item()

# 评估模型
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 提取预测类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

for epoch in range(200):
    loss = train()
    acc = test()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

通过这些修改,你可以将交叉熵损失函数替换为均方误差损失函数,并确保输入和标签的形状匹配,从而避免报错。

  • 简单版本的的答案

Cross Entropy vs. MSE Loss

  1. Cross Entropy Loss:

    • 输入:模型的logits,形状为 ([N, C]),其中 (N) 是批次大小,(C) 是类别数量。
    • 目标:目标类别的索引,形状为 ([N])。
  2. MSE Loss:

    • 输入:模型的预测值,形状为 ([N, C])。
    • 目标:实际值,形状为 ([N, C])(通常是 one-hot 编码)。

要将 cross_entropy 换成 mse_loss,需要确保输入和目标的形状匹配。具体来说,你需要将目标类别索引转换为 one-hot 编码。

示例代码

假设你有一个分类任务,其中模型输出的是 logits,目标是类别索引。我们将这个设置转换为使用 MSE Loss。

import torch
import torch.nn.functional as F

# 假设有一个批次的模型输出和目标标签
logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], requires_grad=True)  # 模型输出
target = torch.tensor([0, 2])  # 目标类别

# 使用 cross_entropy
cross_entropy_loss = F.cross_entropy(logits, target)
print("Cross-Entropy Loss:")
print(cross_entropy_loss)

# 转换目标类别为 one-hot 编码
target_one_hot = F.one_hot(target, num_classes=logits.size(1)).float()
print("One-Hot Encoded Targets:")
print(target_one_hot)

# 计算 MSE Loss
mse_loss = F.mse_loss(F.softmax(logits, dim=1), target_one_hot)
print("MSE Loss:")
print(mse_loss)

输出

Cross-Entropy Loss:
tensor(1.4076, grad_fn=<NllLossBackward>)
One-Hot Encoded Targets:
tensor([[1., 0., 0.],
        [0., 0., 1.]])
MSE Loss:
tensor(0.2181, grad_fn=<MseLossBackward>)

解释

  1. logits: 模型的原始输出,形状为 ([N, C])。
  2. target: 原始目标类别索引,形状为 ([N])。
  3. target_one_hot: 将目标类别索引转换为 one-hot 编码,形状为 ([N, C])。
  4. F.mse_loss: 使用 F.softmax(logits, dim=1) 计算模型的概率分布,然后与 target_one_hot 计算 MSE 损失。

通过将目标类别转换为 one-hot 编码并确保输入和目标的形状匹配,可以成功地将 cross_entropy 换成 mse_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1887442.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IAR工程目录移动报错(改变文件目录结构)

刚开始用IAR&#xff0c;记录一下。 工作中使用华大单片机&#xff0c;例程的文件目录结构太复杂了想精简一点。 1.如果原本的C文件相对工程文件&#xff08;.eww文件&#xff09;路径变化了&#xff0c;需要先打开工程&#xff0c;再将所有的.c文件右键Add添加进工程&#xf…

【Godot4.2】Godot中的贝塞尔曲线

概述 通过指定平面上的多个点&#xff0c;然后顺次连接&#xff0c;我们可以得到折线段&#xff0c;如果闭合图形&#xff0c;就可以获得多边形。通过向量旋转我们可以获得圆等特殊图形。 但是对于任意曲线&#xff0c;我们无法使用简单的方式来获取其顶点&#xff0c;好在计…

X-ObjectMount: 对象存储访问接入的新选择

XEOS 自 2017 年发布面世以来&#xff0c;历经 7 年的研发迭代&#xff0c;上个月正式发布了 XSKY SDS 6.4 版本&#xff0c;包含了最新的多站点统一命名空间能力&#xff0c;也标志了 XEOS 在对象存储领域的全方面优势和领先市场地位。 在 XSKY 过去对象存储服务历程里&#…

mysql 命令 —— 查看表信息(show table status)

查询表信息&#xff0c;如整个表的数据量大小、表的索引占用空间大小等 1、查询某个库下面的所有表信息&#xff1a; SHOW TABLE STATUS FROM your_database_name;2、查询指定的表信息&#xff1a; SHOW TABLE STATUS LIKE your_table_name;如&#xff1a;Data_length 显示表…

openGauss真的比PostgreSQL差了10年?

前不久写了MogDB针对PostgreSQL的兼容性文章&#xff0c;我在文中提到针对PostgreSQL而言&#xff0c;MogDB兼容性还是不错的&#xff0c;其中也给出了其中一个能源客户之前POC的迁移报告数据。 But很快我发现总有人回留言喷我&#xff0c;而且我发现每次喷的这帮人是根本不看文…

Python基础003

Python流程控制基础 1.条件语句 内置函数input a input("请输入一段内容&#xff1a;") print(a) print(type(a))代码执行的时候遇到input函数&#xff0c;就会等键盘输入结果&#xff0c;已回车为结束标志&#xff0c;也就时说输入回车后代码才会执行 2.顺序执行…

【问题记录】如何在xftp上查看隐藏文件。

显示隐藏的文件夹 用xftp连接到服务器后&#xff0c;发现有些隐藏的文件夹并未显示出来&#xff0c;通过以下配置&#xff0c;即可使隐藏的文件夹给显示出来。 1.点击菜单栏的"小齿轮"按钮&#xff1a; 2.勾选显示隐藏的文件夹&#xff1a; 3.点击确定即可。

古韵流光:探秘五代耀州窑青瓷提梁倒灌壶的奇妙设计

在陕西历史博物馆的静谧展厅中&#xff0c;一件千年前的瓷器静静陈列&#xff0c;它不仅承载着历史的沉淀&#xff0c;更凝聚了古代匠人的非凡智慧。这便是五代时期的耀州窑青瓷提梁倒灌壶&#xff0c;一件巧夺天工的艺术品&#xff0c;其独特的设计至今仍让人叹为观止。 一、倒…

算法mq 交互通用校验模块设计

背景 当前与算法交互均通过rocketMQ异步交互&#xff0c;绝大部分场景一条请求mq消息应对应一条返回mq&#xff0c;但由于各种原因&#xff08;消息积压、程序bug&#xff09;&#xff0c;可能会导致返回mq超时未返回或者消息丢失。工程侧针对一些重要场景 case by case的通过…

【web3】分享一个web入门学习平台-HackQuest

前言 一直想进入web3行业&#xff0c;但是没有什么途径&#xff0c;偶然在电鸭平台看到HackQuest的共学营&#xff0c;发现真的不错&#xff0c;并且还接触到了黑客松这种形式。 链接地址&#xff1a;HackQuest 平台功能 学习路径&#xff1a;平台有完整的学习路径&#xff…

VS2022+Qt+OpenCV Debug模式下,循环中格式转换引起的内存异常问题 debug_heap.cpp

文章目录 前言一、问题二、报错1.提示图片2.提示堆栈3.反汇编位置 三、解决办法总结 前言 最近在使用VS2022&#xff0c;C&#xff0c;OpenCV&#xff0c;Qt开发时&#xff0c;遇到了一个疑难杂症-在循环中执行字符串格式转换会触发内存异常&#xff0c;经过痛苦的排查过程&am…

Ubuntu下反弹shell的思考

目录 Ubuntu的命令执行环境 bash (Bourne Again SHell): sh (Bourne SHell): dash (Debian Almquist SHell): 它们之间的关系&#xff1a; 可能遇到的问题 一、脚本权限问题 二、命令执行环境(shell解释器)问题 如何解决&#xff1f; 1.修改/bin/sh软连接的指向为bas…

C++字体库开发

建议根据字体需求&#xff0c;多个组合使用。高度定制可基于freeTypeharfbuzz基础库完成。 GitHub - GNOME/pango: Read-only mirror of https://gitlab.gnome.org/GNOME/pango GitHub - googlefonts/fontview: Demo app that displays fonts with a free/libre/open-source …

Java_多线程:线程和死锁

一、线程 1、线程的状态流转 新建状态&#xff08;New&#xff09;&#xff1a;当线程对象对创建后&#xff0c;即进入了新建状态&#xff0c;如&#xff1a;Thread t new MyThread();就绪状态&#xff08;Runnable&#xff09;&#xff1a;当调用线程对象的start()方法&…

JAVA极简图书管理系统,初识springboot后端项目

前提条件&#xff1a; 具备基础的springboot 知识 Java基础 废话不多说&#xff01; 创建项目 配置所需环境 将application.properties>application.yml 配置以下环境 数据库连接MySQL 自己创建的数据库名称为book_test server:port: 8080 spring:datasource:url:…

搜索型数据库的技术发展历程与趋势前瞻

概述 随着数字科技的飞速发展和信息量的爆炸性增长&#xff0c;搜索引擎已成为我们获取信息的首选途径之一&#xff0c;典型的代表厂商如 Google。然而&#xff0c;随着用户需求的不断演变&#xff0c;传统的搜索技术已经无法满足人们对信息的实时性、个性化和多样性的需求。 …

C++基础知识-编译相关

记录C语言相关的基础知识 1 C源码到可执行文件的四个阶段 预处理(.i)、编译(.s)、汇编(.obj)、链接。 1.1 预处理 预处理阶段&#xff0c;主要完成宏替换、文件展开、注释删除、条件编译展开、添加行号和文件名标识&#xff0c;输出.i/.ii预处理文件。 宏替换&#xff0c;…

AI的价值——不再那么“废”人,保险行业用AI人员流失减少20%

最近有个热点挺让人唏嘘的&#xff0c;某咖啡店员工对顾客泼咖啡粉&#xff0c;我们对于这个事件不予评价。但是对员工这种情绪崩溃&#xff0c;我们所接触的行业中也有不少例子&#xff0c;比如保险行业&#xff0c;相信大家经常会被打保险电话&#xff0c;这类电话很容易就被…

K8s 的最后一片拼图:dbPaaS

K8s 的发展使得私有云跟公共云之间的技术差不断的缩小&#xff0c;不管是在私有云还是公共云&#xff0c;大家今天都在基于 K8s 去开发 PaaS 系统。而 K8s 作为构建 PaaS 的基础&#xff0c;其全景图里还缺最后一块“拼图”——dbPaaS。作为一个云数据库行业干了十几年的资深从…

Swin Transformer:最佳论文,准确率和性能双佳的视觉Transformer | ICCV 2021

论文提出了经典的Vision Transormer模型Swin Transformer&#xff0c;能够构建层级特征提高任务准确率&#xff0c;而且其计算复杂度经过各种加速设计&#xff0c;能够与输入图片大小成线性关系。从实验结果来看&#xff0c;Swin Transormer在各视觉任务上都有很不错的准确率&a…