(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

news2025/1/11 16:41:46

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
                            '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')

# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:
    data_dir = d2l.download_extract('dog_tiny')
else:
    data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):
    labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))
    d2l.reorg_train_valid(data_dir, labels, valid_ratio)
    d2l.reorg_test(data_dir)

batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])

5. 读取数据

train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),
        transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),
        transform=transform_test
    ) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):
    finetune_net = nn.Sequential()
    finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
    # 定义一个新的输出网络,共有120个输出类别
    finetune_net.output_new = nn.Sequential(
        nn.Linear(1000, 256),
        nn.ReLU(),
        nn.Linear(256, 120)
    )
    finetune_net = finetune_net.to(devices[0])
    # 冻结参数
    for param in finetune_net.features.parameters():
        param.requires_grad = False

    return finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')

def evaluate_loss(data_iter, net, device):
    l_sum, n = 0.0, 0
    for features, labels in data_iter:
        features, labels = features.to(device[0]), labels.to(device[0])
        outputs = net(features)
        l = loss(outputs, labels)
        l_sum += l.sum()
        n += labels.numel()
        return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    # 只训练小型定义输出网络
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.SGD(
        (param for param in net.parameters() if param.requires_grad),
        lr=lr, momentum=0.9, weight_decay=wd
    )
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss']
    if valid_iter is not None:
        legend.append('valid loss')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(2)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            features, labels = features.to(devices[0]), labels.to(devices[0])
            trainer.zero_grad()
            output = net(features)
            l = loss(output, labels).sum()
            l.backward()
            trainer.step()
            metric.add(l, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(
                    epoch + (i + 1) / num_batches, (metric[0] / metric[1], None)
                )
        measures = f'train loss {metric[0] / metric[1]:.3f}'
        if valid_iter is not None :
            valid_loss = evaluate_loss(valid_iter, net, devices)
            animator.add(epoch + 1, (None, valid_loss.detach().cpu()))
        scheduler.step()
    if valid_iter is not None:
        measures += f', valid loss {valid_loss:.3f}'
    print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'
                     f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1240799.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文《Unsupervised Dialog Structure Learning》笔记:详解DD-VRNN

D-VRNN模型和DD-VRNN模型 总体架构 离散-可变循环变分自编码器(D-VRNN)和直接-离散-可变循环变分自编码器(DD-VRNN)概述。D-VRNN和DD-VRNN使用不同的先验分布来建模 z t z_t zt​之间的转换,如红色实线所示。 x t x_t…

爱创科技总裁谢朝晖荣获“推动医药健康产业高质量发展人物”

中国医药市场规模已经成为全球第二大医药市场,仅次于美国。近年来,随着中国经济的持续增长和人民生活水平的提高,医药市场需求不断扩大。政府对医疗卫生事业的投入也在不断加大,为医药行业的发展创造了良好的政策环境。为推动医药…

基于顺序表实现通讯录

1.功能实现 功能要求 1)至少能够存储100个人的通讯信息 2)能够保存用户信息:名字、性别、年龄、电话、地址等 3)增加联系人信息 4)删除指定联系人 5)查找制定联系人 6)修改指定联系人 7&#xf…

Sentinel 监控数据持久化(mysql)

Sentinel 实时监控仅存储 5 分钟以内的数据,如果需要持久化,需要通过调用实时监控接口来定制,即自行扩展实现 MetricsRepository 接口(修改 控制台源码)。 本文通过使用Mysql持久化监控数据。 1.构建存储表&#xff08…

java-String

String 1. String引入 1.1 构造方法 public static void main1(String[] args) {//构造方法String s1 "hello world";String s2 new String("yuanwei");char[] values {a,b,c};String s3 new String(values);System.out.println(s1);System.out.printl…

看不惯AI版权作品被白嫖!Stability AI副总裁选择了辞职,曾领导开发Stable Audio

近日,OpenAI的各种大瓜真是让人吃麻了。 而就在Sam Altmam被开除前两天,可能没太多人注意到Stability AI副总裁Newton—Rex因看不惯StabilityAI在版权保护上的行为选择辞职一事。 大模型研究测试传送门 GPT-4传送门(免墙,可直接…

记录一次因内存不足而导致hiveserver2和namenode进程宕机的排查

背景 最近发现集群主节点总有进程宕机,定位了大半天才找到原因,分享一下 排查过程 查询hiveserver2和namenode日志,都是正常的,突然日志就不记录了,直到我重启之后又恢复工作了。 排查各种日志都是正常的&#xff0…

windows搭建gitlab教程

1.安装gitlab 说明:由于公司都是windows服务器,这里安装以windows为例,先安装一个虚拟机,然后安装一个docker(前提条件) 1.1搜索镜像 docker search gitlab #搜索所有的docker search gitlab-ce-zh #搜索…

【css】Google第三方登录按钮样式修改

文章目录 场景前置准备修改样式官方属性修改样式CSS修改样式按钮的高度height和border-radiusLogo和文字布局 场景 需要用到谷歌的第三方登录,登录按钮有自己的样式。根据官方文档:概览 | Authentication | Google for Developers,提供两种第…

SPASS-ARIMA模型

基本概念 在预测中,对于平稳的时间序列,可用自回归移动平均(AutoRegres- sive Moving Average, ARMA)模型及特殊情况的自回归(AutoRegressive, AR)模型、移动平均(Moving Average, MA)模型等来拟合,预测该时间序列的未来值,但在实际的经济预测中,随机数据序列往往…

HarmonyOS ArkTS Video组件的使用(七)

概述 在手机、平板或是智慧屏这些终端设备上,媒体功能可以算作是我们最常用的场景之一。无论是实现音频的播放、录制、采集,还是视频的播放、切换、循环,亦或是相机的预览、拍照等功能,媒体组件都是必不可少的。以视频功能为例&a…

6-使用nacos作为注册中心

本文讲解项目中集成nacos,并将nacos作为注册中心使用的过程。本文不涉及nacos的原理。 1、项目简介 以一个演示项目为例,项目包含三个服务,调用及依赖如下图: 由图中可以看出,coupon-customer-serv为服务的消费者&a…

Python基础教程: sorted 函数

嗨喽,大家好呀~这里是爱看美女的茜茜呐 sorted 可以对所有可迭代的对象进行排序操作, sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。 从新排序列表。 👇 👇 👇 更多精彩机密、教程…

9.4 Windows驱动开发:内核PE结构VA与FOA转换

本章将继续探索内核中解析PE文件的相关内容,PE文件中FOA与VA,RVA之间的转换也是很重要的,所谓的FOA是文件中的地址,VA则是内存装入后的虚拟地址,RVA是内存基址与当前地址的相对偏移,本章还是需要用到《内核解析PE结构导…

【论文阅读笔记】Emu Edit: Precise Image Editing via Recognition and Generation Tasks

【论文阅读笔记】Emu Edit: Precise Image Editing via Recognition and Generation Tasks 论文阅读笔记论文信息摘要背景方法结果额外 关键发现作者动机相关工作1. 使用输入和编辑图像的对齐和详细描述来执行特定的编辑2. 另一类图像编辑模型采用输入掩码作为附加输入 。3. 为…

第三节-Android10.0 Binder通信原理(三)-ServiceManager篇

1、概述 在Android中,系统提供的服务被包装成一个个系统级service,这些service往往会在设备启动之时添加进Android系统,当某个应用想要调用系统某个服务的功能时,往往是向系统发出请求,调用该服务的外部接口。在上一节…

Vue批量全局处理undefined和null转为““ 空字符串

我们在处理后台返回的信息,有的时候返回的是undefined或者null,这种字符串容易引起用户的误解,所以需要我们把这些字符串处理一下。 如果每个页面都单独处理,那么页面会很冗余,并且后期如果有修改容易遗漏&#xff0c…

生成式AI与大语言模型,东软已经准备就绪

伴随着ChatGPT的火爆全球,数以百计的大语言模型也争先恐后地加入了这一战局,掀起了一场轰轰烈烈的“百模大战”。毋庸置疑的是,继方兴未艾的人工智能普及大潮之后,生成式AI与大语言模型正在全球开启新一轮生产力革新的科技浪潮。 …

PostgreSQL (Hologres) 日期生成

PostgreSQL 生成指定日期下一个月的日期 (在Hologres中,不支持递归查询) SELECTto_char(T, YYYYMMDD)::int4 AS date_int,date(T) AS date_str,date_part(year, T)::int4 AS year_int,date_part(month, T)::int4 AS month_int,date_part(da…

中职组网络安全B模块-渗透提权2

任务五:渗透提权2 任务环境说明: 仅能获取xxx的IP地址 用户名:test,密码:123456 访问服务器主机,找到主机中管理员名称,将管理员名称作为Flag值提交; Flag:doyoudoyoudo 访问服…