深度学习入门:自建数据集完成花鸟二分类任务

news2024/10/5 15:31:08

自建数据集完成二分类任务(参考文章)

1 图片预处理

1 .1 统一图片格式

找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错

RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1

pytorch的torchvision.transforms模块提供了许多用于图片变换/增强的函数。

1.1.1 把图片不等比例压缩为固定大小
transforms.Resize((600,600)),
1.1.2 裁剪保留核心区

因为主体要识别的图像一般在中心位置,所以使用CenterCrop,这里设置为(400, 400)

transforms.CenterCrop((400,400)),
1.1.3 处理成统一数据类型

这里统一成torch.float64方便神经网络计算,也可以统一成其他比如uint32等类型

transforms.ConvertImageDtype(torch.float64),
1.1.4 归一化进一步缩小图片范围

对于图片来说0~255的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)

  • 我们可以使用[0.5,0.5,0.5]mean,std来把数据归一化至[-1,1]
  • 也可以手动计算出所有的图片mean,std来归一化至均值为0,标准差为1的正态分布,
  • 一些深度学习代码常常使用mean=[0.485, 0.456, 0.406] ,std=[0.229, 0.224, 0.225]的归一化数据,这是在ImageNet的几百万张图片数据计算得出的结果
  • BN等方法也具有很出色的归一化表现,我们也会使用到

Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)

我们这里使用简单的[0.5,0.5,0.5]归一化方法,更新cls_dataset,加入transform操作 ,作为图片裁剪的预处理。

transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

关于transforms的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。

1.2 数据增强

当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms中的函数 ,也可以混合使用来根据已有数据创造新数据

        self.data_enhancement = transforms.Compose([
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomRotation(30)
        ])

2 创建自制数据集

2.1 以Dataset类接口为模版

class cls_dataset(Dataset):
    def __init__(self) -> None:
       # initialization
        
    def __getitem__(self, index):
        # return data,label in set 
    
    def __len__(self):
        # return the length of the dataset

2.2 创建set

2.2.1定义两个空列表data_list和target_list
2.2.2遍历文件夹
2.2.3读取图片对象,对每一个图片对象预处理后,分别将图片对象和对应的标签加入data_list和target_list中
2.2.4将data_list和target_list加入h5df_ile中
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_image

train_pic_path = 'test-set'
test_pic_path = 'training-set'

def create_h5_file(file_name):
    all_type = ['flower', 'bird']
    h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空

    #图片统一化处理
    transform = transforms.Compose([
        transforms.Resize((600, 600)),
        transforms.CenterCrop((400, 400)),
        transforms.ConvertImageDtype(torch.float64),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]
    )
    #数据增强

    data_list = []   #建立一个保存图片张量的空列表
    target_list = [] #建立一个保存图片标签的空列表

    #遍历文件夹建立数据集
    '''
    文件夹组成
    | —— train
    |   | —— flower
    |   |   | —— 图片1
    |   | —— bird
    |   | —— | —— 图片2
    | —— test
    |   | —— flower
    |   | —— bird
    '''

    dataset_kind = file_name.split('.')[0]
    #先判断缺失的文件是训练集还是测试集
    if dataset_kind == 'train':
        pic_file_name = train_pic_path
    else:
        pic_file_name = test_pic_path

    #再循环遍历文件夹
    for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):
        target = file_name_dir.split('/')[-1]
        if target in all_type:
            for file in files:
                pic = read_image(os.path.join(file_name_dir, file))  #以张量形式读取图片对象
                pic = transform(pic)    #预处理图片
                pic = np.array(pic).astype(np.float64)
                data_list.append(pic)   #将pic对象添加到列表里
                target_list.append(target.encode()) #将target编码后添加到列表里

    h5df_file.create_dataset("image", data=data_list)
    h5df_file.create_dataset("target", data=target_list)
    h5df_file.close()

class h5py_dataset(Dataset):
    def __init__(self, file_name) -> None:
        super().__init__()
        self.file_name = file_name    #指向文件的路径名
        #如果file_name指向的h5文件不存在,就新建一个
        if not os.path.exists(file_name):
            create_h5_file(file_name)

        
    def __getitem__(self, index):
        with h5py.File(self.file_name, 'r') as f:
            if f['target'][index].decode() == 'bird':   #如果在f文件的target列表中查找到index下标对应的标签是bird
                target = torch.tensor(0)
            else:
                target = torch.tensor(1)
        return f['image'][index], target

    def __len__(self):
        with h5py.File(self.file_name, 'r') as f:
            return len(f['target'])

def h5py_loader():
    train_file = 'train.hdf5'
    test_file = 'test.hdf5'

    train_dataset = h5py_dataset(train_file)
    test_dataset = h5py_dataset(test_file)

    train_data_loader = DataLoader(train_dataset, batch_size=4)
    test_data_loader = DataLoader(test_dataset, batch_size=4)

    return train_data_loader, test_data_loader


2.3 创建loader

实例化set对象后利用torch.utils.data.DataLoader

3 搭建网络

3.1 网络结构

在这里插入图片描述

3.2 参数计算

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1

参考文章

3.3 不成文规定

池化参数一般就是(2, 2)

中间的channel数量都是自己设定的,二的次方就行

kernelsize一般3或者5之类的

4 训练

加深对前面数据集组成理解

    for _, data in enumerate(train_loader):
        if isinstance(data, list):
            image = data[0].type(torch.FloatTensor).to(device)
            target = data[1].to(device)
        elif isinstance(data, dict):
            image = data['image'].type(torch.FloatTensor).to(device)
            target = data['target'].to(device)
        else:
            print(type(data))
            raise TypeError

for 循环中data的组成来源于构建set时,

    h5df_file.create_dataset("image", data=data_list)
    h5df_file.create_dataset("target", data=target_list)

写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合

在这里插入图片描述

5 测试

6 保存模型

改进

投影概率放到网络里面

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5.基于飞蛾扑火算法(MFO)优化的VMD参数(MFO-VMD)

代码的使用说明 基于飞蛾扑火算法优化的VMD参数 优化算法代码原理 飞蛾扑火优化算法(Moth-Flame Optimization,MFO)是一种新型元启发式优化算法,该算法是受飞蛾围绕火焰飞行启发而提出的,具有搜索速度快、寻优能力强的…

git常常用命令

这篇文章中,一些简单的,大家都知道的git 命令我就不再赘述,我只写出来最近在项目中常用到的一些命令。这些命令可以帮助我更好的开发。 git stash 请大家设想下面的场景,你的本地有两个分支,develop,fix分支&#xf…

万界星空科技QMS质量管理系统功能

QMS质量管理系统结合质量决策、综合质量管理、过程质量控制三个层次要素,帮助企业实现产品全寿命周期质量数据的及时、灵活、准确和全面采集。 通过质量管理软件能够实现质量数据科学处理和应用,包括数据的系统化组织、结构化存贮、便捷式查询、定制化统…

使用USB转JTAG芯片CH347在Vivado下调试

简介 高速USB转接芯片CH347是一款集成480Mbps高速USB接口、JTAG接口、SPI接口、I2C接口、异步UART串口、GPIO接口等多种硬件接口的转换芯片。 通过XVC协议,将CH347应用于Vivado下,简单尝试可以成功,源码如下,希望可以一起共建&a…

全链路压测的步骤及重要性

全链路压测是一种系统性的性能测试方法,旨在模拟真实用户场景下的完整操作流程,全面评估软件系统在不同压力下的性能表现。这种测试方法对于保证应用程序的高可用性、稳定性和可扩展性至关重要。 1. 全链路压测概述 全链路压测是在模拟实际用户使用场景的…

OpenAI 董事会宫斗始作俑者?一窥伊尔亚·苏茨克维内心世界

OpenAI 董事会闹剧应该是暂告一个段落了,Sam Altman和Greg Brockman等一众高管均已加入微软,还有员工写联名信逼宫董事会的戏码,关注度已经降下来了。 但是,这场宫斗闹剧的中心人物Ilya Sutskever大家关注度不算太高。他本人是纯粹的技术男,极少抛头露面透露其内心世界。…

TransmittableThreadLocal - 线程池中也可以传递参数了

一、InheritableThreadLocal的不足 InheritableThreadLocal可以用于主子线程之间传递参数,但是它必须要求在主线程中手动创建的子线程才可以获取到主线程设置的参数,不能够通过线程池的方式调用。 但是现在我们实际的项目开发中,一般都是采…

深度学习之三(卷积神经网络--Convolutional Neural Networks,CNNs)

概念 卷积神经网络(Convolutional Neural Networks,CNNs)是一种特殊的神经网络结构,专门用于处理具有网格状结构(如图像、音频)的数据。CNN 在计算机视觉领域取得了巨大成功,广泛应用于图像识别、物体检测、图像生成等任务。以下是 CNN 的主要理论概念: 在数学中,卷…

短视频配音软件有哪些?这些常用的短视频配音软件

短视频行业近年来发展得很快,几乎闯入了我们每个现代人的生活,它以其独有的特点和乐趣,也收获了大批短视频爱好者,配音是短视频创作过程中不可或缺的环节,今天,我们就来聊聊短视频配音及好用的配音软件。 短…

京东大数据(京东数据采集):2023年Q3线上投影仪品类销售数据分析报告

11月初,某知名投影仪企业发布了2023年三季度财报。数据显示,今年第三季度,公司营收依然不客观,连续第五个季度业绩持续下滑。 从鲸参谋数据也可以看出,今年Q3,京东平台上该品牌的销量环比下滑约35%&#x…

审计dvwa高难度命令执行漏洞的代码,编写实例说明如下函数的用法

审计dvwa高难度命令执行漏洞的代码 &#xff0c;编写实例说明如下函数的用法 代码&#xff1a; <?phpif( isset( $_POST[ Submit ] ) ) {// Get input$target trim($_REQUEST[ ip ]);// Set blacklist$substitutions array(& > ,; > ,| > ,- > ,$ …

第一次参加算法比赛是什么感受?

大家好&#xff0c;我是怒码少年小码。 冬日暖阳&#xff0c;好日常在。今天中午在食堂干饭的时候&#xff0c;我的手机&#x1f4f1;收到了一条收货信息。 阿&#xff1f;什么玩意儿&#xff1f;我又买啥了&#xff1f; 个败家玩意&#xff0c;我都准备好叨叨我自己&#x…

SpringCloud原理-OpenFeign篇(二、OpenFeign包扫描和FeignClient的注册原理)

文章目录 前言正文一、从启动类开始二、EnableFeignClients 的源码分析三、Import FeignClientsRegistrar 的作用四、FeignClientsRegistrar#registerFeignClients(...)五、饥饿注册&懒注册 FeignClientsRegistrar#registerFeignClient(...)六、通过Holder真正注册beanDefi…

一文概括AxureRP的优缺点和替代软件

AxureRP是目前流行的设计精美的用户界面和交互软件。AxureRP根据其应用领域提供了一组丰富的UI控制。 Axure是什么软件&#xff1f; Axure是目前流行的设计精美的用户界面和交互软件。Axure已经存在了近十年&#xff0c;让UX设计师轻松了解创建软件原型的细节。作为一种原型设…

【阿里云】图像识别 摄像模块 语音模块

USB 摄像头模块测试及配置 一、首先将 USB 摄像头插入到 Orange Pi 开发板的 USB 接口中二、然后通过 lsmod 命令可以看到内核自动加载了下面的模块三、通过 v4l2-ctl 命令可以看到 USB 摄像头的设备节点信息为 /dev/video0四、使用 fswebcam 测试 USB 摄像头五、使用 motion …

【SA8295P 源码分析】132 - GMSL2 协议分析 之 GPIO/SPI/I2C/UART 等通迅控制协议带宽消耗计算

【SA8295P 源码分析】132 - GMSL2 协议分析 之 GPIO/SPI/I2C/UART 等通迅控制协议带宽消耗计算 一、GPIO 透传带宽消耗计算二、SPI 通迅带宽消耗计算三、I2C 通迅带宽消耗计算四、UART 通迅带宽消耗计算系列文章汇总见:《【SA8295P 源码分析】00 - 系列文章链接汇总》 本文链接…

ROS2中Executors对比和优化

目录 SingleThreadExecutorEventExecutor SingleThreadExecutor 执行流程 EventExecutor 通信图

现在的发票有发票专用章吗?如何验证发票真伪?百望云为您详解!

大部分企业的财务都开始真正用上数电票了&#xff0c;但目前还是处于税控发票与数电票并行的阶段&#xff0c;一些财务朋友并没有深入理解二者的区别&#xff0c;就总会遇到以下的问题&#xff1a; 收到一张数电票&#xff0c;发现没有发票专用章&#xff0c;询问销售方为什么不…

什么是办公RPA?办公RPA解决什么问题?办公RPA实施难点在哪里?

什么是办公RPA&#xff1f; 办公RPA是一种能够模拟人类在计算机上执行任务的自动化软件。它可以在没有人工干预的情况下&#xff0c;执行重复的、规则化的任务&#xff0c;例如数据输入、网页爬取、电子邮件管理等。办公RPA可以帮助企业提高工作效率&#xff0c;降低人力成本&…

图像处理中常用的相似度评估指标

导读 有时候我们想要计算两张图片是否相似&#xff0c;而用来衡量两张图片相似度的算法也有很多&#xff0c;例如&#xff1a;RMSE、PSNR、SSIM、UQI、SIFT以及深度学习等。这篇文章主要介绍&#xff0c;RMSE、PSNR、SSIM、UQI这些指标的计算和应用&#xff0c;关于SIFT算法来…