动手学深度学习(Pytorch版)代码实践 -计算机视觉-39实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

news2025/1/11 8:58:47

39实战Kaggle比赛:狗的品种识别(ImageNet Dogs

比赛链接:Dog Breed Identification | Kaggle

1.导入包
import torch
from torch import nn
import collections
import math
import os
import shutil
import torchvision
from d2l import torch as d2l
import matplotlib.pyplot as plt
import liliPytorch as lp
2.数据集处理
# 精简数据集
# file_path = '../data/kaggle_dog_tiny/'
# 原数据集
file_path = '../data/dog-breed-identification/'

# 整理数据集
# 从原始训练集中拆分验证集,然后将图像移动到按标签分组的子文件夹中。
#@save
def read_csv_labels(fname):
    """读取CSV文件中的标签,它返回一个字典,该字典将文件名中不带扩展名的部分映射到其标签"""
    with open(fname, 'r') as f:
        # 跳过文件头行(列名)
        lines = f.readlines()[1:]
    tokens = [l.rstrip().split(',') for l in lines]
    return dict(((name, label) for name, label in tokens))

# labels = read_csv_labels(os.path.join(file_path, 'labels.csv'))
# print(labels) # {'0097c6242c6f3071762d9f85c3ef1b2f': 'bedlington_terrier', '00a338a92e4e7bf543340dc849230e75': 'dingo'}
# print('训练样本 :', len(labels)) # 训练样本 : 1000
# print('类别 :', len(set(labels.values()))) # 类别 : 120

# 定义reorg_train_valid函数来将验证集从原始的训练集中拆分出来
#@save
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)
    shutil.copy(filename, target_dir)

#@save
def reorg_train_valid(data_dir, labels, valid_ratio):
    """将验证集从原始的训练集中拆分出来"""
    # 训练数据集中样本最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label = max(1, math.floor(n * valid_ratio))
    label_count = {}
    for train_file in os.listdir(os.path.join(data_dir, 'train')): # 遍历训练集文件夹中的所有文件。
        label = labels[train_file.split('.')[0]] # 获取文件名(去掉扩展名)
        fname = os.path.join(data_dir, 'train', train_file) # 构建完整的文件路径

        copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                     'train_valid', label))
        
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                         'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test',
                                         'train', label))
    return n_valid_per_label


# reorg_test函数用来在预测期间整理测试集
#@save
def reorg_test(data_dir):
    """在预测期间整理测试集,以方便读取"""
    for test_file in os.listdir(os.path.join(data_dir, 'test')):
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test',
                              'unknown'))

def reorg_dog_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'labels.csv'))
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)


reorg_dog_data(file_path, valid_ratio = 0.1)
3.数据集加载
# 数据图像增广
# 训练
transform_train = torchvision.transforms.Compose([
    # 随机裁剪图像,所得图像为原始面积的0.08~1之间,高宽比在3/4和4/3之间。
    # 然后,缩放图像以创建224x224的新图像
    torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0),
                                             ratio=(3.0/4.0, 4.0/3.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    # 随机更改亮度,对比度和饱和度
    torchvision.transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
    # 添加随机噪声
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])])
# 测试
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    # 从图像中心裁切224x224大小的图片
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])])

# 读取数据集
# 创建数据集对象
# 通常用于定义数据源及其预处理方法。
train_dataset, train_valid_dataset = [
    # ImageFolder 创建数据集时,它会遍历指定目录下的所有子文件夹,
    # 并将每个子文件夹的名称作为一个类别标签。然后,它会按字母顺序给每个类别分配一个索引
    torchvision.datasets.ImageFolder(
        os.path.join(file_path, 'train_valid_test', folder),
        transform=transform_train
    ) for folder in ['train', 'train_valid']]

valid_dataset, test_dataset = [
    torchvision.datasets.ImageFolder(
        os.path.join(file_path, 'train_valid_test', folder),
        transform=transform_test
    ) for folder in ['valid', 'test']]

# 显示每个类别名称和对应的索引
# print(train_dataset.class_to_idx) 4
# {'affenpinscher': 0, 'afghan_hound': 1, 'african_hunting_dog': 2}

batch_size = 128
# 创建数据加载器
# 通常用于训练过程中按批次提供数据,具有更高效的数据加载和处理能力。
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_dataset, train_valid_dataset)]

valid_iter = torch.utils.data.DataLoader(
    valid_dataset, batch_size, shuffle=False,drop_last=True)

test_iter = torch.utils.data.DataLoader(
    test_dataset, batch_size, shuffle=False,drop_last=False)

4.预训练模型resnet34
# 用于创建和配置训练模型
def get_net(devices):
    # 创建一个空的 nn.Sequential 容器
    finetune_net = nn.Sequential()
    # 加载预训练的 ResNet-34 模型,并将其特征层(features)部分添加到 finetune_net 中
    finetune_net.features = torchvision.models.resnet34(pretrained=True)
    # 定义一个新的输出网络
    finetune_net.output_new = nn.Sequential(
        nn.Linear(1000, 256),
        nn.ReLU(),
        nn.Linear(256, 120)
    )
    # 将模型参数分配到指定的设备(如 GPU 或 CPU)
    finetune_net = finetune_net.to(devices[0])
    # 冻结预训练的特征层参数,以避免在训练过程中更新这些参数
    for param in finetune_net.features.parameters():
        param.requires_grad = False
    # 返回配置好的模型
    return finetune_net
5.模型训练
def train_batch(net, X, y, loss, trainer, devices):
    """使用多GPU训练一个小批量数据。
    参数:
    net: 神经网络模型。
    X: 输入数据,张量或张量列表。
    y: 标签数据。
    loss: 损失函数。
    trainer: 优化器。
    devices: GPU设备列表。
    返回:
    train_loss_sum: 当前批次的训练损失和。
    train_acc_sum: 当前批次的训练准确度和。
    """
    # 如果输入数据X是列表类型
    if isinstance(X, list):
        # 将列表中的每个张量移动到第一个GPU设备
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备
    y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备
    net.train() # 设置网络为训练模式
    trainer.zero_grad()# 梯度清零
    pred = net(X) # 前向传播,计算预测值
    l = loss(pred, y) # 计算损失
    l.sum().backward()# 反向传播,计算梯度
    trainer.step() # 更新模型参数
    train_loss_sum = l.sum()# 计算当前批次的总损失
    train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度
    return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和


def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):

    trainer = torch.optim.SGD(
        # net.parameters():返回模型 net 中所有参数。
        # if param.requires_grad:仅选择那些 requires_grad 为 True 的参数。
        # 这些参数是需要进行梯度更新的(即未冻结的参数)
        (param for param in net.parameters()if param.requires_grad), 
        # momentum用于加速 SGD 的收敛速度,通过在更新参数时考虑之前的更新方向,减少震荡
        # weight_decay权重衰减用于防止过拟合
        lr=lr,momentum=0.9, weight_decay=wd)
    # trainer = torch.optim.Adam(net.parameters(), lr=lr,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    loss = nn.CrossEntropyLoss(reduction="none")
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = lp.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch(net, features, labels,loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            # train_l = metric[0] / metric[2] # 计算训练损失
            # train_acc = metric[1] / metric[2] # 计算训练准确率
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[2],None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch + 1, (None, None, valid_acc))
        scheduler.step()
        # print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
        #       f'valid_acc {valid_acc:.3f}')
        
    measures = (f'train loss {metric[0] / metric[2]:.3f}, '
                f'train acc {metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
          f' examples/sec on {str(devices)}')
6.模型预测
def predict(file_path_module):
    # 预测
    net = get_net(d2l.try_all_gpus())
    net.load_state_dict(torch.load(file_path_module + 'imageNet_Dogs.params'))

    # 初始化一个空列表preds用于存储预测结果
    preds = []

    # 遍历测试集中的每一个数据和标签
    for data, label in test_iter:
        # 使用神经网络(net)对数据进行预测,并使用softmax函数将输出转化为概率分布
        output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=1)
        # 将预测结果从GPU中取出,转换为NumPy数组后,添加到preds列表中
        preds.extend(output.cpu().detach().numpy())

    # 获取测试数据文件夹中所有文件的id,并按字典顺序排序
    ids = sorted(os.listdir(
        os.path.join(file_path, 'train_valid_test', 'test', 'unknown')))

    # 打开一个新的CSV文件submission.csv用于写入预测结果
    with open(file_path + 'submission.csv', 'w') as f:
        # 写入CSV文件的表头,包含'id'和所有类别标签
        f.write('id,' + ','.join(train_valid_dataset.classes) + '\n')
        # 遍历文件id和对应的预测结果
        for i, output in zip(ids, preds):
            # 写入每个文件的id和对应的预测概率
            f.write(i.split('.')[0] + ',' + ','.join(
                [str(num) for num in output]) + '\n')
7.定义超参数并保存训练参数
# 定义模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 1e-4, 1e-4
lr_period, lr_decay, net = 10, 0.1, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 4, 0.9 (简略数据集)
# train loss 0.750, train acc 0.814, valid acc 0.646
# 647.4 examples/sec on [device(type='cuda', index=0)]

# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 10, 0.1 (原数据集)
# train loss 0.863, train acc 0.759, valid acc 0.844
# 830.8 examples/sec on [device(type='cuda', index=0)]
plt.show()

net = get_net(devices)
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,lr_decay)
# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 4, 0.9 (简略数据集)
# train loss 0.721, train acc 0.815
# 704.9 examples/sec on [device(type='cuda', index=0)]

# num_epochs, lr, wd, lr_period, lr_decay = 20, 1e-4, 1e-4, 10, 0.1 (原数据集)
# train loss 0.865, train acc 0.758
# 845.4 examples/sec on [device(type='cuda', index=0)]

plt.show()
# 保存模型参数
file_path_module = '../limuPytorch/module/'
torch.save(net.state_dict(), file_path_module + 'imageNet_Dogs.params')

简略数据集:
在这里插入图片描述
在这里插入图片描述

原始数据集:
在这里插入图片描述
在这里插入图片描述

8.预测提交kaggle
predict(file_path_module)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1892905.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

nacos开启鉴权后,springboot注册失败

1.确认Nacos版本 我的Nacos版本是1.4.2 2.确认Nacos相关依赖的版本之间兼容&#xff0c;一下是我的一些pom.xml依赖 <!--父级项目的--><parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifa…

MySQL:MySQL总结

文章目录 MySQL思维导图基础实际在 Innodb 存储引擎中&#xff0c;会用一个特殊的记录来标识最后一条记录&#xff0c;该特殊的记录的名字叫 supremum pseudo-record &#xff0c;所以扫描第二行的时候&#xff0c;也就扫描到了这个特殊记录的时候&#xff0c;会对该主键索引加…

深化产教融合“桥梁”作用!蓝卓携手宁波4大院校共育数智人才

建强“三支队伍”赋能新质生产力&#xff0c;为进一步加强新时代教师队伍建设改革&#xff0c;促进人才培养能力和服务企业能力“双提升”&#xff0c;7月2日&#xff0c;“2024企业实践工业互联网职业教育师资培训班”在蓝卓顺利开班。 来自宁波城市职业技术学院、宁波职业技…

代理IP和VPN有什么区别?该怎么选择?

今天我们来聊聊很多人关心的一个问题——代理IP和VPN到底有什么区别&#xff1f;虽然它们听起来差不多&#xff0c;但其实有很大的不同。这篇文章&#xff0c;小编就带大家一起了解一下吧&#xff01; 什么是代理IP&#xff1f; 代理IP是一种通过代理服务器替换用户真实IP地址…

c进阶篇(四):内存函数

内存函数以字节为单位更改 1.memcpy memcpy 是 C/C 中的一个标准库函数&#xff0c;用于内存拷贝操作。它的原型通常定义在 <cstring> 头文件中&#xff0c;其作用是将一块内存中的数据复制到另一块内存中。 函数原型&#xff1a;void *memcpy(void *dest, const void…

UE5 修改项目名称 类的名称

修改类的名称 这里推荐使用Rider编辑器修改&#xff0c;它会给你遍历所有的引用&#xff0c;然后一次性修改&#xff0c;并自动添加DefaultEngine.ini。接下来&#xff0c;我将给大家演示如何实现。 我们在一个类的文件上面选择重构此 然后选择重命名 在弹框内修改为新的名称…

PG实践|内置函数之GENERATE_SERIES之深入理解(二)

&#x1f4eb; 作者简介&#xff1a;「六月暴雪飞梨花」&#xff0c;专注于研究Java&#xff0c;就职于科技型公司后端工程师 &#x1f3c6; 近期荣誉&#xff1a;华为云云享专家、阿里云专家博主、腾讯云优秀创作者、ACDU成员 &#x1f525; 三连支持&#xff1a;欢迎 ❤️关注…

【Whisper】WhisperX: Time-Accurate Speech Transcription of Long-Form Audio

Abstract Whisper 的跨语言语音识别取得了很好的结果&#xff0c;但是对应的时间戳往往不准确&#xff0c;而且单词级别的时间戳也不能做到开箱即用(out-of-the-box). 此外&#xff0c;他们在处理长音频时通过缓冲转录

c++类模板--无法解析的外部符号

解决办法 文章目录 解决办法方法1(推荐).在主函数包含头文件时将实现模板类的函数也包含进来方法2.将模板类的实现方法写在头文件里面方法3.函数模板声明前加inline 可能错误2&#xff0c;类内实现友元输出重载 方法1(推荐).在主函数包含头文件时将实现模板类的函数也包含进来 …

七、函数练习

目录 1. 写一个函数可以判断一个数是不是素数。&#xff08;素数只能被1或其本身整除的数&#xff09; 2. 一个函数判断一年是不是闰年。 3.写一个函数&#xff0c;实现一个整形有序数组的二分查找。 4. 写一个函数&#xff0c;每调用一次这个函数&#xff0c;使得num每次增…

养老院人员定位系统如何实现

养老院人员定位系统应反应养老公寓情况、增加老人安全防范级别、加强安全保障措施&#xff0c;部署物联网设备及配套集成平台软件&#xff0c;实时定位人员信息及时反应老人救助行为&#xff0c;实现与视频、门禁一卡通等自动化监管设施联合动作&#xff0c;提高应急响应速度和…

【vite创建项目】

搭建vue3tsvitepinia框架 一、安装vite并创建项目1、用vite构建项目2、配置vite3、找不到模块 “path“ 或其相对应的类型声明。 二、安装element-plus1、安装element-plus2、引入框架 三、安装sass sass-loader1、安装sass 四、安装vue-router-next 路由1、安装vue-router42搭…

Mybatis入门の基础操作

1 Mybatis概述 MyBatis 是支持定制化 SQL、存储过程以及高级映射的优秀的持久层框架。MyBatis避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以对配置和原生Map使用简单的 XML 或注解&#xff0c;将接口和 Java 的 POJOs(Plain Old Java Objects,普通的…

vue3+vue-router+vite 实现动态路由

文章中出现的代码是演示版本&#xff0c;仅供参考&#xff0c;实际的业务需求会更加复杂 什么是动态路由 什么场景会用到动态路由 举一个最常见的例子&#xff0c;比如说我们要开发一个后台管理系统&#xff0c;一般来说后台管理系统都会分角色登录&#xff0c;这个时候也就涉…

幻兽帕鲁卡顿严重、延迟高怎么办?幻兽帕鲁卡顿问题处理

幻兽帕鲁更是一款支持多人游戏模式的生存制作游戏。玩家们可以邀请好友一同加入这个充满奇幻与冒险的世界&#xff0c;共同挑战强大的敌人&#xff0c;分享胜利的喜悦。在多人模式中&#xff0c;玩家之间的合作与竞争将成为游戏的一大看点。玩家们需要充分发挥自己的智慧和策略…

centos7的yum命令无法使用解决方案

文章目录 问题排查流程解决方案总结 问题 今天新建了个centos7的虚拟机发现yum无法正常使用 已加载插件&#xff1a;fastestmirror Determining fastest mirrors Could not retrieve mirrorlist http://mirrorlist.centos.org/?release7&archx86_64&repoos&infra…

单片机学习(14)--DS18B20温度传感器

DS18B20温度传感器 13.1DS18B20温度传感器基础知识1.DS18B20介绍2.引脚及应用电路3.内部结构框图4.存储器框图5.单总线介绍6.单总线电路规范7.单总线时序结构8.DS18B20操作流程9.DS18B20数据帧 13.2DS18B20温度读取和温度报警器代码1.DS18B20温度读取&#xff08;1&#xff09;…

linux模拟aix盘19c单机asm安装补丁

linux模拟盘aix盘vi /etc/rc.d/rc.local/bin/ln /dev/sda /dev/rhdisk2/bin/ln /dev/sdb /dev/rhdisk3 /bin/chown grid:oinstall /dev/rhdisk*chmod 660 /dev/rhdisk* 一、19c安装GI&#xff08;Standalone Oracle Restart&#xff09; su - grid配置环境变量vi .profileex…

【Linux】Linux用户,用户组,其他人

1.文件拥有者 初次接触Linux的朋友大概会觉得很怪异&#xff0c;怎么“Linux有这么多用户&#xff0c;还分什么用户组&#xff0c;有什用呢&#xff1f;”&#xff0c;这个“用户与用户组”的功能可是相当健全而且好用的一个安全防护措施。 怎么说呢&#xff1f;由于Linux是个…

人生感悟 | 努力奋斗和内卷不是一个意思。

哈喽&#xff0c;你好啊&#xff0c;我是雷工&#xff01; 有个很有趣的话题&#xff0c;是不是努力奋斗导致的内卷&#xff1f; 如果每个人都躺平&#xff0c;各行各业的内卷是不是就不存在了&#xff1f; 01 有关联不尽同 两者有关量&#xff0c;但无绝对的导向关系。 努力…