(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

news2025/1/10 21:35:36
  1. 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
# 查看数据集
def read_csv_labels(fname):
    """读取‘fname’来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环
        tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)
        return dict((name, label) for name, label in tokens)

labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

在这里插入图片描述
将验证集从原始的训练集钟拆分出来

# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。
    shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹

def reorg_train_valid(data_dir, labels, valid_ratio):
    """将验证集从原始训练集钟拆分出来"""
    # 训练数据集中样本数量最少的类别中的样本数
    # Counter: 计数器,返回一个字典,键为元素,值为元素个数;
    # .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序
    # [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整
    label_count = {}

    # 遍历原始训练集中的每个样本
    for train_file in os.listdir(os.path.join(data_dir, 'train')):
        label = labels[train_file.split('.')[0]]  # 从文件名中提取标签
        fname = os.path.join(data_dir, 'train', train_file)
        copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))
        # 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))

    return n_valid_per_label

# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
    """在预测期间整理测试集,以方便读取"""
    # 遍历测试集中的每个样本
    for test_file in os.listdir(os.path.join(data_dir, 'test')):
        # 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
transform_train = torchvision.transforms.Compose([
    # 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道 : 消除评估结果中的随机性
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
  1. 加载数据集
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test
    ) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():
    num_classes = 10
    net = d2l.resnet18(num_classes, in_channels=3)
    return net
# 查看网络模型
get_net()

在这里插入图片描述

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0]/ metric[2], metric[1] / metric[2], None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch+1, (None, None, valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
                f'train acc{metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'
                     f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述
10. 对测试集进行分类并提交结果

net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.to(devices[0]))
    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1231939.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

内存学习(4):内存分类与常用概念3(ROM)

1 ROM介绍 ROM即为只读存储器&#xff0c;全拼是Read Only Memory。 1.1 “只读”的由来 ROM叫只读存储器是因为最早的ROM&#xff08;MROM&#xff09;确实是只能读取不能写入&#xff0c;一旦出厂不能再写&#xff0c;需要在出厂之前预设好它的数据&#xff0c;并且它是掉…

华为---OSPF网络虚连接(Virtual Link)简介及示例配置

OSPF网络虚连接&#xff08;Virtual Link&#xff09;简介 为了避免区域间的环路&#xff0c;OSPF规定不允许直接在两个非骨干区域之间发布路由信息&#xff0c;只允许在一个区域内部或者在骨干区域和非骨干区域之间发布路由信息。因此&#xff0c;每个ABR都必须连接到骨干区域…

Fourier分析导论——第6章——R^d 上的Fourier变换(E.M. Stein R. Shakarchi)

第6章 上的 Fourier 变换 It occurred to me that in order to improve treatment planning one had to know the distribution of the at- tenuation coefficient of tissues in the body. This in- formation would be useful for diagnostic purposes and would con…

[github配置] 远程访问仓库以及问题解决

作者&#xff1a;20岁爱吃必胜客&#xff08;坤制作人&#xff09;&#xff0c;近十年开发经验, 跨域学习者&#xff0c;目前于新西兰奥克兰大学攻读IT硕士学位。荣誉&#xff1a;阿里云博客专家认证、腾讯开发者社区优质创作者&#xff0c;在CTF省赛校赛多次取得好成绩。跨领域…

密码加密解密之路

1.背景 做数据采集&#xff0c;客户需要把他们那边的数据库连接信息存到我们系统里&#xff0c;那我们系统就要尽可能的保证这部分数据安全&#xff0c;不被盗。 2.我的思路 1.需要加密的地方有两处&#xff0c;一个是新增的时候前端传给后端的时候&#xff0c;一个是存到数…

浅析RSA非对称加密算法

目录 引言 凯撒密码 对称加密 非对称加密 ​编辑总结 引言 几月前在知乎上看到一个关于RSA公钥与私钥加解密的提问甚感兴趣&#xff0c;却一直没有时间去探究&#xff0c;今日浅得闲时以文记之。 在文章正式开始之前先讲一个小故事&#xff0c;在公元前58年时&#xff0c…

css 实现文字流光效果

经过调研发现大多滑块验证码中&#xff0c;有一些文字流光效果&#xff0c;因此在这里简单实现一下。 实现主要利用background 渐变背景以及backgorund-clip:text实现。具体代码如下 css部分 .slide {width: 300px;height: 40px;border: 1px solid #ccc;border-radius: 8px;…

猫12分类:使用yolov5训练检测模型

前言&#xff1a; 在使用yolov5之前&#xff0c;尝试过到百度飞桨平台&#xff08;小白不建议&#xff09;、AutoDL平台&#xff08;这个比较友好&#xff0c;经济实惠&#xff09;训练模型。但还是没有本地训练模型来的舒服。因此远程了一台学校电脑来搭建自己的检测模型。配置…

【鸿蒙最新全套教程】<HarmonyOS第一课>1、运行Hello World

下载与安装DevEco Studio 在HarmonyOS应用开发学习之前&#xff0c;需要进行一些准备工作&#xff0c;首先需要完成开发工具DevEco Studio的下载与安装以及环境配置。 进入DevEco Studio下载官网&#xff0c;单击“立即下载”进入下载页面。 DevEco Studio提供了Windows版本和…

pyQt主界面与子界面切换简易框架

本篇来介绍使用python中是Qt功能包&#xff0c;设置一个简易的多界面切换框架&#xff0c;实现主界面和多个子界面直接的切换显示。 1 主界面 设计的Demo主界面如下&#xff0c;主界面上有两个按钮图标&#xff0c;点击即可切换到对应的功能界面中&#xff0c;进入子界面后&a…

猫12分类:使用多线程爬取图片的Python程序

本文目标 对于猫12目标检测部分的数据集&#xff0c;采用网络爬虫来制作数据集。 在网络爬虫中&#xff0c;经常需要下载大量的图片。为了提高下载效率&#xff0c;可以使用多线程来并发地下载图片。本文将介绍如何使用Python编写一个多线程爬虫程序&#xff0c;用于爬取图片…

飞翔的小鸟

运行游戏如下&#xff1a; 碰到柱子就结束游戏 App GameApp类 package App;import main.GameFrame;public class GameApp {public static void main(String[] args) {//游戏的入口new GameFrame();} } main Barrier 类 package main;import util.Constant; import util.Ga…

Linux--网络编程

一、网络编程概述1.进程间通信&#xff1a; 1&#xff09;进程间通信的方式有**&#xff1a;管道&#xff0c;消息队列&#xff0c;共享内存&#xff0c;信号&#xff0c;信号量这么集中 2&#xff09;特点&#xff1a;依赖于linux内核&#xff0c;基本是通过内核来实现应用层…

线上bug-接口速度慢

&#x1f47d;System.out.println(“&#x1f44b;&#x1f3fc;嗨&#xff0c;大家好&#xff0c;我是代码不会敲的小符&#xff0c;双非大四&#xff0c;Java实习中…”); &#x1f4da;System.out.println(“&#x1f388;如果文章中有错误的地方&#xff0c;恳请大家指正&a…

十四、Docker的基本操作

目录 &#xff08;一&#xff09;镜像命令 一、拉取Nginx 二、查看镜像 三、导出文件 四、删除镜像 五、加载镜像 &#xff08;二&#xff09;容器命令 一、例子&#xff1a;运行一个nginx容器 1、输入运行命令 2、使用命令查看宿主机ip 3、在外部浏览器访问 4、查看…

函数调用分析

目录 函数相关的汇编指令 JMP指令 call指令 ret指令 VS2019正向分析main函数 总结调用函数堆栈变化规律 x64dbg分析调用函数 IDA分析调用函数 函数相关的汇编指令 JMP指令 JMP 指令表示的是需要跳转到哪个内存地址&#xff0c;相当于是间接修改了 EIP 。 call指令 ca…

NX二次开发UF_CAM_ask_blank_matl_db_object 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;里海NX二次开发3000例专栏 UF_CAM_ask_blank_matl_db_object Defined in: uf_cam.h int UF_CAM_ask_blank_matl_db_object(UF_CAM_db_object_t * db_obj ) overview 概述 This function provides the database object which …

使ros1和ros2的bag一直互通

很多文章都是先source ros1 然后source ros2,再play bag source /opt/ros/noetic/setup.bash source /opt/ros/foxy/setup.bash ros2 bag play -s rosbag_v2 kitti_raw00.bag 但实测会出问题: 为使ros1和ros2的bag一直互通 sudo apt update sudo apt install ros-foxy-ro…

axios的原理及实现一个简易版axios

面试官&#xff1a;你了解axios的原理吗&#xff1f;有看过它的源码吗&#xff1f; 一、axios的使用 关于axios的基本使用&#xff0c;上篇文章已经有所涉及&#xff0c;这里再稍微回顾下&#xff1a; 发送请求 import axios from axios;axios(config) // 直接传入配置 axio…

hdfsClient_java对hdfs进行上传、下载、删除、移动、打印文件信息尚硅谷大海哥

Java可以通过Hadoop提供的HDFS Java API来控制HDFS。通过HDFS Java API&#xff0c;可以实现对HDFS的文件操作&#xff0c;包括文件的创建、读取、写入、删除等操作。 具体来说&#xff0c;Java可以通过HDFS Java API来创建一个HDFS文件系统对象&#xff0c;然后使用该对象来进…