学习Pytorch深度学习运行AlexNet代码时关于在Pycharm中解决 “t >= 0 t < n_classes” 的断言错误方法

news2025/1/18 20:26:50

在学习深度学习的过程中,遇到了一个报错:

 这跑的代码是AlexNet的代码实现。

运行时出现报错:

C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\cuda\Loss.cu:257: block: [0,0,0], thread: [4,0,0] Assertion `t >= 0 && t < n_classes` failed.

解决方案:

当你遇到 CUDA error: device-side assert triggered 和具体的断言错误 t >= 0 && t < n_classes,这通常指示在 CUDA 上运行的某些操作遇到了问题,大多数情况下是由于标签值 t 超出了期望的范围。这里的问题发生在执行损失函数计算时,具体来说是在 PyTorch 的底层 CUDA 代码中。

问题的主要源头:

由于错误提示 t >= 0 && t < n_classes,你需要确保所有标签值都在正确的范围内。对于分类任务,标签值 t 应该是一个非负整数,并且小于类别总数 n_classes。如果你的数据集标签不是从 0 开始的,你需要将它们转换为从 0 开始。

在合适的位置定义:确保 n_classes 在你尝试使用它进行断言检查之前已经被定义。这通常意味着你需要在加载数据集、初始化数据加载器之前,或在定义模型之处确定 n_classes 的值。

这里我的代码中已经指定:

由于错误报告显示断言失败发生在损失计算时,一个可能的原因是某些样本的标签不在 [0, n_classes-1] 的范围内。你可以添加一些代码在损失计算之前检查标签值:

增添代码段:

        if not (labels.min() >= 0 and labels.max() < n_classes):
            print("不满足条件的标签值:", labels[labels < 0], labels[labels >= n_classes])

 以及:

n_classes = 102  # 根据你的具体任务设置这个值, 这里对应num_classes=102

继续运动代码进行测试,出现如下报错:

 从提供的输出信息来看,断言错误是因为存在标签值等于 102,这超出了预期的类别范围 [0, n_classes-1]。假设 n_classes 应该是 102(意味着有效的标签范围是从 0101),标签值 102 显然是无效的,因为它等于类别总数,超出了最大有效索引。

解决方案

  1. 校正类别总数:首先确认 n_classes 的值是否正确。如果你的任务确实有 102 个类别(例如,Flowers102 数据集),那么 n_classes 应该设置为 102,并且你需要确保所有标签都在 [0, 101] 的范围内。

  2. 修正数据标签:由于出现了 102 作为标签值,这可能是由于数据标签在某个步骤中被错误地分配或转换。你需要回溯到数据处理的步骤,找出为什么会有 102 这样的标签值出现,并进行修正。如果是因为数据集自带的标签从 1 开始计数,那么你需要将所有标签减 1 以转换为从 0 开始计数:

增添代码:

# 假设 `labels` 是你的标签张量
labels = labels - 1

测试和验证

在进行了上述修正之后,再次运行你的代码,并使用之前添加的打印语句来验证所有标签值是否都在正确的范围内。如果没有进一步的断言错误,那么这意味着问题已经被解决。如果问题依然存在,可能需要进一步调查数据处理流程中的每一个步骤,确保在任何地方都没有引入标签错误。

再一次运行代码,发现没有报错,代码运行正常:

模型开始了训练,问题得到了解决!!!

最后给出优化后的完整python代码:

文件:main_AlexNet.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import *
import numpy as np
import matplotlib.pyplot as plt
import sys
from AlexNet import AlexNet
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# 设备检测,若未检测到cuda设备则在CPU上运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 设置随机种子
torch.manual_seed(0)

# 定义模型、优化器、损失函数
model = AlexNet(num_classes=102).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.002, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 设置训练集的数据变换,进行数据增强
transform_train = transforms.Compose([
    transforms.RandomRotation(30),  # 随机旋转 -30度到30度之间
    transforms.RandomResizedCrop((224, 224)),  # 随机比例裁剪并进行resize
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
    transforms.ToTensor(),  # 将数据转换为张量
    # 对三通道数据进行归一化(均值,标准差), 数值是从ImageNet数据集上的百万张图片中随机抽样计算得到
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 对数据进行归一化
])

# 设置测试集的数据变换,进行数据增强
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),  # resize
    transforms.ToTensor(),  # 将数据转化为张量
    # 对三通道数据进行归一化(均值,标准差),数值是从ImageNet数据集上的百万张图片中随机抽样计算得到
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载训练数据,需要特别注意的是Flowers102数据集,test簇的数据量较多些,所以这里使用"test"作为训练集
train_dataset = datasets.Flowers102(root='./data/flowers102', split="test",
                                    download=False, transform=transform_train)
# 实例化训练数据加载器
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=6, drop_last=False)
# 加载测试数据,使用“train”作为测试集
test_dataset = datasets.Flowers102(root='./data/flowers102', split="train",
                                   download=False, transform=transform_test)
# 实例化测试数据加载器
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=6, drop_last=False)

# 设置epoch数并开始训练
num_epochs = 500  # 设置epoch数
n_classes = 102  # 根据你的具体任务设置这个值, 这里对应num_classes=102
loss_history = []  # 创建损失历史记录列表
acc_history = []  # 创建准确率历史记录列表

# tqdm用于显示进度条并评估任务时间开销
for epoch in tqdm(range(num_epochs), file=sys.stdout):
    # 记录损失和预测正确数
    total_loss = 0
    total_correct = 0

    # 批量训练
    model.train()
    for inputs, labels in train_loader:

        labels = labels - 1

        if not (labels.min() >= 0 and labels.max() < n_classes):
            print("不满足条件的标签值:", labels[labels < 0], labels[labels >= n_classes])

        # 将数据转换到指定计算资源设备上
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 预测、损失函数、反向传播
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 记录训练集loss
        total_loss += loss.item()

    # 测试模型,不计算梯度
    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            # 将数据转换到指定计算资源设备上
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 预测
            outputs = model(inputs)
            # 记录测试集预测正确数
            total_correct += (outputs.argmax(1) == labels).sum().item()

    # 记录训练集损失和测试集准确率
    loss_history.append(np.log10(total_loss))  # 将损失加入损失历史记录列表,由于数值有时较大,这里取对数
    acc_history.append(total_correct / len(test_dataset))  # 将准确率加入准确率历史记录列表

    # 打印中间值
    # 每50个epoch打印一次中间值
    if epoch % 50 == 0:
        tqdm.write("Epoch: {0} Loss: {1} Acc: {2}".format(epoch, loss_history[-1], acc_history[-1]))

# 使用Matplotlib绘制损失和准确率的曲线图
plt.plot(loss_history, label='loss')
plt.plot(acc_history, label=' ')
plt.legend()
plt.show()

# 输出准确率
print("Accuracy:", acc_history[-1])

文件:AlexNet.py

import torch
import torch.nn as nn
from torchinfo import summary

# 定义AlexNet的网络结构
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, dropout=0.5):
        super().__init__()
        # 定义卷积层
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        # 定义全连接层
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# 查看模型结构以及参数量,input_size表示示例输入数据的维度信息
# summary(AlexNet(), input_size=(1,3,224,224))

将epoch改为100,得到如下训练结果:

 可见模型还未收敛,大家可以自行调节参数来尝试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1441943.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[职场] 公务员面试停顿磕巴常见吗 #学习方法#知识分享#知识分享

公务员面试停顿磕巴常见吗 面试时说话磕巴简直是太常见了&#xff0c;对于一个新问题&#xff0c;让人在短时间内&#xff0c;并且仅仅是三分钟内&#xff0c;就组织起一个答案&#xff0c;还无法全部打手稿&#xff0c;这对于连上个讲台都会脸发红的人来说&#xff0c;简直是一…

前端JavaScript篇之如何获得对象非原型链上的属性?

目录 如何获得对象非原型链上的属性&#xff1f; 如何获得对象非原型链上的属性&#xff1f; 要获取对象上非原型链上的属性&#xff0c;可以使用 hasOwnProperty() 方法。这个方法是 JavaScript 内置的对象方法&#xff0c;用于检查一个对象是否包含指定名称的属性&#xff0…

TP-LINK今年的年终奖。。

TP-LINK 年终奖 如果说昨天爆料的「浦发银行年终奖&#xff0c;一书抵万金」还稍有争议&#xff08;有些说没发&#xff0c;有些说 3/4/5 折&#xff09;&#xff0c;那今天的 TP-LINK 则是毫无悬念。 据在职的 TP-LINK 技术员工爆料&#xff1a;入职时说好的 16 薪&#xff0c…

day45_maven_tomcat

今日内容 0 复习昨日 1 maven 2 tomcat 3 创建项目 0 复习昨日 1 单词写5遍 argument 参数 parameter 参数 access 访问 field 字段 invoke 调用 illegal 非法 invalid 无效 column 列 property 属性 DataSource 数据源 2 数据库连接池有啥好处 3 获得字节码文件的方式 Class.f…

ChatGPT高效提问—prompt常见用法(续篇七)

ChatGPT高效提问—prompt常见用法&#xff08;续篇七&#xff09; 1.1 零样本、单样本和多样本 ​ ChatGPT拥有令人惊叹的功能和能力&#xff0c;允许用户自由向其提问&#xff0c;无须提供任何具体的示例样本&#xff0c;就可以获得精准的回答。这种特性被称为零样本&#x…

使用CHATGPT进行论文写作的缺点和风险

为了真正感受 ChatGPT 的写作潜力&#xff0c;让我们先将其与传统的论文写作方法进行一下比较分析 CHATGPT论文写作的缺点和风险 传统论文写作的考验和磨难很深&#xff1a;费力的研究、组织想法和精心设计的逻辑论证&#xff0c;往往以牺牲你的理智为代价。 进入ChatGPT&am…

Linux下的多用户管理和认证:从入门到精通(附实例)

Linux操作系统以其强大的多用户管理和认证机制而著称。这种机制不仅允许多个用户同时登录并执行各种任务&#xff0c;还能确保每个用户的数据安全和隐私。本文将通过一系列实例&#xff0c;带你逐步掌握Linux下的多用户管理和认证。 一、Linux多用户管理的基础知识 在Linux中&…

Bootstrap学习三

Bootstrap学习三 文章目录 前言四、Bootstrap插件4.1. 插件概览4.1.1. data属性4.1.2. 编程方式的API4.1.3. 避免命名空间冲突4.1.4. 事件 4.2. 模态框4.2.1. 引入4.2.2. 基本结构4.2.3. 基本使用4.2.4. 触发模态框的方法 4.3. 下拉菜单和滚动监听4.3.1. 下拉菜单4.3.2. 滚动监…

祝大家春节快乐

文章目录 祝福年俗交流 祝福 今天是大年三十&#xff0c;也就是除夕&#xff0c;这是全画人民欢庆春节的日子&#xff0c;在此辞旧迎新之际&#xff0c;我祝愿所有的粉丝们春节快乐&#xff0c;身体健康&#xff0c;万事如意。也祝愿我们伟大的祖国繁荣昌盛&#xff0c;龙腾虎…

《MySQL 简易速速上手小册》第9章:高级 MySQL 特性和技巧(2024 最新版)

文章目录 9.1 使用存储过程和触发器9.1.1 基础知识9.1.2 重点案例&#xff1a;使用 Python 调用存储过程实现用户注册9.1.3 拓展案例 1&#xff1a;利用触发器自动记录数据更改历史9.1.4 拓展案例 2&#xff1a;使用 Python 和触发器实现数据完整性检查 9.2 管理和查询 JSON 数…

基于LLM的数据漂移和异常检测

大型语言模型 (LLM) 的最新进展被证明是许多领域的颠覆性力量&#xff08;请参阅&#xff1a;通用人工智能的火花&#xff1a;GPT-4 的早期实验&#xff09;。 和许多人一样&#xff0c;我们非常感兴趣地关注这些发展&#xff0c;并探索LLM影响数据科学和机器学习领域的工作流程…

你的立身之本是什么?

去年发生的一切&#xff0c;大到疫情、政治经济形势、行业的萎靡和震荡&#xff0c;小到身边的跳槽、裁员、公司倒闭……似乎都在告诉我们&#xff1a; 当冲击到来的时候&#xff0c;它是不会提前跟你打招呼的。 接下来的10年&#xff0c;我们所面临的不确定性&#xff0c;比起…

技术精英求职必备:Java开发工程师简历制作全指南

投简历找工作嘛&#xff0c;这事儿其实就跟相亲差不多&#xff0c;得让对方一眼就看上你。 在这场职场的‘相亲’中&#xff0c;怎样才能让你的简历脱颖而出&#xff0c;成为HR眼中的理想‘对象’呢&#xff1f;来&#xff0c;我给你支几招&#xff0c;让你的简历更吸引人。 …

windows编程-系统编程入门

1.进程线程概念&#xff08;简略版&#xff09; 1.1 进程 1.1.1 概念 我们编写的代码只是一个存储在硬盘的静态文件&#xff0c;通过编译后就会生成二进制可执行文件&#xff0c;当我们运行这个可执行文件后&#xff0c;它会被装载到内存中&#xff0c;接着 CPU 会执行程序中…

除夕快乐(前端小烟花)

家人们&#xff0c;新的一年好运常在&#xff0c;愿大家在新的一年里得偿所愿&#xff0c;发财暴富&#xff0c;愿大家找到属于自己的那个公主&#xff0c;下面就给大家展示一下给公主的烟花 前端烟花 新的一年&#xff0c;新的挑战&#xff0c;愿我们不忘初心&#xff0c;砥砺…

HarmonyOS 开发学习笔记

HarmonyOS 开发学习笔记 一、开发准备1.1、了解ArkTs语言1.2、TypeScript语法1.2.1、变量声明1.2.2、条件控制1.2.3、函数1.2.4、类和接口1.2.5、模块开发 1.3、快速入门 二、ArkUI组件2.1、Image组件2.2、Text文本显示组件2.3、TextInput文本输入框组件2.4、Button按钮组件2.5…

备战蓝桥杯---搜索(完结篇)

再看一道不完全是搜索的题&#xff1a; 解法1&#xff1a;贪心并查集&#xff1a; 把冲突事件从大到小排&#xff0c;判断是否两个在同一集合&#xff0c;在的话就返回&#xff0c;不在的话就合并。 下面是AC代码&#xff1a; #include<bits/stdc.h> using namespace …

Bee+SpringBoot稳定的Sharding、Mongodb ORM功能(同步 Maven)

Hibernate/MyBatis plus Sharding JDBC Jpa Spring data GraphQL App ORM (Android, 鸿蒙) Bee 小巧玲珑&#xff01;仅 860K, 还不到 1M, 但却是功能强大&#xff01; V2.2 (2024春节・LTS 版) 1.Javabean 实体支持继承 (配置 bee.osql.openEntityCanExtendtrue) 2. 增强批…

放飞梦想,扬帆起航——1888粉丝福利总结

目录 1.祝福 2.准备 3.抽奖 4.制作 5.添加 6.成果 7.感谢 8.福利 9.祝福 1.祝福 马上就是除夕了&#xff0c;在这里提前预祝大家春节快乐&#xff0c;小芒果在这里给大家拜年了&#xff01; 2.准备 其实很早之前我就在幻想着哪一天我的粉丝量能突破1888&#xff0c;…

Redis -- 安装客户端redis-plus-plus

目录 访问reids客户端github链接 安装git 如何安装&#xff1f; 下载/编译、安装客户端 安装过程中可能遇到的问题 访问reids客户端github链接 GitHub - sewenew/redis-plus-plus: Redis client written in CRedis client written in C. Contribute to sewenew/redis-p…