【代码实验】CNN实验——利用Imagenet子集训练分类网络(AlexNet/ResNet)

news2025/1/11 2:57:29

文章目录

  • 前言
  • 一、数据准备
  • 二、训练
  • 三、结果


前言

Imagenet是计算机视觉的经典分类比赛,但是Imagenet数据集本身太大了,我们穷学生没有这么大的算力,2016年google DeepMind团队从Imagnet数据集中抽取的一小部分(大小约3GB)制作了Mini-Imagenet数据集(也就是Imagenet的子集),共有100个类别,每个类别都有600张图片,共60000张图片。这个大小的数据集是可以训练得动的。


一、数据准备

首先下载数据集,链接:miniimagenet

在这里插入图片描述
下载完成后,文件里面应该有4个文件夹,iamges文件夹包含了60000张从Imagenet中选出来的图片。还有三个csv文件:train、val、test,对应训练集、验证集、测试集,csv文件里面是图片的文件名和对应的标签。

├── mini-imagenet
     ├── images
     ├── train.csv
     ├── val.csv
     └── test.csv

在这里插入图片描述

但是这三个csv文件不能直接使用,因为train.csv包含38400张图片,共64个类别。val.csv包含9600张图片,共16个类别。可以看出作者将60000张图片共100类分在了这三个csv文件,所以我们肯定不能按照他的CSV文件来划分训练集和验证集(自己跑实验就没必要分三个了,分训练集和验证集就足够了)。

下面是60000张图片划分为训练集和验证集的脚本,只需要修改imagenet的根目录和验证集的比例即可。

import csv
import glob
import os
import random
import shutil
# 找出所有csv文件


def concat_csv(csv_list):
    file_with_label = {}
    for csv_path in csv_list:
        with open(csv_path) as csvfile:
            reader = csv.reader(csvfile)
            header = next(reader)
            for line in reader:
                if (line[1] not in file_with_label.keys()):
                    file_with_label[line[1]] = [line[0]]
                else:
                    file_with_label[line[1]].append(line[0])
    return file_with_label


def split_dataset(file_with_label, root, ratio):
    for label in file_with_label.keys():
        if not os.path.exists(os.path.join(root, "train", label)) and not os.path.exists(os.path.join(root, "val", label)):
            os.makedirs(os.path.join(root, "train", label))
            os.makedirs(os.path.join(root, "val", label))
        for file_name in file_with_label[label]:
            shutil.move(os.path.join(root, "images", file_name),
                        os.path.join(root, "train", label))
    for label in os.listdir(os.path.join(root, "train")):
        samples = random.sample(os.listdir(
            os.path.join(root, "train", label)), int(len(os.listdir(os.path.join(root, "train", label))) * ratio))
        for files in samples:
            shutil.move(os.path.join(root, "train", label, files),
                        os.path.join(root, "val", label))
    print("数据集划分完成!")

def main():
    root = "F:/Mini-ImageNet"  # 修改imagenet根目录
    csv_list = glob.glob(os.path.join(root, "*.csv"))  # 获取三个csv文件
    file_with_label = concat_csv(csv_list)  # 整合csv文件
    split_dataset(file_with_label, root, 0.2)  # 分成训练集和验证集,修改验证集比例,默认0.2

if __name__ == "__main__":
    main()

运行代码后会看到增加了两个文件夹,train文件夹是训练图片,存放格式是:

├── train
     ├── label 1
     	├── image
     ├── label 2
     	├── iamge
...

在这里插入图片描述

train文件夹里面有100个文件夹,对应100分类。每个文件夹的文件名就是对应的label。每个label下存放着训练图片。这样我们可以直接用Pytorch的内置数据模块torchvision.datasets.ImageFolder来加载数据。val文件夹同理。

使用torchvision.datasets.ImageFolder加载数据集时,是按照文件夹顺序来索引的。比如经过softmax后输出最大概率值索引是0,那么就对应类别为n01532829;最大概率值索引是1,那么就对应类别为n01558993

在这里插入图片描述
我们可以看到label都是n01532829这种,我们可以通过Imagenet的json文件来查找对应的具体类别。
json文件参考:label
我把100类对应的具体类别挑了出来:

label_class = {'n01532829': 'house_finch',
 'n01558993': 'robin',
 'n01704323': 'triceratops',
 'n01749939': 'green_mamba',
 'n01770081': 'harvestman',
 'n01843383': 'toucan',
 'n01855672': 'goose',
 'n01910747': 'jellyfish',
 'n01930112': 'nematode',
 'n01981276': 'king_crab',
 'n02074367': 'dugong',
 'n02089867': 'Walker_hound',
 'n02091244': 'Ibizan_hound',
 'n02091831': 'Saluki',
 'n02099601': 'golden_retriever',
 'n02101006': 'Gordon_setter',
 'n02105505': 'komondor',
 'n02108089': 'boxer',
 'n02108551': 'Tibetan_mastiff',
 'n02108915': 'French_bulldog',
 'n02110063': 'malamute',
 'n02110341': 'dalmatian',
 'n02111277': 'Newfoundland',
 'n02113712': 'miniature_poodle',
 'n02114548': 'white_wolf',
 'n02116738': 'African_hunting_dog',
 'n02120079': 'Arctic_fox',
 'n02129165': 'lion',
 'n02138441': 'meerkat',
 'n02165456': 'ladybug',
 'n02174001': 'rhinoceros_beetle',
 'n02219486': 'ant',
 'n02443484': 'black-footed_ferret',
 'n02457408': 'three-toed_sloth',
 'n02606052': 'rock_beauty',
 'n02687172': 'aircraft_carrier',
 'n02747177': 'ashcan',
 'n02795169': 'barrel',
 'n02823428': 'beer_bottle',
 'n02871525': 'bookshop',
 'n02950826': 'cannon',
 'n02966193': 'carousel',
 'n02971356': 'carton',
 'n02981792': 'catamaran',
 'n03017168': 'chime',
 'n03047690': 'clog',
 'n03062245': 'cocktail_shaker',
 'n03075370': 'combination_lock',
 'n03127925': 'crate',
 'n03146219': 'cuirass',
 'n03207743': 'dishrag',
 'n03220513': 'dome',
 'n03272010': 'electric_guitar',
 'n03337140': 'file',
 'n03347037': 'fire_screen',
 'n03400231': 'frying_pan',
 'n03417042': 'garbage_truck',
 'n03476684': 'hair_slide',
 'n03527444': 'holster',
 'n03535780': 'horizontal_bar',
 'n03544143': 'hourglass',
 'n03584254': 'iPod',
 'n03676483': 'lipstick',
 'n03770439': 'miniskirt',
 'n03773504': 'missile',
 'n03775546': 'mixing_bowl',
 'n03838899': 'oboe',
 'n03854065': 'organ',
 'n03888605': 'parallel_bars',
 'n03908618': 'pencil_box',
 'n03924679': 'photocopier',
 'n03980874': 'poncho',
 'n03998194': 'prayer_rug',
 'n04067472': 'reel',
 'n04146614': 'school_bus',
 'n04149813': 'scoreboard',
 'n04243546': 'slot',
 'n04251144': 'snorkel',
 'n04258138': 'solar_dish',
 'n04275548': 'spider_web',
 'n04296562': 'stage',
 'n04389033': 'tank',
 'n04418357': 'theater_curtain',
 'n04435653': 'tile_roof',
 'n04443257': 'tobacco_shop',
 'n04509417': 'unicycle',
 'n04515003': 'upright',
 'n04522168': 'vase',
 'n04596742': 'wok',
 'n04604644': 'worm_fence',
 'n04612504': 'yawl',
 'n06794110': 'street_sign',
 'n07584110': 'consomme',
 'n07613480': 'trifle',
 'n07697537': 'hotdog',
 'n07747607': 'orange',
 'n09246464': 'cliff',
 'n09256479': 'coral_reef',
 'n13054560': 'bolete',
 'n13133613': 'ear'}

二、训练

搭建AlexNet/ResNet或者其他网络可以自己写模型也可以直接加载torchvision.models 里写好的网络架构。如果是自己搭建网络,按照自己写的模板来就行,下面是我自己写的一个模板例子:

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models


def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='relu')
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)



class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 4, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))


        self.fc=nn.Sequential(nn.Linear(256*6*6, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 100),
                                )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        output=self.fc(x.view(-1, 256*6*6))
        return output


def train(epoch):
    global train_loss
    train_loss=0
    for idx, (inputs, label) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, label=inputs.to(device), label.to(device)
        outputs=model(inputs)
        loss=criteon(outputs, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print("epoch %d train_loss:%.6f" %
          (epoch+1, train_loss/len(train_set)*256))


def test(epoch):
    global test_loss, correct
    test_loss=0
    correct=0
    for idx, (inputs, label) in enumerate(test_loader):
        with torch.no_grad():
            inputs, label=inputs.to(device), label.to(device)
            outputs=model(inputs)
            test_loss += criteon(outputs, label)
            predict=torch.max(outputs, dim=1)[1]
            correct += torch.eq(predict, label).sum().item()
    print("test_acc: %.4f  val_loss:%.4f " %
          ((correct/len(test_set)), test_loss*256/len(test_set)))
    

if __name__=="__main__":
    transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Resize((224,224)),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
	# 数据集路径
	# 根据数据集保存的格式我们用torchvision.datasets.ImageFolder加载数据集
    train_set = torchvision.datasets.ImageFolder("F:/DLdata/mini-imagenet/train", transform)
    test_set = torchvision.datasets.ImageFolder("F:/DLdata/mini-imagenet/val", transform)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True) 
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False) 
    
    # 写入tensorboard
	writer = SummaryWriter()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model =AlexNet()
    model.to(device)
    model.apply(weight_init)

    optimizer = optim.SGD(model.parameters(),lr=1e-2,momentum=0.9,weight_decay=0.0005)
    criteon = nn.CrossEntropyLoss().to(device)
    scheduler =optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.0001, last_epoch=-1)
    for epoch in range(2):
        model.train()
        train(epoch)
        model.eval()
        test(epoch)
        scheduler.step()
        writer.add_scalar('loss', train_loss/len(train_loader), epoch)
        writer.add_scalar('acc', correct/len(test_set), epoch)
     writer.close()

如果是调用models模块的网络结构,则可以省略很多工作,但是要记得修改最后一层softmax的输出维度。官方模型是1000分类,这里是100分类。修改模型网络结构可以参考另一篇:加载预训练模型与修改网络结构

三、结果

在ResNet34中训练了80多个epoch,达到了74%的准确率。其实也试了ResNet50感觉模型太大了,容易过拟合,最后精度也差不多。AlexNet就要差一些了,只有62%,毕竟是很早之前的模型了,也可以再调调参。
ResNet34:
在这里插入图片描述
AlexNet:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DBCO-PEG-Methacrylate_DBCO-PEG-MA_二苯并环辛炔-PEG-甲基丙烯酸酯

一、试剂基团反应特点(Reagent group reaction characteristics):DBCO(二苯并环辛炔)是一种环炔烃,可以通过在水溶液中通过应变促进的1,3-偶极环加成反应与叠氮化物反应,这种生物正交反应也称为…

Crack:MindFusion.Diagramming for ASP.NET V7.0

MindFusion.Diagramming for ASP.NET V7.0 MindFusion.Diagramming for ASP.NET 为 Web 应用程序提供图表功能。它包括丰富的预定义图表节点以及具有无限行数和列数的表节点。您可以在容器中组织节点,形状设计器 UI 工具可帮助您快速轻松地生成自己的图表节点。 添加…

GD32F450工程模板创建

一、新建工程目录 序号文件夹备注1Project存放工程文件,编译文件等。2Firmware存放ARM内核文件,标准外设库文件等。3Hardware存放开发板的硬件驱动文件。4App存放应用层文件。5User存放main函数,gd32f4xx_it文件,systick文件。6D…

【GD32F427开发板试用】07-硬件SPI驱动0.96LCD

本篇文章来自极术社区与兆易创新组织的GD32F427开发板评测活动,更多开发板试用活动请关注极术社区网站。作者:Stark_GS SPI 简介及特点 具有全双工、半双工和单工模式的主从操作。16位宽度,独立的发送和接收缓冲区。8位或16位数据帧格式。低…

火出圈的《中国奇谭》,如果浪浪山的小妖怪们也用WorkPlus

你会被一只小妖怪感动到破防吗? 最近,由上海美术电影制片厂和B站联合推出的动画片《中国奇谭》火了,仅仅一上线就被打出9.9的高分,频繁登上微博热搜。 其中,第一集《小妖怪的夏天》取材自《西游记》,却从…

Unity编辑器右键菜单实现多平台游戏资源打包—AssetBundle的构建

文章目录👉一、初识AssetBundle👉二、创建AssetBundle👉三、动手操作:实现右键菜单打包AssetBundle👉一、初识AssetBundle AssetBundle是Unity提供的一种打包资源的文件格式,比如模型、纹理和音频文件等的各…

大数据导论

数据是什么数据是指对客观事件进行记录并可以鉴别的符号,是对客观事物的性质、状态以及相互关系等进行记载的物理符号或这些物理符号的组合,它是可识别的、抽象的符号。它不仅指狭义上的数字,还可以是具有一定意义的文字、字母、数字符号的组…

python学习 --- 元组基础

目录 一、什么是元组 1、不可变序列和可变序列 2、元组 二、 元组的创建方式 1、小括号(可省略) 2、内置函数tuple() 三、元组的获取与遍历 1、元组的获取 2、元组的遍历 一、什么是元组 1、不可变序列和可变序列 不可变序列:没有增…

使用streamlit记录

官方网址:https://docs.streamlit.io/ 官方网址:https://discuss.streamlit.io/t/streamlit-components-community-tracker/4634 官方网址:https://github.com/streamlit/streamlit 第三方插件网址:https://github.com/arnaudmiri…

Solidity 中的数学(第 2 部分:溢出)

本文是关于在 Solidity 中进行数学运算的系列文章中的第二篇。这次的主题是:溢出。 介绍 每次我看到、*或**审计另一个 Solidity 智能合约时,我都会开始写以下评论:“这里可能会溢出”。我需要几秒钟来写这四个字,在这几秒钟内&a…

【Pandas】18 小练习

#【Pandas】18 小练习 2023.1.16 两个pandas小练习 18.1 疫情数据分析 18.1.1 观察数据 import pandas as pd import osdf pd.read_csv("data/covid19_day_wise.csv") dfDateConfirmedDeathsRecoveredActiveNew casesNew deathsNew recoveredDeaths / 100 CasesR…

日常渗透刷洞的一些小工具

SecurityServiceBox:一个Windows平台下既可以满足安服仔日常渗透工作也可以批量刷洞的工具盒子 0x00 更新题外话—终端选取 在盒子的tools当中,很多工具运行都是带有颜色标识的,例如nuclei, vulmap,原生的cmd终端虽然…

MD5有哪些特性,常用的MD5加密真的安全吗

在密码学中,MD5是比较常用的算法之一。大家都知道MD5曾一度被认为十分安全,并且在国内外得到广泛适用。然而,王小云教授的研究证明利用MD5算法的磕碰能够严重威胁信息体系安全,因此引发了密码学界的轩然大波。那么,关于…

为什么JDK中String类的indexof不使用KMP或者Boyer-Moore等时间复杂度低的算法编辑器

indexOf底层使用的方法是典型的BF算法。 1、KMP算法 由来 外国人: Knuth,Morris和Pratt发明了这个算法,然后取它三个的首字母进行了命名。所以叫做KMP。 KMP真的很难理解,建议多看几遍 B站代码随想录,文章也的再好 …

【蓝桥杯备赛系列 | 真题 | 简单题】2014年第五届真题-分糖果

🤵‍♂️ 个人主页: 计算机魔术师 👨‍💻 作者简介:CSDN内容合伙人,全栈领域优质创作者。 蓝桥杯竞赛专栏 | 简单题系列 (一) 作者: 计算机魔术师 版本: 1.0 &#xff08…

【博客597】iptables如何借助连续内存块通过xt_table结构管理流量规则

iptables如何借助连续内存块通过xt_table结构管理流量规则 1、iptables 分为两部分: 用户空间的 iptables 命令向用户提供访问内核 iptables 模块的管理界面。内核空间的 iptables 模块在内存中维护规则表,实现表的创建及注册。 2、iptables如何管理众…

第十二章 数据库设计

前言 本文章为看视频所写。 视频链接:168. 14.1 数据库设计前言_哔哩哔哩_bilibili 目录 前言 章节提要 一、数据库设计过程 二、E-R模型 三、答题技巧 四、案例分析 1、案例1 二、案例2 章节提要 一、数据库设计过程 ER模型:是实体联系模型&#x…

第一章 数据结构绪论

数据结构:是相互之间存在一种或多种特定关系的数据元素的集合。数据结构是一门研究非数值计算的程序设计问题中的操作对象,以及它们之间关系和操作等相关问题的学科。程序设计数据结构算法数据:是描述客观事物的符号,是计算机中可…

2.2、进程的状态与转换

整体框架 1、三种基本状态 进程是程序的一次执行。在这个执行过程中,有时进程正在被 CPU 处理,有时又需要等待 CPU 服务, 可见进程的状态是会有各种变化。 为了方便对各个进程的管理,操作系统需要将进程合理地划分为几种状态 ①…

随机梯度下降法的数学基础

梯度是微积分中的基本概念,也是机器学习解优化问题经常使用的数学工具(梯度下降算法)。因此,有必要从头理解梯度的来源和意义。本文从导数开始讲起,讲述了导数、偏导数、方向导数和梯度的定义、意义和数学公式&#xf…