PyTorch自定义损失函数实现

news2025/1/10 18:01:47

在机器学习中,损失函数是衡量预测输出与实际输出之间差异的关键组成部分。 它在模型训练中起着至关重要的作用,因为它通过指示模型应该改进的方向来指导优化过程。 损失函数的选择取决于具体的任务和数据类型。 在本文中,我们将以用于手写数字分类的 MNIST 数据集为例,深入研究 PyTorch 中自定义损失函数的理论和实现。
在这里插入图片描述

推荐:使用 NSDT场景设计器 快速搭建 3D场景。

1、概述

MNIST 数据集是广泛用于图像分类任务的数据集,它包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。 此任务旨在根据 MNIST 数据集中提供的训练示例训练一个模型,该模型可以准确地对手写数字的新图像进行分类。

此任务的典型方法是使用多类逻辑回归模型,它是一个 softmax 分类器。 softmax 函数将模型的输出映射到 10 个类别的概率分布。 交叉熵损失通常用作此类模型的损失函数。 交叉熵损失计算预测概率分布与实际概率分布之间的差异。

然而,在某些情况下,交叉熵损失可能不是特定任务的最佳选择。 例如,考虑一个场景,其中错误分类某些类的成本比其他类高得多。 在这种情况下,有必要使用考虑到每个类的相对重要性的自定义损失函数。

在本文中,我将向你展示如何为 MNIST 数据集实现自定义损失函数,其中误分类数字 9 的成本远高于其他数字。 我们将使用 Pytorch 作为框架,首先讨论自定义损失函数背后的理论,然后我们将展示使用 Pytorch 实现自定义损失函数。 最后,我们将使用自定义损失函数在 MNIST 数据集上训练线性模型,并评估模型的性能。

2、自定义损失函数:为什么

出于以下几个原因,实现自定义损失函数很重要:

  • Problem-specific:损失函数的选择取决于具体任务和数据类型。 可以设计自定义损失函数以更好地适应手头问题的特征,从而提高模型性能。
  • 类不平衡:在许多现实世界的数据集中,每个类中的样本数量可能非常不同。 可以设计一个自定义损失函数来考虑类别不平衡,并为不同的类别分配不同的成本。
  • 成本敏感:在某些任务中,错误分类某些类别的成本可能比其他类别高得多。 可以设计自定义损失函数以考虑每个类的相对重要性,从而产生更稳健的模型。
  • 多任务学习:可以设计自定义损失函数来同时处理多个任务。 这在需要单个模型来执行多个相关任务的情况下非常有用。
  • 正则化:自定义损失函数也可以用于正则化,有助于防止过拟合,提高模型的泛化能力。
  • 对抗性训练:自定义损失函数也可用于训练模型以抵抗对抗性攻击。
    总之,自定义损失函数可以提供一种更好地针对特定问题优化模型的方法,并且可以提供更好的性能和泛化能力。

3、PyTorch 中的自定义损失函数

MNIST 数据集包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。 此任务的典型方法是使用多类逻辑回归模型,它是一个 softmax 分类器。 softmax 函数将模型的输出映射到 10 个类别的概率分布。 交叉熵损失通常用作此类模型的损失函数。

交叉熵损失计算预测概率分布与实际概率分布之间的差异。 通过将 softmax 函数应用于模型的输出来获得预测的概率分布。 实际的概率分布是一个one-hot vector,其中正确类别对应的元素值为1,其他元素值为0。交叉熵损失定义为:

    L = -∑(y_i * log(p_i))

其中 y_i 是类别 i 的实际概率,p_i 是类别 i 的预测概率。

然而,在某些情况下,交叉熵损失可能不是特定任务的最佳选择。 例如,考虑一个场景,其中错误分类某些类的成本比其他类高得多。 在这种情况下,有必要使用考虑到每个类的相对重要性的自定义损失函数。

在 PyTorch 中,可以通过创建 nn.Module 类的子类并覆盖 forward 方法来实现自定义损失函数。 forward 方法将预测输出和实际输出作为输入,并返回损失值。

下面是 MNIST 分类任务的自定义损失函数示例,其中错误分类数字 9 的成本远高于其他数字:

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        target = torch.LongTensor(target)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        mask = target == 9
        high_cost = (loss * mask.float()).mean()
        return loss + high_cost

在这个例子中,我们首先使用 nn.CrossEntropyLoss() 函数计算交叉熵损失。 接下来,我们为属于类别 9 的样本创建掩码 1,为其他样本创建掩码 0。 然后我们计算属于类别 9 的样本的平均损失。最后,我们将这个高成本损失添加到原始损失中以获得最终损失。

要使用自定义损失函数,我们需要将其实例化并将其作为参数传递给训练循环中优化器的标准参数。 以下是如何使用自定义损失函数在 MNIST 数据集上训练模型的示例:

import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import os

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        target = torch.LongTensor(target)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        mask = target == 9
        high_cost = (loss * mask.float()).mean()
        return loss + high_cost




# Load the MNIST dataset
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True)


# Define the model, loss function and optimizer
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
optimizer = optim.SGD(network.parameters(), lr=0.01,
                      momentum=0.5)
criterion = CustomLoss()

# Training loop
n_epochs = 10

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

if os.path.exists('results'):
  os.system('rm -r results')

os.mkdir('results')

def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 1000 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      torch.save(network.state_dict(), 'results/model.pth')
      torch.save(optimizer.state_dict(), 'results/optimizer.pth')

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += criterion(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

此代码是 PyTorch 中 MNIST 数据集的自定义损失函数的实现。 MNIST 数据集包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。

第一个代码块通过继承 PyTorch nn.Module 创建一个名为“CustomLoss”的自定义损失函数。 它有一个前向方法,接受两个输入; 模型的输出和目标标签。 forward 方法首先将目标标签转换为长整数张量。 然后它创建一个内置 PyTorch 交叉熵损失函数的实例,并使用它来计算模型输出和目标标签之间的损失。 接下来,它创建一个标识等于 9 的目标标签的掩码,然后将损失乘以该掩码并计算所得张量的平均值。 最后,它返回原始损失和高成本损失的均值之和。

下一个代码块使用 PyTorch 的内置数据加载实用程序加载 MNIST 数据集。 train_loader 加载训练数据集并对图像应用指定的变换,例如将图像转换为张量并归一化像素值。 test_loader 加载测试数据集并应用相同的转换。

以下代码块通过对 PyTorch nn.Module 进行子类化来定义一个名为“Net”的卷积神经网络 (CNN)。 CNN 由 2 个卷积层、2 个线性层和一些用于正则化的 dropout 层组成。 Net 类的 forward 方法依次应用卷积层和线性层,将输出传递给 ReLU 激活函数和最大池化层。 它还将 dropout 层应用于输出并返回最终输出的 log-softmax。

下一个代码块创建 Net 类的一个实例、一个优化器(随机梯度下降)和一个自定义损失函数的实例。

最后的代码块是训练循环,其中模型训练了 10 个时期。 在每个时期,模型迭代训练数据集,通过网络传递图像,使用自定义损失函数计算损失并反向传播梯度。 然后它使用优化器更新模型的参数。 它还跟踪训练损失和测试损失,并定期将当前损失打印到控制台。 此外,它会创建一个名为“results”的新目录来存储训练过程的结果和输出。

import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()

在这里插入图片描述

此代码在训练过程中为 MNIST 数据集创建自定义损失函数图。 该图将显示训练集和测试集的自定义损失。

它首先导入 Matplotlib 库,这是一个用于 Python 的绘图库。 然后,它使用 plt.figure() 函数创建一个具有指定大小的图形对象。

下一行代码使用 plt.plot() 函数绘制训练集的自定义损失。 它使用 train_counter 和 train_losses 变量分别作为 x 和 y 轴值。 使用 color 参数将图的颜色设置为蓝色。

然后,它使用 plt.scatter() 函数绘制测试集的自定义损失。 它使用 test_counter 和 test_losses 变量分别作为 x 和 y 轴值。 使用 color 参数将图的颜色设置为红色。

plt.legend() 函数为绘图添加图例,指示哪个绘图对应于训练损失,哪个对应于测试损失。 loc 参数设置为“右上角”,这意味着图例将位于绘图的右上角。

plt.xlabel() 和 plt.ylabel() 函数分别向绘图的 x 轴和 y 轴添加标签。 x 轴标签设置为“看到的训练示例数”,y 轴标签设置为“自定义损失”。

最后,plt.show() 函数用于显示绘图。

此代码将显示一个图,显示所见训练示例的自定义损失函数。 蓝线代表训练集的自定义损失,红点代表测试集的自定义损失。 该图将允许你查看自定义损失函数在训练过程中的表现,并评估模型的性能。

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
  output = network(example_data)
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Prediction: {}".format(
    output.data.max(1, keepdim=True)[1][i].item()))
  plt.xticks([])
  plt.yticks([])

plt.show()

在这里插入图片描述

此代码显示一个图形,其中包含来自测试集的 6 个图像以及训练网络做出的相应预测。

它首先使用 enumerate() 函数循环遍历 test_loader,这是一个批量加载测试数据集的迭代器。 next() 函数用于从测试集中获取第一批示例。

example_data 变量包含图像,example_targets 变量包含相应的标签。

然后它使用 Pytorch 的 torch.no_grad() 函数,它用于临时将 requires_grad 标志设置为 false。 它将减少内存使用并加快计算速度,但也不会跟踪操作。

下一个代码块使用 plt.figure() 函数创建一个新的图形对象。 然后,它使用 for 循环迭代测试集中的前 6 个示例。 对于每个示例,它使用 plt.subplot() 函数在当前图窗中创建一个子图。 plt.tight_layout() 函数用于调整子图之间的间距。

然后它使用 plt.imshow() 函数在当前子图中显示图像。 cmap 参数设置为“灰色”以灰度显示图像,插值参数设置为“无”以显示图像而不进行任何插值。

plt.title() 函数用于为当前子图添加标题。 标题显示了网络对当前示例所做的预测。 网络的输出通过 output.data.max(1, keepdim=True)[1] 传递,它返回预测类的索引。 [i].item() 提取预测类的整数值。

plt.xticks() 和 plt.yticks() 函数分别用于从当前子图中删除 x 轴和 y 轴刻度。

最后,plt.show() 函数用于显示图形。 此代码将显示一个图形,其中包含来自测试集的 6 张图像以及经过训练的网络对其做出的相应预测。 图像以灰度显示且没有任何插值,预测类别显示为每张图像上方的标题。 这可能是一个有用的工具,可用于可视化模型在测试集上的性能并识别任何潜在问题或错误分类。

4、结束语

在本文中,我们以用于数字分类的 MNIST 数据集为例,讨论了 PyTorch 中自定义损失函数的理论和实现。 我们已经展示了如何通过继承 nn.Module 类并覆盖 forward 方法来创建自定义损失函数。 我们还提供了一个示例,说明如何使用自定义损失函数在 MNIST 数据集上训练模型。 在错误分类某些类的成本远高于其他类的情况下,自定义损失函数可能很有用。 重要的是要注意,在实现自定义损失函数时应格外小心,因为它们会对模型的性能产生重大影响。

— ‌
原文链接:Pytorch自定义损失函数 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/337667.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VHDL语言基础-时序逻辑电路-概述

目录 时序逻辑电路-概述: 时序逻辑电路: 时序逻辑电路——有记忆功能: 时序电路的分类: 按照触发器的动作特点: 按照输出信号的特点: 同步时序逻辑电路: 异步时序逻辑电路: 时序逻辑电路-概述: 数字电路按其完成逻辑功能的不同特点,划分为组合逻辑电路和时序…

福利篇1——嵌入式软件行业与公司汇总

前言 汇总嵌入式软件行业与公司,供参考。 文章目录 前言一、嵌入式软件行业和公司汇总1、芯片行业代表性公司2、人工智能代表性公司1)智能驾驶方向代表性公司2)机器人方向代表性公司3、消费电子领域代表性公司4、传统电子电器领域代表性公司5、国企和军工领域代表性公司6、网…

嵌入式系统那些事——aarch64 backtrace嵌入式汇编实现

0 背景 在aarch64嵌入式应用开发中,经常会遇到段错误(segmentation fault),但是通常情况下系统报错后直接退出,没有异常调用打印信息,定位出错原因十分困难。经确认,该问题是由于没有设置捕获段错误,并调用…

推荐3dMax三维设计十大插件

3dMax是一款功能非常强大的三维设计软件,但无论它的功能多么强大,也不可能包含所有三维方面的功能,这时候,第三方插件可以很好的弥补和增强3dMax的基本功能,下面就给大家介绍十款非常不错的3dMax插件。 森林包&#xf…

Unsupervised Question Answering 简单综述

Unsupervised Question Answering by Cloze Translation, ACL 2019 随机从文本中抽取noun phrases或者named entity作为答案将答案部分mask掉,生成cloze question利用无监督翻译,将cloze question转化为natural question 缺点: 直接利用原句…

Android 进阶——Framework核心 之Binder Native成员类详解(二)

文章大纲引言一、Native 家族核心成员关系图二、Native 家族核心成员源码概述1、IInterface1.1、DECLARE_META_INTERFACE 宏1.2、IMPLEMENT_META_INTERFACE(INTERFACE, NAME) 宏1.3、sp< IInterface > BnInterface< INTERFACE >::queryLocalInterface(const String…

微前端qiankun架构 (基于vue2实现)使用教程

工具使用版本 node --> 16vue/cli --> 5 创建文件 创建文件夹qiankun-test。 使用vue脚手架创建主应用main和子应用dev 主应用 安装 qiankun: yarn add qiankun 或者 npm i qiankun -S 使用qiankun&#xff1a; 在 utils 内创建 微应用文件夹 microApp,在该文件夹…

_Linux (线程池)

文章目录线程池概述&#xff1a;线程池示例&#xff1a;代码细节代码结果展示线程池概述&#xff1a; 一种线程使用模式。 线程过多会带来调度开销&#xff0c;进而影响缓存局部性和整体性能。而线程池维护着多个线程&#xff0c;等待着监督管理者分配可并发执行的任务。这避…

Linux下文档类型转PDF的总结

我的环境 centos8 先说思路:先把字体上传到服务器,然后更新字体库 ,代码里面配置字体地址。 如果导出的还是乱码,要么没字体,要么检查代码里面的路径。 目录 1.上传windows字体到linux 2.建立索引信息,更新字体缓存

【基于ChatGPT+Python】快速打造前后端分离的OpenAI人工智能聊天机器人

&#x1f680; ChatGPT是最近很热门的AI智能聊天机器人 &#x1f680; 用途方面相比于普通的聊天AI更加的广泛&#xff0c;甚至可以帮助你改BUG&#xff0c;写代码&#xff01;&#xff01;&#xff01; &#x1f680; 下面是使用pythonChatGPTVue实现的在线聊天机器人&#xf…

shell脚本免交互与expect

目录 Here Document 定义 格式 注意 例子 统计行数 修改密码​编辑 expect 定义 基本命令 实验 免交互ssh主机 Here Document 定义 使用I/O重定向的方式将命令列表提供给交互式程序 格式 命令<< 标记....标记 注意 标记可以使用任意的合法字符&#xf…

SpringBoot笔记【JavaEE】

SpringBoot概念、创建和运行 1.什么是SpringBoot&#xff1f;为什么学习SpringBoot&#xff1f; Spring Boot 就是 Spring 框架的脚⼿架&#xff0c;它就是为了快速开发 Spring 框架⽽诞⽣的。 2.Spring Boot优点 快速集成框架【提供启动添加依赖的功能】内容运行容器【无需…

从零开始,打造属于你的 ChatGPT 机器人!

大家好&#xff01;我是韩老师。不得不说&#xff0c;最近 OpenAI/ChatGPT 真的是太火了。前几天&#xff0c;微软宣布推出全新的 Bing 和 Edge&#xff0c;集成了 OpenAI/ChatGPT 相关的技术&#xff0c;带动股价大涨&#xff1a;微软市值一夜飙涨 5450 亿国内外各家大厂也是纷…

为什么神经网络做不了2次函数拟合,网上的都是骗人的吗?

环境&#xff1a;tensorflow2 kaggle 这几天突发奇想&#xff0c;用深度学习训练2次函数。先在网上找找相同的资料这方面资料太少了。大多数如下&#xff1a; 。 给我的感觉就是&#xff0c;用深度学习来做&#xff0c;真的很容易。 网上写出代码分析的比较少。但是也找到了…

云计算|OpenStack|社区版OpenStack安装部署文档(十二--- openstack的网络模型解析---Rocky版)

前言&#xff1a; https://zskjohn.blog.csdn.net/article/details/128846360 云计算|OpenStack|社区版OpenStack安装部署文档&#xff08;六 --- 网络服务neutron的安装部署---Rocky版&#xff09; &#xff08;######注&#xff1a;以上文章使用的是openstack的provider网…

【Vue3】电商网站吸顶功能

头部分类导航-吸顶功能 电商网站的首页内容会比较多&#xff0c;页面比较长&#xff0c;为了能让用户在滚动浏览内容的过程中都能够快速的切换到其它分类。需要分类导航一直可见&#xff0c;所以需要一个吸顶导航的效果。 目标:完成头部组件吸顶效果的实现 交互要求 滚动距离大…

计算机视觉 对比学习13篇经典论文、解读、代码

为了快速对 机器视觉中的对比学习有一个快速了解&#xff0c;或者后续复习&#xff0c;此处收录了 13篇经典论文、一些讲解地较好的博客和相应的Github代码&#xff0c;用不同颜色标记。 ​ 对比学习 13篇经典论文 论文代码和博客http://​www.webhub123.com/#/home/detail?p…

Nextjs了解内容

目录Next.jsnext.js的实现1&#xff0c;nextjs初始化2&#xff0c; 项目结构3&#xff0c; 数据注入getInitialPropsgetServerSidePropsgetStaticProps客户端注入3&#xff0c;CSS Modules4&#xff0c;layout组件5&#xff0c;文件式路由6&#xff0c;BFF层的文件式路由7&…

爬虫笔记之——selenium安装与使用(1)

爬虫笔记之——selenium安装与使用&#xff08;1&#xff09;一、安装环境1、下载Chrome浏览器驱动&#xff08;1&#xff09;查看Chrome版本&#xff08;2&#xff09;下载相匹配的Chrome驱动程序地址&#xff1a;https://chromedriver.storage.googleapis.com/index.html2、学…

vue83-103

vue全局路由拦截路由懒加载路由原理swiper组件选项卡封装电影导航组件正在热映获取数据渲染axios封装详情渲染详情轮播详情Header-组件影院组件渲染全局路由拦截 即使路径对&#xff0c;也会被拦截 router.beforeEach((to,from, next) > { console.log(to) if&#xff08;…