034、test

news2025/1/10 20:33:41

之——全纪录

目录

之——全纪录

杂谈

正文

1.下载处理数据

2.数据集概览

3.构建自定义dataset

4.初始化网络

5.训练


杂谈

        综合方法试一下。


leaves

1.下载处理数据

        从官网下载数据集:Classify Leaves | Kaggle

        解压后有一个图片集,一个提交示例,一个测试集,一个训练集。

        images,27153个树叶图片:

        test.csv,8800个:

        train.csv,18353个:


2.数据集概览

        训练集、测试集、类别:

#导包
import random
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision
import pandas as pd
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Image

train_data=pd.read_csv(r"D:\apycharmblackhorse\leaves\train.csv")
test_data=pd.read_csv(r"D:\apycharmblackhorse/leaves/test.csv")

train_images=train_data.iloc[:,0].values #把所有的训练集图片路径读进来成list
print("训练集数量:",len(train_images))
n_train=len(train_images)
test_images=test_data.iloc[:,0].values
print("测试集数量:",len(test_images))
n_test=len(test_images)

train_labels = pd.get_dummies(train_data.iloc[:, 1]).values.astype(int).argmax(1)
#独热编码后找到每行最大的索引记下来就是类别号,而顺序与独热编码colums,也就是与下方排序一致
# print(len(train_labels),train_labels)

#记录并排序所有的类别名
train_labels_header = pd.get_dummies(train_data.iloc[:, 1]).columns.values
print("总类别:",len(train_labels_header))
classes=len(train_labels_header)


3.构建自定义dataset

       继承 torch.utils.Dataset 类,自定义树叶分类数据集:

#继承 torch.utils.Dataset 类,自定义树叶分类数据集
class leaves_dataset(torch.utils.data.Dataset):
    #root数据目录, images图片路径, labels图片标签, transform数据增强
    def __init__(self, root, images, labels, transform):
        super(leaves_dataset, self).__init__()
        self.root = root
        self.images = images
        if labels is None:
            self.labels = None
        else:
            self.labels = labels
        self.transform = transform
    #获得指定样本
    def __getitem__(self, index):
        image_path = self.root + self.images[index]
        image = Image.open(image_path)
        #预处理
        image = self.transform(image)
        if self.labels is None:
            return image
        label = torch.tensor(self.labels[index])
        return image, label
    #获得数据集长度
    def __len__(self):
        return self.images.shape[0]

        构建读取数据与预处理:

def load_data(images, labels, batch_size, train):
    aug = []
    normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if (train):
        aug = [torchvision.transforms.CenterCrop(224),
               transforms.RandomHorizontalFlip(),
               transforms.RandomVerticalFlip(),
               transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
               transforms.ToTensor(),
               normalize]
    else:
        aug = [torchvision.transforms.Resize([256, 256]),
               torchvision.transforms.CenterCrop(224),
               transforms.ToTensor(),
               normalize]
    transform = transforms.Compose(aug)
    dataset = leaves_dataset(r"D:\apycharmblackhorse\leaves\\", images, labels, transform=transform)
    if train==True:type="训练"
    else:type="测试"
    print("载入:",dataset.__len__(),type)
    return torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, shuffle=train)

train_iter = load_data(train_images, train_labels, 512, train=True)

4.初始化网络

        使用官方预训练模型初始化网络,并修改输出类别数:

#初始化网络
net = torchvision.models.resnet18(pretrained=True)

net.fc = nn.Linear(net.fc.in_features, classes)
nn.init.xavier_uniform_(net.fc.weight)
net.fc


5.训练

         定义迭代器、优化器以及其他超参数,进行训练:

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=64, num_epochs=20,
                      param_group=True):
    train_slices = random.sample(list(range(n_train)), 15000)
    test_slices = list(set(range(n_train)) - set(train_slices))

    train_iter = load_data(train_images[train_slices], train_labels[train_slices], batch_size, train=True)
    test_iter = load_data(train_images[test_slices], train_labels[test_slices], batch_size, train=False)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        #别的层不变,最后一层10倍学习率
        trainer = torch.optim.Adam([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else:
        trainer = torch.optim.Adam(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    print(111)
    try:
        d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
    except Exception as e:
        print(e)



#%%

#较小的学习率,通过微调预训练获得的模型参数
train_fine_tuning(net, 1e-3)

        小破脑跑得慢,之前不用预训练5个epoch后acc大概只能到0.3  ,使用预训练后到了0.6,但实际上感觉对于树叶的针对性分类还是需要从头开始才是最好的选择,资源不够这里就不做尝试了,大概尝试情况:


CIFAR-10

1.数据集


2.未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1221002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaWeb[总结]

文章目录 一、Tomcat1. BS 与 CS 开发介绍1.1 BS 开发1.2 CS 开发 2. 浏览器访问 web 服务过程详解(面试题)2.1 回到前面的 JavaWeb 开发技术栈图2.2 浏览器访问 web 服务器文件的 UML时序图(过程) ! 二、动态 WEB 开发核心-Servlet1. 为什么会出现 Servlet2. 什么是…

【C++】模版-初阶

目录 泛型编程--模版 函数模版 类模版 泛型编程--模版 函数模版 如何实现一个通用的交换函数呢?void Swap(int& left, int& right){int temp left;left right;right temp;}void Swap(double& left, double& right){double temp left;left right;righ…

jbase虚拟M层的设计

对于只是自己产品内部使用的打印程序来说(比如打印收费单,打印结算单等),打印逻辑写在js,获取其他层都是没毛病的。但是对于类型检验报告这种打印来说,打印格式控制逻辑写在js层是百分百不行的。因为检验报…

数据结构-哈希表(C语言)

哈希表的概念 哈希表就是: “将记录的存储位置与它的关键字之间建立一个对应关系,使每个关键字和一个唯一的存储位置对 应。” 哈希表又称:“散列法”、“杂凑法”、“关键字:地址法”。 哈希表思想 基本思想是在关键字和存…

Express.js 与 Nest.js对比

Express.js 与 Nest.js对比 自从 Node.js 发布以来,Javascript 在后端领域的使用有所增加。由于 Node.js 的使用越来越多,每天都会有新的框架和工具发布。Express 和 Nest 是使用 Node.js 创建后端应用程序的最著名的框架之一,在本文中&…

数据结构与算法之美学习笔记:20 | 散列表(下):为什么散列表和链表经常会一起使用?

目录 前言LRU 缓存淘汰算法Redis 有序集合Java LinkedHashMap解答开篇 & 内容小结 前言 本节课程思维导图: 今天,我们就来看看,在这几个问题中,散列表和链表都是如何组合起来使用的,以及为什么散列表和链表会经常…

【代码随想录】算法训练计划21、22

day 21 1、530. 二叉搜索树的最小绝对差 题目: 给你一个二叉搜索树的根节点 root ,返回 树中任意两不同节点值之间的最小差值 。 差值是一个正数,其数值等于两值之差的绝对值。 思路: 利用了二叉搜索树的中序遍历特性用了双指…

线性表的概念

目录 1.什么叫线性表2.区分线性表的题 1.什么叫线性表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串… 线性表在逻辑上是…

2.4 矩阵的运算法则

矩阵是数字或 “元素” 的矩形阵列。当矩阵 A A A 有 m m m 行 n n n 列,则是一个 m n m\times n mn 的矩阵。如果矩阵的形状相同,则它们可以相加。矩阵也可以乘上任意常数 c c c。以下是 A B AB AB 和 2 A 2A 2A 的例子,它们都是 …

JSplacement丨随机生成置换贴图

界面很简单,虽然是英文,但基本也能看懂,参数调一调,随机生成不重复的8K高清图片。 这种图片可能对普通人感觉很奇怪,有什么用呢?会C4D建模渲染的同学应该会明白,特别是建一些科技类的场景背景&a…

短期经济波动:均衡国民收入决定理论(三)

短期经济波动:国民收入决定理论(三) 文章目录 短期经济波动:国民收入决定理论(三)[toc]1 总需求曲线及其变动1.1 总需求曲线含义1.2 总需求曲线推导1.2.1 代数推导1.2.2 几何推导 1.3 AD曲线及其变动1.3.1 扩张性财政政策1.3.2 扩张性货币政策 2 总供给曲…

2023年AI生成音频研究报告

第一章 行业概况 1.1 定义 AI音频生成行业,作为人工智能生成内容(AIGC)技术渗透的关键领域,正迅速成为技术革新的前沿阵地。这一领域专注于运用先进的人工智能技术和复杂算法来创造音频内容,覆盖了语音合成、音乐制作…

常见面试题-HashMap源码

了解 HashMap 源码吗? 参考文章:https://juejin.cn/post/6844903682664824845 https://blog.51cto.com/u_15344989/3655921 以下均为 jdk1.8 的 HashMap 讲解 首先,HashMap 的底层结构了解吗? 底层结构为:数组 链…

C语言--给定一行字符串,获取其中最长单词【图文详解】

一.问题描述 给定一行字符串,获取其中最长单词。 比如:给定一行字符串: hello wo shi xiao xiao su 输出:hello 二.题目分析 “打擂台算法”,具体内容小伙伴们可以参考前面的内容。 三.代码实现 char* MaxWord(const char* str)…

初始MySQL(七)(MySQL表类型和存储引擎,MySQL视图,MySQL用户管理)

目录 MySQL表类型和存储引擎 MyISAM MEMORY MySQL视图 我们先说说视图的是啥? 视图的一些使用细节 MySQL用户管理 原因 常见操作 MySQL表类型和存储引擎 -- 查看所有的存储引擎 SHOW ENGINES 我们常见的表有MyISAM InnoDB MEMORY 1.MyISAM不支持事务,也不支持外…

群晖7.2版本安装CloudDriver2(套件)挂载alist(xiaoya)到本地

CloudDrive是一个强大的多云盘管理工具,为用户提供包含云盘本地挂载的一站式的多云盘解决方案。挂载到本地后,可以像本地文件一样进行操作。 一、套件库添加矿神源 二、安装CloudDriver2 1、搜索安装 搜索框输入【clouddrive】,搜索到Clou…

抖音快手判断性别、年龄自动关注脚本,按键精灵开源代码!

这个是支持抖音和快手两个平台的,可以进入对方主页然后判断对方年龄和性别,符合条件的关注,不符合条件的跳过下一个ID,所以比较精准,当然你可以二次开发加入更多的平台,小红书之类的,仅供学习&a…

YOLO目标检测——PCB缺陷数据集下载分享【含对应voc、coco和yolo三种格式标签】

实际项目应用:电子制造过程的质量控制、生产线的自动化检测、以及产品可靠性验证等方面数据集说明:PCB缺陷检测数据集,真实场景的高质量图片数据,数据场景丰富标签说明:使用lableimg标注软件标注,标注框质量…

修完这个 Bug 后,MySQL 性能提升了 300%

最近 MySQL 官方在 8.0.35 上修复了一个 bug: 这个 bug 是由 Mark Callaghan 发现的。Mark 早年在 Google MySQL 团队,后来去了 Meta MySQL,也主导了 RocksDB 的开发。 Mark 在 #109595 的 bug report 给出了非常详细的复现步骤 在官方修复后…

反转字符串中的单词

给你一个字符串 s ,请你反转字符串中 单词 的顺序。 单词 是由非空格字符组成的字符串。s 中使用至少一个空格将字符串中的 单词 分隔开。 返回 单词 顺序颠倒且 单词 之间用单个空格连接的结果字符串。 注意:输入字符串 s中可能会存在前导空格、尾随空格…