kaggle叶子分类比赛(易理解)

news2025/1/11 13:04:13

说实话网上很多关于叶子分类比赛的代码能取得的成绩都很好,但对于我这个业余人员太专业了,而且很多文章都有自己的想法,这让我这个仿写沐神代码的小菜鸡甚是头痛。
但好在我还是完成了,虽然结果并不是很好,但是如果跟着沐神走的同学在学习上应该没什么大问题。于是这篇文章的重点不是调参获得一个好成绩,而是把牵扯到的难点与思路好好的解释一下,方便同学们模仿。

竞赛地址:https://www.kaggle.com/c/classify-leaves

文章目录

  • 第一部分 加载并读取数据
  • 第二部分 定义网络
  • 第三部分 损失函数,验证函数,优化器
  • 第四部分 训练
  • 可能出现的bug
  • 拓展内容
    • 正常加载图像数据的其他方式
    • 类别索引在做什么
    • train_iter迭代器在迭代时__getitem_在干什么

第一部分 加载并读取数据

难点:如何接受并处理图像数据–使用自定义函数进行处理

import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        初始化数据集。
        Args:
        csv_file (str): 数据集的csv文件路径,其中包含图像的文件名和标签。
        root_dir (str): 图像文件的根目录路径。
        transform (callable, optional): 一个可选的转换函数,用来对图像进行处理。
        """
        # 读取csv文件,并将数据存储到pandas DataFrame中。
        self.data_frame = pd.read_csv(csv_file)
        # 存储图像文件的根目录路径。
        self.root_dir = root_dir
        # 存储可选的图像转换函数。
        self.transform = transform

        # 将字符串类型的标签转换为整数索引,同时获取标签到整数索引的映射。
        self.data_frame['label'], self.label_mapping = pd.factorize(self.data_frame['label'])

    def __len__(self):
        """
        返回数据集中的样本数。
        """
        return len(self.data_frame)

    def __getitem__(self, idx):
        """
        根据给定的索引idx获取对应的数据项。
        Args:
        idx (int): 数据项的索引。
        
        Returns:
        tuple: 包含图像和其对应标签的元组。
        """
        # 如果idx是torch tensor类型,先转换为列表。
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # 构建图像文件的完整路径。
        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
        # 打开图像文件。
        image = Image.open(img_name)

        # 获取对应的标签(整数索引)。
        label = self.data_frame.iloc[idx, 1]

        # 如果有转换函数,应用之。
        if self.transform:
            image = self.transform(image)

        # 返回图像和标签。
        return image, label
    
    def get_num_classes(self):
        """
        返回数据集中不同类别的总数。
        """
        return len(self.label_mapping)



transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(224),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建图像数据集实例
dataset = CustomDataset(csv_file='C:/Users/xiaox/pytorch/SucTest/train.csv',
                        root_dir='C:/Users/xiaox/pytorch/SucTest',
                        transform=transform)

num_classes = dataset.get_num_classes()
print(f"Total number of classes: {num_classes}")

# 数据加载和划分
from torch.utils.data import DataLoader, random_split
total_size = len(dataset)
train_size = int(total_size * 0.8)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# 加载数据集
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
1.如何自定义Dataset以用来灵活处理图像数据[数据的变化]
    1.定义__init__函数(读取csv文件,图像文件,transform,标签编码) [相当于将csv文件读取到Dataframe数据类型中,将标签映射为整数] 
    2.定义__getitem__函数(获取图像并转换,与图像对应的标签的索引)[相当于返回一张被转换的图片 与图片对应的Label对应的整数索引]

2.为什么使用类别索引将字符串映射成整数
    最重要的一点:神经网络中字符串无法转化为tensor类型,无法加入到net网络中

3.为什么选择类别索引而不是独热编码[独热编码就是预测房价中对于各个字符串标签的处理方法]
    独热编码在交叉熵损失函数中不适用

可拓展内容:
1.正常加载图像数据的其他方式(dataset,compose,data_loader的关系)
2.类别索引在做什么
3.train_iter迭代器在迭代时__getitem_在干什么

第二部分 定义网络

使用了Resnet50

from torch import nn
from d2l import torch as d2l
from torch.nn import functional as F
import torchvision.models as models


model = models.resnet50(weights=None)  # 使用预训练的ResNet-50


# 首先获取全连接层的输入特征数量
num_ftrs = model.fc.in_features

# 使用Dropout层和新的全连接层创建一个新的Sequential模块
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_ftrs, 176)
)

第三部分 损失函数,验证函数,优化器

这里我使用了Adam作为优化器

#这是评估模型平均准确率的函数
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        #1
        net.eval()  # 设置为评估模式
        #2
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    #3
    metric = d2l.Accumulator(2)
    #4
    ## 4.1
    with torch.no_grad():
        ## 4.2
        for X, y in data_iter:
            ### 4.2.1
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            ### 4.2.2
            y = y.to(device)
            ### 4.2.3 注意:d2l原有库可能表示:acc = d2l.accuracy(net(X), y) metric.add(acc * y.numel(), y.numel())
            print(d2l.accuracy(net(X), y))
            metric.add(d2l.accuracy(net(X), y), y.numel())
             
    #5
    return metric[0] / metric[1]

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device,weight_decay):
    """用GPU训练模型(在第六章定义)"""
    #1
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    #2
    net.to(device)
    #更改了优化器
    #optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    #4
    loss = nn.CrossEntropyLoss()
    #5
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    #6
    timer, num_batches = d2l.Timer(), len(train_iter)
    #7
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        #7.1
        metric = d2l.Accumulator(3)
        #7.2
        net.train()
        #7.3
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                #7.4
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            #7.5
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            #7.6
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

第四部分 训练

## 开始训练
lr = 1e-4
batch_size = 128
num_epochs = 20
weight_decay = 1e-3

train_ch6(model, train_loader, test_loader, num_epochs, lr, d2l.try_gpu(),weight_decay)

可能出现的bug

在这里插入图片描述

CUDA错误
1.检查数据类型与形状是否合理(断言测试)
2.检查网络输出种类是否正常(获取类别个数)
3.检查网络是否正常(前向输出测试)
4.检查网络每一层是否正常(循环测试)

首先:可以尝试重启,有可能是把内存用完了,重启试一下在进行下面的排查

拓展内容

正常加载图像数据的其他方式


#### 当图片所在文件夹代表一个标签时使用或数据集有对应的加载函数
import torch 
from torchvision import transforms,datasets
from torch import nn
from d2l import torch as d2l

# 0.定义载入图像的格式 AlexNet的输入是227
transform = transforms.Compose([
    transforms.Resize(256),                    # 将图像缩放,使最短边为256像素
    transforms.CenterCrop(227),                # 从图像中心裁剪224x224大小的图像
    transforms.ToTensor(),                     # 将图像转换为PyTorch张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化处理
])

# 例子:读取图像数据(图片所在文件夹代表一个标签)
dataset = datasets.ImageFolder(root='C:\\Users\\xiaox\\pytorch\\SucTest\\img\\', transform=transform)

# 例子:加载 CIFAR-10 数据集(数据集有对应的加载函数)
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# 2.定义迭代器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

类别索引在做什么

## 标签编码定义
在处理分类问题时,尤其是在使用机器学习或深度学习模型时,通常需要将文本或字符串类型的标签(labels)转换成整数索引。这是因为大多数算法都优化以处理数值数据,而不是文本数据。在你的代码中,这个转换是通过 Pandas 的 `factorize` 函数实现的。

### `pd.factorize()`
这个函数用于将一个具有重复值的数组转换为一个整数数组,其中每个唯一值都被分配一个整数标识符。它还返回一个包含原始数据中唯一值的数组,这可以作为标签到整数的映射。

#### 示例解释

假设你有一个CSV文件,其中包含如下的数据,其中每行代表一个样本,第一列是图像的文件名,第二列是图像的标签(如动物种类):

```
image_name, label
cat001.jpg, cat
dog001.jpg, dog
cat002.jpg, cat
bird001.jpg, bird
```

使用 `pd.factorize()` 函数处理 `label` 列时,会发生以下操作:


labels, label_mapping = pd.factorize(['cat', 'dog', 'cat', 'bird'])
```

结果:
- `labels` 会是 `[0, 1, 0, 2]`。这里,'cat' 被映射为 0,'dog' 被映射为 1,'bird' 被映射为 2。注意,第一个出现的标签('cat')是第一个被赋予新索引的。
- `label_mapping` 会是 `['cat', 'dog', 'bird']`,这是一个数组,其中索引位置对应于在 `labels` 中分配给每个唯一标签的整数。

通过这种方式,原始的字符串标签被转换为整数,使得它们可以更容易地被模型处理,同时你还保持了一个从整数索引回到原始标签的映射,这在模型预测结束后,将预测的整数标签转换回人类可读的标签时非常有用。

train_iter迭代器在迭代时__getitem_在干什么

for X,y in train_iter:做了什么

DataLoader 创建一个迭代器。

每次迭代时,从数据集(通过 Dataset 对象)中请求下一批数据。
(既向dataset对象随即指定batch_size个索引,来获取数据)

数据集的 __getitem__ 方法按索引获取数据和标签,这通常是随机访问,支持数据的随机打乱和批处理。
(dataset通过idx与__getitem__获得指定的数据然后返回给dataloader直到所有的batch_size个数据都被返回)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1651162.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring+Vue的卓越托管中心管理系统的设计与实现+PPT+论文+讲解+售后

相比于以前的传统手工管理方式,智能化的管理方式可以大幅降低运营人员成本,实现了卓越托管中心管理系统的标准化、制度化、程序化的管理,有效地防止了卓越托管中心管理系统的随意管理,提高了信息的处理速度和精确度,能…

PTA|小字辈

题目 本题给定一个庞大家族的家谱,要请你给出最小一辈的名单。 输入格式: 输入在第一行给出家族人口总数 N(不超过 100 000 的正整数) —— 简单起见,我们把家族成员从 1 到 N 编号。随后第二行给出 N 个编号&#x…

顺序表的实现(迈入数据结构的大门)

什么是数据结构 数据结构是由:“数据”与“结构”两部分组成 数据与结构 数据:如我们所看见的广告、图片、视频等,常见的数值,教务系统里的(姓名、性别、学号、学历等等); 结构:当…

三分钟了解计算机网络核心概念-数据链路层和物理层

计算机网络数据链路层和物理层 节点:一般指链路层协议中的设备。 链路:一般把沿着通信路径连接相邻节点的通信信道称为链路。 MAC 协议:媒体访问控制协议,它规定了帧在链路上传输的规则。 奇偶校验位:一种差错检测方…

【RT-DETR有效改进】 主干篇 | 2024.5全新的移动端网络MobileNetV4改进RT-DETR(含MobileNetV4全部版本改进)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、本文介绍 本文给大家带来的改进机制是MobileNetV4,其发布时间是2024.5月。MobileNetV4是一种高度优化的神经网络架构,专为移动设备设计。它最新的改动总结主要有两点&…

【intro】图注意力网络(GAT)

论文阅读 https://arxiv.org/pdf/1710.10903 abstract GAT,作用于图结构数据,采用masked self-attention layers来弥补之前图卷积或类似图卷积方法的缺点。通过堆叠layers,让节点可以添加其邻居的特征,我们就可以给不同的邻居节…

java-串口通讯-连接硬件

串口通信(Serial Communications)的概念非常简单,串口按位(bit)发送和接收字节。尽管比按字节(byte)的并行通信慢,但是串口可以在使用一根线发送数据的同时用另一根线接收数据。它很…

04.2.配置应用集

配置应用集 应用集的意思就是:将多个监控项添加到一个应用集里面便于管理。 创建应用集 填写名称并添加 在监控项里面找到对应的自定义监控项更新到应用集里面 选择对应的监控项于应用集

45 套接字

本节重点 认识ip地址,端口号,网络字节序等网络编程中的基本概念 学习scoket,api的基本用法 能够实现一个简单的udp客户端/服务端 能够实现一个简单的tcp客户端/服务器(但链接版本,多进程版本,多线程版本&a…

时间复杂度与空间复杂度(上篇)

目录 前言时间复杂度 前言 算法在运行的过程中要消耗时间资源和空间资源 所以衡量一个算法的好坏要看空间复杂度和时间复杂度, 时间复杂度衡量一个算法的运行快慢 空间复杂度是一个算法运行所需要的额外的空间 一个算法中我们更关心的是时间复杂度 时间复杂度 时…

【快捷部署】023_HBase(2.3.6)

📣【快捷部署系列】023期信息 编号选型版本操作系统部署形式部署模式复检时间023HBase2.3.6Ubuntu 20.04tar包单机2024-05-07 注意:本脚本非全自动化脚本,有2次人工干预,第一次是确认内网IP,如正确直接回车即可&#…

什么软件能在桌面提醒我 电脑桌面提醒软件

在这个信息爆炸的时代,我们每个人每天都需要处理海量的信息和任务。有时候,即便是再细心的人,也难免会因为事情太多而忘记一些重要的细节。 我就经常遇到这样的问题,明明记得自己有个重要的会议要参加,或者有个关键的…

扭蛋机小程序在互联网浪潮中的崛起与发展

随着互联网的快速发展,各种线上娱乐方式层出不穷,其中扭蛋机小程序凭借其独特的魅力,在互联网浪潮中迅速崛起并发展壮大。扭蛋机小程序不仅打破了传统扭蛋机的地域限制和操作不便,还融入了丰富的互动元素和便捷性,满足…

纯血鸿蒙APP实战开发——自定义安全键盘案例

介绍 金融类应用在密码输入时,一般会使用自定义安全键盘。本示例介绍如何使用TextInput组件实现自定义安全键盘场景,主要包括TextInput.customKeyboard绑定自定义键盘、自定义键盘布局和状态更新等知识点。 效果图预览 实现思路 1. 使用TextInput的cu…

为什么你的企业需要微信小程序?制作微信小程序有什么好处?

什么是小程序? WeChat小程序作为更大的WeChat生态系统中的子应用程序。它们就像更小、更基本的应用程序,在更大的应用程序(WeChat)中运行。这些程序为用户提供了额外的高级功能,以便在使用WeChat服务时加以利用。根据…

linux系统 虚拟机的安装详细步骤

window: (1) 个人:win7 win10 win11 winxp (2)服务器:windows server2003 2008 2013 linux: (1)centos7 5 6 8 (2)redhat (3)ubuntu (4)kali 什么是linux: 主要是基于命令来完成各种操作,类似于DO…

0基础学PHP有多难?

php作为web端最佳的开发语言,没有华而不实,而是经受住了时间考验,是一门非常值得学习的编程语言。 目前市场上各种网站、管理系统、小程序、APP等,基本都是使用PHP开发的,也侧面反映了PHP的需求以及学习的必要性&…

UTONMOS:真正的“游戏元宇宙”还有多遥远?

元宇宙来源于科幻小说的概念,已成为真实世界中的流行语。围绕这一新兴概念,一场产、学、研的实践正在展开。 数字化转型中,元宇宙能否担当大任?这些新概念在中国语境下如何落地?本文将深入挖掘国内元宇宙游戏产业的发…

数据结构-线性表-应用题-2.2-6

从有序顺序表中删除所有其值重复的元素,使表中的元素的值均不同 有序顺序表,值相同的元素一定在连续的位置上,初始时将第一个元素是为非重复的有序表,之后依次判断后面的元素是否与前面的非重复表的最后一个元素相同,…

当AI遇见现实:数智化时代的人类社会新图景

文章目录 一、数智化时代的机遇二、数智化时代的挑战三、如何适应数智化时代《图解数据智能》内容简介作者简介精彩书评目录精彩书摘强化学习什么是强化学习强化学习与监督学习的区别强化学习与无监督学习的区别 前言/序言 随着科技的日新月异,我们步入了一个前所未…