cifar-10数据集+ResNet50

news2025/1/17 23:07:57

CIFAR-10-ObjectRecognition

作为一个古老年代的数据集,用ResNet来练一下手也是不错的。
比赛链接:CIFAR-10 - Object Recognition in Images | Kaggle

1. 预设置处理

创建各类超参数,其中如果是在Kaggle上训练的话batch_size是可以达到4096的。

同时对于CIFAR-10数据集中含有10个类别,通过字典与反字典生成相应映射。

'''
**************************************************
@File   :kaggle -> ResNet
@IDE    :PyCharm
@Author :TheOnlyMan
@Date   :2023/4/14 23:06
**************************************************
'''
seed = 998244353
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.autograd.set_detect_anomaly(True) 检测梯度计算失败位置


class ArgParse:
    def __init__(self) -> None:
        self.batch_size = 16
        self.lr = 0.001
        self.epochs = 10
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
args = ArgParse()
dic = {"airplane": 0, "automobile": 1, "bird": 2, "cat": 3, "deer": 4,
       "dog": 5, "frog": 6, "horse": 7, "ship": 8, "truck": 9}
rev_dic = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer",
           5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}

2. 创建Dataset

对于CIFAR-10数据集而言,图片的大小为(3,32,32),若是采用ResNet的话也不必要重新插值生成(224,224)尺寸的图像, 目前觉得好像(32,32)会更好点,而且模型体量更小点。

由于数据集是比较大的,可以不采用一次性读入内存的形式,在需要时在读取。

最后对于正确标签而言,采用onehot编码的形式,将正确标签概率设置为1,因此模型的输出只需要判断概率问题即可,较为经典的多分类问题。

class DataSet(Dataset):
    def __init__(self, flag='train') -> None:
        self.flag = flag
        self.trans = transforms.Compose([
            # trans.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            trans.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        assert flag in ['train', 'test'], 'not implement'
        if self.flag == 'train':
            self.path = "dataset/cifar-10/train"
            self.dtf = pd.read_csv("dataset/cifar-10/trainLabels.csv")
        else:
            self.path = "dataset/cifar-10/test"
            self.dtf = None
        self.len = len(os.listdir(self.path))

    def __getitem__(self, item):
        image = Image.open(os.path.join(self.path, f"{item + 1}.png"))
        if self.flag == 'train':
            label = [0 for _ in range(10)]
            label[dic.get(self.dtf.iloc[item, 1], -1)] = 1
            return self.trans(image), torch.tensor(label, dtype=torch.float)
        else:
            return self.trans(image), torch.tensor(item, dtype=torch.int32)

    def __len__(self):
        return self.len

3. 模型

CNN

采取类似VGG架构,卷积核均为3,通道数在卷积层由3>8>16>32>64>128,在全连接层则是由一层512节点数构成,最后的输出层10采用Softmax函数进行多分类处理。(注:nn.Softmax()中dim需要设定为1)

训练轮数只有5-10轮,提交上去后有百分之73的准确率,对于一个较为少层的神经网络来说效果还是不错的,轻量级,同时如果训练轮数在多点应该也可以到达较高的准确率。

相比ResNet50而言,在训练10轮之后可以很快达到60%多的准确率。

在这里插入图片描述

10-20轮次CNN趋于稳定,验证集准确率不再上升。

在这里插入图片描述

可能是设置的参数与第一次设置的不同,从20-50轮次的训练中,CNN的效果也并没有什么起色,应该是已经达到了该网络的上限。已经无法稳定地突破70%的准确率。
在这里插入图片描述

在这里插入图片描述

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(
                in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(
                in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(
                in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(128), nn.ReLU(),
        )
        self.layer = nn.Sequential(
            nn.Linear(4 * 4 * 128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.layer(x)
        return x

ResNet50

采用由torchvision中包含的resnet50模型进行训练,同时去除掉resnet50的全连接层直接由卷积层连向输出层。

训练50轮次后可以达到82.7%的准确率。应该是轮次不足,或是仅仅使用原ResNet网络模型的效果无法达到90%以上等。其中采用20层初始预训练,经过微调后继续30层训练。感觉是还有上升的余地的,把训练轮数再翻倍应该可以达到85%以上准确率。

在这里插入图片描述

在这里插入图片描述

class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.model = torchvision.models.resnet50(pretrained=True)
        in_channel = self.model.fc.in_features
        self.model.fc = nn.Linear(in_channel, 10)
        self.layer = nn.Sequential(
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.model(x)
        return self.layer(x)

MyResNet50

参考另一篇博客:ResNet残差网络

对于该模型而言,训练其50轮。

在20轮时的并没有20轮的CNN效果好。不过CNN本身参数虽然少,但是其效果之直逼ResNet50。

在这里插入图片描述

当对MyResNet50训练到第50轮时,其准确率并没有达到ResNet源码的效果,只在72%附近徘徊(比CNN略好一点,上限略高一点),因为MyResNet只是个简化的版本。这就使得网络比较浅的CNN效果和深层网络ResNet效果类似。

在这里插入图片描述

在这里插入图片描述

尝试再此基础上再次进行50轮次训练,并将学习率减少为原来的一半。此时验证集准确率已经到达极限,无法上升。

在这里插入图片描述

4. train

模型的创建或者加载由以上三个模型进行确定。

def train(pretrain=None):
    dataset = DataSet('train')
    train_size = int(len(dataset) * 0.95)
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=True)
    model = ResNet(Bottleneck, [3, 4, 6, 3], 10).to(args.device)
    if pretrain:
        model.load_state_dict(torch.load(pretrain))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    train_epochs_loss = []
    valid_epochs_loss = []
    train_acc = []
    valid_acc = []

    for epoch in tqdm(range(args.epochs)):
        model.train()
        train_epoch_loss = []
        acct, numst = 0, 0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            outputs = model(inputs)
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_epoch_loss.append(loss.item())
            acct += sum(outputs.max(axis=1)[1] == labels.max(axis=1)[1]).cpu()
            numst += labels.size()[0]
        train_epochs_loss.append(np.average(train_epoch_loss))
        train_acc.append(100 * acct / numst)

        with torch.no_grad():
            model.eval()
            val_epoch_loss = []
            acc, nums = 0, 0
            for inputs, labels in tqdm(valid_loader):
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_epoch_loss.append(loss.item())

                acc += sum(outputs.max(axis=1)[1] == labels.max(axis=1)[1]).cpu()
                nums += labels.size()[0]

            valid_epochs_loss.append(np.average(val_epoch_loss))
            valid_acc.append(100 * acc / nums)
            print("train acc = {:.3f}%, loss = {}".format(100 * acct / numst, np.average(train_epoch_loss)))
            print("epoch = {}, valid acc = {:.2f}%, loss = {}".format(epoch, 100 * acc / nums,
                                                                          np.average(val_epoch_loss)))

    plt.figure(figsize=(12, 4))
    plt.subplot(121)
    plt.plot(train_acc, '-o', label="train_acc")
    plt.plot(valid_acc, '-o', label="valid_acc")
    plt.title("epochs_acc")
    plt.subplot(122)
    plt.plot(train_epochs_loss, '-o', label="train_loss")
    plt.plot(valid_epochs_loss, '-o', label="valid_loss")
    plt.title("epochs_loss")
    plt.legend()
    plt.show()
    torch.save(model.state_dict(), 'model.pth')

5. predict

最后预测将其生成指定submission.csv文件即可进行提交。

def pred():
    model = ResNet50().to(args.device)
    model.load_state_dict(torch.load('model.pth'))
    model.eval()

    test_data = DataSet('test')
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False)
    data = []

    for inputs, labels in tqdm(test_loader):
        ans = model(inputs.to(args.device)).cpu()
        ans = ans.max(axis=1)[1].numpy()
        for number, res in zip(labels, ans):
            data.append([number + 1, rev_dic.get(res)])

    dtf = pd.DataFrame(data, columns=['id', 'label'])
    dtf['id'] = dtf['id'].astype(np.int64)
    dtf.sort_values(by='id', ascending=True)
    print(dtf.info())
    dtf.to_csv('submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/435235.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安全狗深度参与编写《数据安全产品与服务观察报告》发布!

4月11日,由中国通信标准化协会联合数据安全推进计划主办的《数据安全产品与服务观察报告》发布会在北京顺利开展。 作为国内云原生安全领导厂商,安全狗也参与了数据安全推进计划《数据安全产品与服务观察报告》撰写。 此次活动针对数据安全产业、技术、…

排序算法合集(1)

前言: 今天我们正式开始讲述算法中的排序。排序算法是我们十分重要的算法,为什么呢? 排序是在各种情况下都非常重要的,无论是在人类社会还是在计算机科学中。以下是一些排序的重要性: 数据分析:在数据分析…

多连接数据库管理Navicat Premium 中文

Navicat Premium 是一款强大的数据库管理工具,它支持多种关系型数据库,包括 MySQL、MariaDB、Oracle、SQL Server、PostgreSQL 等等。 以下是 Navicat Premium 的一些主要功能: 连接管理:可以在一个用户界面中同时连接到多个数据库…

HCIP-6.7BGP的路径选择

BGP的路径选择 1、BGP路径属性1.1、路由选择1.1.1、BGP路由选择过程1.1.2、BGP选路参数2、BGP的路由策略2.1、Preferred-Value相当权重weight2.2、local-preference本地优先级2.3、AS_PATH经过的AS号 不常用2.4、Origin起源属性修改2.5、MED多出口鉴别器3、BGP非策略性选路原则…

【C++】右值引用(极详细版)

在讲右值引用之前,我们要了解什么是右值?那提到右值,就会想到左值,那左值又是什么呢? 我们接下来一起学习! 目录 1.左值引用和右值引用 1.左值和右值的概念 2.左值引用和右值引用的概念 2.左值引用和右…

C++linux高并发服务器项目实践 day2

Clinux高并发服务器项目实践 day2 静态库的制作静态库命名规则静态库的制作 动态库的制作命名规则制作使用动态库与静态库的区别解决动态库连接失败问题静态库和动态库的对比静态库的优缺点动态库的优缺点 Makefile什么是MakefileMakefile文件命名和规则Makefile的使用工作原理…

SpringSpringBoot常用注解总结

0.前言 可以毫不夸张地说,这篇文章介绍的 Spring/SpringBoot 常用注解基本已经涵盖你工作中遇到的大部分常用的场景。对于每一个注解我都说了具体用法,掌握搞懂,使用 SpringBoot 来开发项目基本没啥大问题了! 为什么要写这篇文章…

【分享】Excel表格的密码忘记了怎么办?附解决办法

我们知道通过设置密码可以保护Excel表格,可有时候设置后很久没用就把密码忘记了,而Excel并没有找回密码的选项,那要怎么办呢?今天小编就来分享一下忘记Excel密码的解决方法。 Excel表格可以设置多种密码,不同密码对应…

短视频平台-小说推文(Lofter)推广任务详情

​Lofter日结内测中,可能暂只对部分优质会员开放! 注意 Lofter 关键词7天未使用,可能会被下线。 Lofter 不再需要回填视频链接了。 接Lofter官方通知 关于近期部分博主反馈播放量高但搜索量很低的问题尤其是快手平台,我们做了代码、服务器…

No.040<软考>《(高项)备考大全》【第24章】成熟度模型

【第24章】成熟度模型 1 考试相关2 第一维四个阶梯3 项目成熟度模型OPM3CMMI过程域 4 成熟度级别级别区别 5 练习题参考答案: 1 考试相关 选择可能考0-1分,案例论文不考。 2 第一维四个阶梯 3 项目成熟度模型OPM3 CMMI过程域 CMMI过程域可以分为4类&a…

智能对话机器人Rasa学习资料

文章目录 背景收集的Rasa学习资料官网B站其他 类似产品教学机器人售后咨询效果手机推荐效果 背景 最近做了一个Ros2项目,界面如下图: 客户要求能够使用语音快速执行特定动作如:打开视频窗口、显示小车1视频、无人机1返航等,这就涉及到了自然…

C++ : 整体工程构架设计流程

重点: 1.一个项目通常分为bin(存放项目生成的dll和整体工程的exe),code(存每个项目的代码),lib(存每个项目生成的lib),pdb(存放项目生成的pdb文件),sln(解决方案) 整体创建流程: 一个主干项目,其他若干依赖…

Java接口自动化测试框架系列:提升测试效率的自动化测试框架

目录:导读 一、什么是自动化测试 二、自动化测试的缺点 三、自动化测试框架选型 原则 对比 四、框架构建 【自动化测试工程师学习路线】 一、什么是自动化测试 自动化测试是把以人为驱动的测试行为转化为机器执行的一种过程。 通常,在设计了测试…

【UE】暂停游戏界面及功能实现

效果 步骤 1. 首先在项目设置中添加一个暂停的操作映射 2. 新建一个控件蓝图,命名为“PauseMenuWidget” 3. 打开“ThirdPersonCharacter”,添加一个布尔类型变量,命名为“isScreenShow”,用于判断当前玩家是否打开了暂停界面 在…

【Linxu网络服务】DHCP

DHCP 一、DHCP工作原理1.1背景1.2优点1.3 DHCP分配方式1.4DHCP工作原理 二、使用DHCP动态配置主机地址2.1实验一:动态配置主机地址2.2给Linux客户机配置动态地址**2.4设置一个外网口,给客户端设置一个固定的ip地址 一、DHCP工作原理 作为服务端负责集中…

uniapp 之 将marker 渲染在地图上 点击弹层文字时显示当前信息

目录 效果图 总代码 分析 1.template 页面 地图显示代码 2. onload ①经纬度 ②取值 ③注意 ④ 3.methods ① 先发送 getStationList 请求 获取 数组列表信息 ② regionChange 视野发生变化时 触发 分页逻辑 ③ callouttap 点击气泡时触发 查找 当前 marker id 等…

基于第一性原理DFT密度泛函理论的计算项目

随着计算机技术的不断发展,计算材料科学的方法也日益成熟。其中,基于第一性原理的密度泛函理论(DFT)计算方法,因其准确性、可靠性和高效性而广受欢迎。本文将介绍基于DFT的密度泛函理论的计算项目,包括电子…

云内基于 SRv6 的 SFC 方案

1. 基于 SRv6 的 SFC 服务链 为满足用户的业务数据安全、稳定等需求,提供各种基础保障或增值优化服务,在传统网络中,经常使用业务功能节点(如负载均衡、防火墙等)实现服务供应。但这些业务功能节点往往与网络拓扑和硬件…

Fortinet Accelerate 2023全球网安大会成功举办 加速推进网络安全行业融合与整合

近日,Fortinet全球网络安全大会——Fortinet Accelerate 2023 在美国奥兰多成功举办。在对企业数字化转型挑战及网络威胁趋势等行业热点进行深入探讨的同时,Fortinet全新发布了以融合与整合为核心设计理念的增强型产品和服务,帮助企业从容应对…

第2章 时间空间复杂度计算

1时间复杂度计算 时间复杂度是什么? 一个函数,用大O表示,例如:O(1), O(N), O(logN). 定性描述算法的运行时间。 时间复杂度常见图: 案例: O(1) let i 0 i 1 解释:每次执行这段代码&#…