鲜花数据集实验结果总结

news2025/1/11 9:52:49

从read_split_data中得到:训练数据集,验证数据集,训练标签,验证标签。的所有的具体详细路径

数据集位置:https://download.csdn.net/download/guoguozgw/87437634

import os
#一种轻量级的数据交换格式,
import json
#文件读/写操作
import pickle
import random
import matplotlib.pyplot as plt
def read_split_data(root:str,val_rate:float = 0.2):
    random.seed(0)#保证随机结果可重复出现
    assert os.path.exists(root),'dataset root:{} does not exist.'.format(root)

    #遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
    #排序,保证顺序一致
    flower_class.sort()
    #生成类别名称以及对应的数字索引,将数据转换为字典的类型。将标签分好类之后,其类别是key,对应的唯一值是value
    class_indices = dict((k,v) for v,k in enumerate(flower_class))
    #将数据编写成json文件
    json_str = json.dumps(class_indices,indent=4)
    with open('json_str','w') as json_file:
        json_file.write(json_str)

    train_images_path = [] #存储训练集的所有图片路径
    train_images_label = [] #存储训练集所有图片的标签
    val_images_path = [] #存储验证机所有图片的路径
    val_images_label = [] #存储验证机所有图片的标签
    every_class_num = [] #存储每个类别的样本总数

    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    #遍历每一个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root,cla)
        #遍历获取supported支持的所有文件路径,得到所有图片的路径地址。针对的是某一个类别。
        images = [os.path.join(root,cla,i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported]
        #获取该类别对应的索引,此时对应就是数字了。对应的只是一个数字
        image_class = class_indices[cla]
        #记录该类别的样本数量
        every_class_num.append(len(images))
        #按比例随机采样验证样本,按照0.2的比例来作为测试集。
        val_path = random.sample(images,k=int(len(images)*val_rate))

        for img_path in images:
            #如果该路径在采样的验证集样本中则存入验证集。否则的话存入到训练集当中。其中label和image是相互对应的。
            if img_path in val_path:
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print('该数据集一共有{}多张图片。'.format(sum(every_class_num)))
    print('一共有{}张图片是训练集'.format(len(train_images_path)))
    print('一共有{}张图片是验证集'.format(len(val_images_path)))
    #输出每一个类别对应的图片个数
    for i in every_class_num:
        print(i)

    plot_image = False
    if plot_image:
        #绘制每一种类别个数柱状图
        plt.bar(range(len(flower_class)),every_class_num,align='center')
        #将横坐标0,1,2,3,4替换成相应类别的名称
        plt.xticks(range(len(flower_class)),flower_class)
        #在柱状图上添加数值标签
        for i,v in enumerate(every_class_num):
            plt.text(x=i,y=v+5,s=str(v),ha='center')
        #设置x坐标
        plt.xlabel('image class')
        plt.ylabel('number of images')
        #
        plt.title('flower class distribution')
        plt.show()

    return train_images_path,train_images_label,val_images_path,val_images_label
if __name__ == '__main__':
    root = '../11Flowers_Predict/flower_photos'
    read_split_data(root)

最后得到的数据信息分别如此,代码中的路径需要进行更换(替换为自己的路径)。
请添加图片描述

从写Dataset类

from PIL import Image
import torch
from torch.utils.data import Dataset

class MyDataSet(Dataset):
    '''
    自定义数据集
    '''
    def __init__(self,images_path:list,images_classes:list,transform = None):
        super(MyDataSet, self).__init__()
        self.images_path = images_path
        self.images_classes = images_classes
        self.transform = transform
    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        #RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            #直接在这里终止程序的运行
            raise ValueError('image :{} is not RGB mode.'.format(self.images_path[item]))
        label = self.images_classes[item]

        if self.transform is not None:
            img = self.transform(img)

        return img , label

对数据集的预处理部分

import os
import torch
from torchvision import transforms
from utils import read_split_data
from my_dataset import MyDataSet
#数据集所在的位置
root = '../11Flowers_Predict/flower_photos'
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('using {} device.'.format(device))
    #接下来这一行是对数据的读取
    train_images_path,train_images_label,val_images_path,val_images_label = read_split_data(root)

    #设置transform,compose立main必须是列表
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    train_data_set = MyDataSet(images_path=train_images_path,
                               images_classes=train_images_label,
                               transform=data_transform['train'])
    val_data_set = MyDataSet(images_path=val_images_path,
                             images_classes=val_images_label,
                             transform=data_transform['val'])
    batch_size = 32
    #number of workers
    #nw = min([os.cpu_count() , batch_size if batch_size>1 else 0,8])
    #print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers = 0
                                               )
    val_loader = torch.utils.data.DataLoader(val_data_set,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers = 0)

    for step,data in enumerate(train_loader):
        images,labels = data
        #print(images.shape)
        #print(labels)
        #print(labels.shape)
    return train_loader,val_loader
if __name__ == '__main__':
    main()

开始对数据集进行训练

import torch
from torch import nn
import torchvision
from torchvision import transforms,models
from tqdm import tqdm
from main import *
import time
HP = {
    'epochs':25,
    'batch_size':32,
    'learning_rate':1e-3,
    'momentum':0.9,
    'test_size':0.05,
    'seed':1
}

#创建一个残差网络34层结果,使用预训练参数
model = models.resnet34(pretrained=True)
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.1),
    torch.nn.Linear(model.fc.in_features,5)
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.backends.cudnn.benchmark = True
print(f'using {device} device')
#将模型添加到gpu当中
model = model.to(device)

#分类问题使用交叉熵函数损失
criterion = torch.nn.CrossEntropyLoss()
#优化器使用SGD随机梯度下降法
optimizer = torch.optim.SGD(model.parameters(),lr=HP['learning_rate'],momentum=HP['momentum'])

train_loader,val_loader = main()

def train(model,criterion,optimizer,train_loader,val_loader):
    #设置总的训练损失和验证损失,以及训练准确度和验证准确度。
    total_train_loss = 0
    total_val_loss = 0
    total_train_accracy = 0
    total_val_accracy = 0

    model.train()#设置为训练模式
    loop = tqdm(enumerate(train_loader),total=len(train_loader))
    loop.set_description(f'training')
    for step,data in loop:
        images,labels = data
        #将数据添加到GPU当中
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        #单个损失
        loss = criterion(output,labels)
        #计算准确率
        accracy = (output.argmax(1)==labels).sum()
        #将所有的损失进行相加
        total_train_loss += loss.item()
        #将所有正确的全部相加起来
        total_train_accracy += accracy
        #开始进行层数更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    loop_val = tqdm(enumerate(val_loader),total=len(val_loader))
    loop_val.set_description(f'valuing')
    for step,data in loop_val:
        images,labels = data
        images = images.to(device)
        labels = labels.to(device)

        output = model(images)
        loss = criterion(output,labels)
        accracy_val = (output.argmax(1)==labels).sum()
        total_val_loss += loss.item()
        total_val_accracy += accracy_val

    train_acc = total_train_accracy/(2939)
    val_acc = total_val_accracy/(731)
    train_loss = total_train_loss/(2939)
    val_loss = total_val_loss/(731)

    print(f'训练集损失率: {train_loss:.4f} 训练集准确率: {train_acc:.4f}')
    print(f'验证集损失率: {val_loss:.4f} 验证集准确率: {val_acc:.4f}')

if __name__ == '__main__':
    time_start = time.time()
    for i in  range(HP['epochs']):
        print(f"Epoch {i+1}/{HP['epochs']}")
        train(model, criterion, optimizer, train_loader, val_loader)

    time_end = time.time()
    print(time_end-time_start)

json_str

{
    "daisy": 0,
    "dandelion": 1,
    "roses": 2,
    "sunflowers": 3,
    "tulips": 4
}

训练结束之后,可以得出来训练出来的结果。

总结部分:

一:针对全部是目录,且目录里面是已经分好类的数据集,且数据没有分成训练集和测试集
1:函数参数设置为:路径,划分的概率
2:设置一定的随机结果
3:判断该路径是否存在,使用assert
4:根据传过来的root,来判断当前路径下所有的文件夹,如果是文件夹将其写入到列表当中
5:同时这个列表也是所有的类别,将该列表进行排序
6:使用enumerate来使其成为字典,其中key对应的是分类,value对应的是数值
7:(可以选择)使用json可以将其写入到文件当中
8:创建训练集图片路径,训练集标签路径,验证集图片路径,验证集标签路径,每个类别的数目,都是列表形式
9:开始对文件进行遍历,然后将其存放到上面的集合当中
10:以根据类别以及root使用join将其连接起来。根据类别来进行循环,然后进行拼接
11:接这这个类别循环的时候,使用随机数来将其划分验证数据集和训练数据集

二:如果数据已经分好训练集和测试集的情况下,如果存在csv的文件情况下,可以使用pandas来进行数据处理
(shuffle函数是sklearn utils里面的类),
(对csv文件读取,主要使用到的是pandas库)
1:对读取到的csv文件可以首先使用head查看前几个数据
2:使用sklearn里面的shuffle方法来进行打乱顺序
3:使用pandas里面的factorize对标签进行数据化显示(把复杂计算分解为基本运算),其返回值为元祖
4:使用unique返回的是列表,将标签封装成列表
5:再将其相互对应封装为字典:key是类别,value是数字
6:使用sklearn中的train_test_split方法来对数据集进行划分,传入参数为(DataFrame,比例)
7:使用value_count来对标签进行计数

对DataSet的重写:
1:主要是实现其中的三个方法,init,getitem,len
2:init主要是接受参数,路径,类别,以及transforms,在这里一定要吧image处理到对应的每一张图片的身上
3:返回的是image格式的图片,以及一个标签数字

部分测试代码

#
import os


def main(root:int,images_class: list,transform = None):
    print('root:',root)
    print('int:', int)
    print('images_class:', images_class)
    print('list:', list)

def read_split_data(root:str,val_rate:float = 0.2):
    print('root:', root)
    print('str:', str)
    print('val_rate:', val_rate)
    print('float:', float)


root = '../11Flowers_Predict/flower_photos'
#遍历文件夹
'''
os.listdir是展示当前所在层的所有文件
os.isdir判断当前这个文件是否属于文件夹
os.path.join()将两个字符串进行连接中间用/
os.path.splittext()返回的是一个元祖
'''
flowers_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
print(flowers_classes)
flowers_classes_copy = flowers_classes.copy()
flowers_classes.sort()
print(os.path.isdir('../11Flowers_Predict/flower_photos'))
print(os.path.join(root,'roses'))
print(flowers_classes)
class_ind = dict((k, v) for v, k in enumerate(flowers_classes))
for v,k in enumerate(flowers_classes):
    print('此时标号{},对应的类别是{}.'.format(v,k))
for v,k in class_ind.items():
    print(v,k)
import json
json_str = json.dumps(class_ind,indent=2)
print(json_str)
with open('json_str','w') as json_file:
    json_file.write(json_str)

AA = os.path.splitext('123.jpg')
print(type(os.path.splitext('123.jpg')))
supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
print(AA[-1] in supported)
list = [1,2,3,4]
#main(root,list)
for cla in flowers_classes:
    image_class = class_ind[cla]

print(image_class)
import matplotlib.pyplot as plt
every_class_num = [633,898,641,699,799]
plt.bar(flowers_classes,every_class_num,align='center')
#   这个东西就是用来替换的
#plt.xticks(range(len(flowers_classes)),[10,11,12,13,14])
for i,v in enumerate(every_class_num):
    plt.text(x=i,y=v,s=str(v))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/338398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见漏洞之 struts2+ jboss

数据来源 本文仅用于信息安全的学习,请遵守相关法律法规,严禁用于非法途径。若观众因此作出任何危害网络安全的行为,后果自负,与本人无关。 01 Struts2相关介绍 》Struts2概述 》Struts2历史漏洞(1) 》…

【Linux】Linux多线程(下)

前言 大家好呀,欢迎来到我的Linux学习笔记~ 本篇承上Linux多线程创建,线程互斥(互斥锁),线程同步(条件变量),继下接着学习线程同步的另一个信号量,以及后序的线程池,线程的懒汉单例模式和其他锁相关知识。(注意本篇博客代码居多) Linux多线程…

C++005-C++选择与分支2

文章目录C005-C选择与分支2条件语句C实现else if 语句题目描述 根据成绩输出成绩等级ABCDEif嵌套语句题目描述 输出三个数中的最大值题目描述 模拟游戏登录switch语句三元运算符题目描述 输出三个数中的最大值-基于3元运算符题目描述 根据1-7输出星期1-星期日案例练习题目描述 …

php的api系统,php api 框架

本文目录一览: 1、php如何开发API接口2、什么是API?PHP的API怎么写?3、API和PHP是什么关系4、php中的API接口怎么写 ?5、如何使用PHP搭建一个restFul风格的API系统6、PHP 的API接口 php如何开发API接口 比如一个自定义函数:fun…

【遇见青山】项目难点:缓存击穿问题解决方案

【遇见青山】项目难点:缓存击穿问题解决方案1.缓存击穿互斥锁🔒方案逻辑过期方案2.基于互斥锁方案的具体实现3.基于逻辑过期方案的具体实现1.缓存击穿 缓存击穿问题也叫热点Key问题,就是一个被高并发访问并且缓存重建业务较复杂的key突然失效…

RuoYi-Cloud 部署

RuoYi-Cloud部署 1. 下载 点击右侧链接可以进入gitee的源码下载地址: 偌依微服务源码gitee下载地址 2. 数据库部署 依据如下步骤创建系统所需数据环境,脚本执行没有先后次序要求: 在Mysql 中创建 ry-cloud 主数据库,并执行 …

初学者必读:讲解 VC 下如何正确的创建、管理及发布项目

Visual C 的项目文件组成,以及如何正确的创建及管理项目。 本内容是初学者必须要掌握的。不能正确的管理项目,就不能进一步写有规模的程序。 一、项目下各种常见文件类型的功能 1. 代码文件 扩展名为 .cpp、.c、.h 等。 通常情况下,项目…

【Java】Help notes about JAVA

JAVA语言帮助笔记Java的安装与JDKJava命名规范JAVA的数据类型自动类型转换强制类型转换JAVA的运算符取余运算结果的符号逻辑运算的短路运算三元运算符运算符优先级JAVA的流程控制分支结构Java的安装与JDK JDK安装网站:https://www.oracle.com/java/technologies/do…

[项目设计]高并发内存池

目录 1、项目介绍 2、高并发内存池整体框架设计 3、thread cache <1>thread cache 哈希桶对齐规则 <2>Thread Cache类设计 4、Central Cache <1>Central Cache类设计 5、page cache <1>Page Cache类设计 6、性能分析 <1>定长内存池实现…

更换主板开机logo

更换主板开机logo前言详细操作步骤可能遇到的问题素材链接前言 在使用刀锋钛主板后发现&#xff0c;开机logo有些不符合个人喜好&#xff0c;如下图&#xff1a; 于是就有了更换主板logo的想法&#xff0c;确定用刷bios这一方法&#xff0c;注&#xff1a;刷BIOS之前一定要做…

MS14-064(OLE远程代码执行漏洞复现)

✅作者简介&#xff1a;CSDN内容合伙人、信息安全专业在校大学生&#x1f3c6; &#x1f525;系列专栏 &#xff1a;内网安全-漏洞复现 &#x1f4c3;新人博主 &#xff1a;欢迎点赞收藏关注&#xff0c;会回访&#xff01; &#x1f4ac;舞台再大&#xff0c;你不上台&#xf…

Java测试——selenium常见操作(2)

这篇博客继续讲解一些selenium的常见操作 selenium的下载与准备工作请看之前的博客&#xff1a;Java测试——selenium的安装与使用教程 先创建驱动 ChromeDriver driver new ChromeDriver();等待操作 我们上一篇博客讲到&#xff0c;有些时候代码执行过快&#xff0c;页面…

Axios异步请求 json格式

Axios是Ajax的一个框架,简化Ajax操作。需要axios.min.js 和vue.js的jar。发送普通参数异步请求以及相应异常情况客户端向服务器端异步发送普通参数值&#xff1a;- 基本格式&#xff1a; axios().then().catch()- 示例&#xff1a;axios({ // axios表示要发送一个异步请求metho…

12月无情被辞:想给还不会自动化测试的技术人提个醒

公司前段时间缺人&#xff0c;也面了不少测试&#xff0c;结果竟没有一个合适的。一开始瞄准的就是中级的水准&#xff0c;也没指望来大牛&#xff0c;提供的薪资在10-20k&#xff0c;面试的人很多&#xff0c;但是平均水平很让人失望。基本能用一句话概括就是&#xff1a;3年测…

火遍全网的ChatGPT,可免费使用啦

啰嗦几句最近最最最火爆的莫过于ChatGPT了&#xff0c;感觉你不知道ChatGPT是什么做什么&#xff0c;你都没法跟人交流了&#xff01;ChatGPT是美国OpenAI研发的聊天机器人程序&#xff0c;跟小冰、小爱、小度一样&#xff0c;但是不一样的是它拥有强大的信息整合能力&#xff…

【性能】性能测试理论篇_学习笔记_2023/2/11

性能测试的目的验证系统是否能满足用户提出的性能指标发现性能瓶颈&#xff0c;优化系统整体性能性能测试的分类注&#xff1a;这些测试类型其实是密切相关&#xff0c;甚至无法区别的&#xff0c;例如几乎所有的测试都有并发测试。在实际中不用纠结具体的概念。而是要明确测试…

子比主题v6.9.2 免费版源码下载及其激活步骤详解

本人版权所有&#xff0c;请勿打回&#xff01; 文章目录一&#xff0c;子比主题v6.9.2 免费版源码下载及其激活步骤1.1什么是Zibll子比主题&#xff1f;1.2特点二.效果展示2.1 部分源码2.2 效果展示三.源码下载及其视频演示3.1源码下载3.2视频演示一&#xff0c;子比主题v6.9.…

Golang map笔记

map定义三种方式package mainimport "fmt"func main() {// map 的基本定义// 第一种方式 使用make分配数据空间var map1 map[string]stringmap1 make(map[string]string, 3)map1["no1"] "北京"map1["no2"] "天津"map1[&q…

Mysql 增删改查(二)—— 增(insert)、删(delete)、改(update)

目录 一、插入 1、insert 2、replace&#xff08;插入否则更新&#xff09; 二、更新&#xff08;update&#xff09; 三、删除 1、delete 2、truncate&#xff08;截断表&#xff0c;慎用&#xff09; 一、插入 1、insert (1) 单行 / 多行插入 全列插入&#xff1a;…

可能是最强的Python可视化神器,建议一试!

数据分析离不开数据可视化&#xff0c;我们最常用的就是Pandas&#xff0c;Matplotlib&#xff0c;Pyecharts当然还有Tableau&#xff0c;看到一篇文章介绍Plotly制图后我也跃跃欲试&#xff0c;查看了相关资料开始尝试用它制图。 1.Plotly Plotly是一款用来做数据分析和可视…