【学习】使用PyTorch训练与评估自己的ResNet网络教程

news2024/11/15 8:54:18

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客

项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

视频手把手教程:我将维护一个集成各主干网络的图像分类项目_哔哩哔哩_bilibili

主要是复现和训练测试自己的数据集

复现部分

0.环境问题

pytorch官网里面找个合适的CUDA11.0安装一下,然后把requirements.txt安装一下

pip install -r requirements.txt

 参考版本:

pip list
Package                Version
---------------------- ---------------
certifi                2021.5.30
cycler                 0.11.0
dataclasses            0.8
importlib-resources    5.4.0
joblib                 1.1.1
kiwisolver             1.3.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
numpy                  1.19.2
olefile                0.46
opencv-contrib-python  4.0.1.24
opencv-python          4.0.1.24
opencv-python-headless 4.0.1.24
packaging              21.3
Pillow                 8.4.0
pip                    21.3.1
pyparsing              3.0.7
python-dateutil        2.9.0.post0
scikit-learn           0.24.2
scipy                  1.5.4
setuptools             36.4.0
six                    1.16.0
terminaltables         3.1.10
threadpoolctl          3.1.0
torch                  1.7.1
torchaudio             0.7.0a0+a853dff
torchvision            0.8.2
tqdm                   4.64.1
typing_extensions      4.1.1
wheel                  0.37.1
zipp                   3.6.0

  • 下载MobileNetV3-Small权重至datas
  • 利用项目里的猫狗图片检验一下安装情况
    python tools/single_test.py datas/cat-dog.png models/mobilenet/mobilenet_v3_small.py --classes-map datas/imageNet1kAnnotation.txt
    

    成功的话大概这样:

 1.数据集问题

 先下载花卉数据集(0zat):flower_photos.zip_免费高速下载|百度网盘-分享无限制 (baidu.com)

 原始地址在项目的资料部分:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

 目录结构,按照花卉类型存放

├─flower_photos
│  ├─daisy
│  │      100080576_f52e8ee070_n.jpg
│  │      10140303196_b88d3d6cec.jpg
│  │      ...
│  ├─dandelion
│  │      10043234166_e6dd915111_n.jpg
│  │      10200780773_c6051a7d71_n.jpg
│  │      ...
│  ├─roses
│  │      10090824183_d02c613f10_m.jpg
│  │      102501987_3cdb8e5394_n.jpg
│  │      ...
│  ├─sunflowers
│  │      1008566138_6927679c8a.jpg
│  │      1022552002_2b93faf9e7_n.jpg
│  │      ...
│  └─tulips
│  │      100930342_92e8746431_n.jpg
│  │      10094729603_eeca3f2cb6.jpg
│  │      ...
  • datas/中创建标签文件annotations.txt,按行将类别名的索引写入文件(应该已经写好了);即
    daisy 0
    dandelion 1
    roses 2
    sunflowers 3
    tulips 4
    

    之后进行数据集划分,随机分为训练和测试集。

  • 在tools/split_data.py中修改原始数据集地址和划分后的数据集地址。(new_datasets最好别更改)

    init_dataset = './flower_photos'
    new_dataset = './Awesome-Backbones/datasets'
    

    终端使用命令:

    python tools/split_data.py
    

    划分后的数据集格式大概为:

    ├─...
    ├─datasets
    │  ├─test
    │  │  ├─daisy
    │  │  ├─dandelion
    │  │  ├─roses
    │  │  ├─sunflowers
    │  │  └─tulips
    │  └─train
    │      ├─daisy
    │      ├─dandelion
    │      ├─roses
    │      ├─sunflowers
    │      └─tulips
    ├─...
    

    查看tools/get_annotation.py,看看路径要不要更改:

  • datasets_path   = '你的数据集路径'
    

 终端使用命令:

python tools/get_annotation.py

 该命令应该会在datas/下形成train.txt和test.txt,里面是具体照片的位置

2.修改配置文件

/models下有许多的模型配置文件

 以resnet为例

 挑一个顺眼的改改

以resnet101为例

# model settings

model_cfg = dict(
    backbone=dict(
        type='ResNet',
        depth=101,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=5,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),))

# dataloader pipeline
img_lighting_cfg = dict(
    eigval=[55.4625, 4.7940, 1.1475],
    eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203]],
    alphastd=0.1,
    to_rgb=True)
policies = [
    dict(type='AutoContrast', prob=0.5),
    dict(type='Equalize', prob=0.5),
    dict(type='Invert', prob=0.5),
    dict(
        type='Rotate',
        magnitude_key='angle',
        magnitude_range=(0, 30),
        pad_val=0,
        prob=0.5,
        random_negative_prob=0.5),
    dict(
        type='Posterize',
        magnitude_key='bits',
        magnitude_range=(0, 4),
        prob=0.5),
    dict(
        type='Solarize',
        magnitude_key='thr',
        magnitude_range=(0, 256),
        prob=0.5),
    dict(
        type='SolarizeAdd',
        magnitude_key='magnitude',
        magnitude_range=(0, 110),
        thr=128,
        prob=0.5),
    dict(
        type='ColorTransform',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Contrast',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Brightness',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Sharpness',
        magnitude_key='magnitude',
        magnitude_range=(-0.9, 0.9),
        prob=0.5,
        random_negative_prob=0.),
    dict(
        type='Shear',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='horizontal',
        random_negative_prob=0.5),
    dict(
        type='Shear',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='vertical',
        random_negative_prob=0.5),
    dict(
        type='Cutout',
        magnitude_key='shape',
        magnitude_range=(1, 41),
        pad_val=0,
        prob=0.5),
    dict(
        type='Translate',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='horizontal',
        random_negative_prob=0.5,
        interpolation='bicubic'),
    dict(
        type='Translate',
        magnitude_key='magnitude',
        magnitude_range=(0, 0.3),
        pad_val=0,
        prob=0.5,
        direction='vertical',
        random_negative_prob=0.5,
        interpolation='bicubic')
]
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='RandAugment',
        policies=policies,
        num_policies=2,
        magnitude_level=12),
    dict(
        type='RandomResizedCrop',
        size=224,
        efficientnet_style=True,
        interpolation='bicubic',
        backend='pillow'),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
    dict(type='Lighting', **img_lighting_cfg),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=False),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='CenterCrop',
        crop_size=224,
        efficientnet_style=True,
        interpolation='bicubic',
        backend='pillow'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

# train
data_cfg = dict(
    batch_size = 32,
    num_workers = 0,
    train = dict(
        pretrained_flag = False,
        pretrained_weights = '',
        freeze_flag = False,
        freeze_layers = ('backbone',),
        epoches = 150,
    ),
    test=dict(
        ckpt = './logs/ResNet/2024-06-26-10-37-00/Last_Epoch150.pth',
        metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'confusion'],
        metric_options = dict(
            topk = (1,5),
            thrs = None,
            average_mode='none'
    )
    )
)

# optimizer
optimizer_cfg = dict(
    type='SGD',
    lr=0.001,
    momentum=0.9,
    weight_decay=1e-4)

# learning 
lr_config = dict(type='StepLrUpdater', step=[30, 60, 90])

主要改model_cfg里面的num_classes,data_cfg里的batch_size与num_workers

若有预训练权重则可以将pretrained_weights设置为True并将预训练的路径赋值给pretrained_weights

optimizer_cfg中修改初始学习率,根据batch_size调试

3.训练

终端运行

python tools/train.py models/resnet/resnet101.py

 运行结果

4.评估

在实际使用的配置文件中将ckpt修改

ckpt = '你的训练权重路径'

终端运行

python tools/evaluation.py models/resnet/resnet101.py

 运行结果

 我跑出来的准确率不高哈

5.测试

单张测试

python tools/single_test.py datasets/test/dandelion/14283011_3e7452c5b2_n.jpg models/resnet/resnet101.py

多张测试

使用batch_test.py,路径使用文件夹路径。

----------------------------------------------------------------------------------------------

使用自己的数据集

1.数据集准备

2.配置文件

3.训练

4.评估

5.测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1865325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java医院绩效考核系统源码:考核目标、考核指标、考核方法、考核结果与奖惩措施

Java医院绩效考核系统源码:考核目标、考核指标、考核方法、考核结果与奖惩措施 随着我国医疗体制的改革广大人民群的看病难,看病贵的问题一直没有得到有效地解决医疗费用的上涨,远远大于大多数家庭收入的增长速度。医院的改革已经势在必行&am…

左右旋分辨

从端头看,切削路径顺时针是右旋,反时针左旋。

OpenCL在移动端GPU计算中的应用与实践

一、引言 移动端芯片性能的不断提升为在手机上进行计算密集型任务,如计算机图形学和深度学习模型推理,提供了可能。在Android设备上,GPU,尤其是高通Adreno和华为Mali,因其卓越的浮点运算能力,成为了异构计…

计算机SCI期刊,中科院3区,易录用,收稿广泛

一、期刊名称 The Journal of Supercomputing 二、期刊简介概况 期刊类型:SCI 学科领域:计算机科学 影响因子:3.3 中科院分区:3区 三、期刊征稿范围 《超级计算杂志》发表有关超级计算各个方面的技术、架构和系统、算法、语…

【PromptCC】遥感图像变化字幕的解耦范式

摘要 以往的方法忽略了任务的显著特异性:对于不变和变化的图像对,RSICC难度是不同的,以一种耦合的方式处理未变化和变化的图像对,这通常会导致变化字幕的混淆。论文链接:https://ieeexplore.ieee.org/stamp/stamp.jsp…

CircuitBreaker断路器-Resilience4j

目录 背景分布式架构面临的问题:服务雪崩如何解决? CircuitBreakerResilience4jCircuitBreaker 服务熔断服务降级三种状态转换例子参数配置案例demo作业 BulkHead隔离特性SemaphoreBulkhead使用了信号量FixedThreadPoolBulkhead使用了有界队列和固定大小…

非root用户crontab定时任务不执行

前言 有一个sh脚本,通过crontab -l写入后,发现并没有执行,手动执行脚本却正常,怀疑是权限上的问题。 排查 在/var/log/cron查看日志发现有" FAILED to authorize user with PAM (Module is unknown)"的报错 解决 …

基于阿里云 OpenAPI 插件,让 Grafana 轻松实现云上数据可视化

作者:徽泠 引言 Grafana 作为市场上领先的开源监控解决方案之一,使得数据监控和可视化变得触手可及。作为一款开源的数据可视化和分析软件,Grafana 支持查询、可视化、提醒和探索您的各种数据,无论它们存储在何处。Grafana 通过…

HTTP协议中的各种请求头、请求类型的作用以及用途

目录 一、http协议介绍二、http协议的请求头三、http协议的请求类型四、http协议中的各种请求头、请求类型的作用以及用途 一、http协议介绍 HTTP(HyperText Transfer Protocol,超文本传输协议)是一种用于分布式、协作式和超媒体信息系统的应…

兰州市红古区市场监管管理局调研食家巷品牌,关注细节,推动进步

近日,兰州市红古区市场监管管理局临平凉西北绿源电子商务有限公司进行了深入视察,为企业发展带来了关怀与指导。 食家巷品牌作为平凉地区特色美食的代表之一,一直以来凭借其纯手工工艺和独特的风味,在市场上占据了一席之地。领导…

0.7 模拟电视标准 PAL 简介

0.7 模拟电视标准PAL PAL 是一种用于模拟电视的彩色编码系统,全名为逐行倒相(Phase Alternating Line)。它是三大模拟彩色电视标准之一,另外两个标准是 NTSC 和 SECAM。“逐行倒相”的意思是每行扫描线的彩色信号会跟上一行倒相&…

读写内部闪存FLASH读取芯片ID

读写内部闪存FLASH 右下角是OLED,然后左上角在PB1和PB11两个引脚,插上两个按键用于控制。下一个代码读取芯片ID,这个也是接上一个OLED,能显示测试数据就可以了。 STM32-STLINK Utility 本节的代码调试,使用辅助软件…

什么是云服务器镜像,如何选择?

云服务器镜像是一种用于业务连续性、灾难恢复和备份的技术手段,其本质是云端创建的服务器数据副本。 这些镜像内容可以涵盖系统、光盘、软件、网站甚至整个服务器,主要用于创建容错和冗余服务器计算基础架构,为用户提供了一个方便且可靠的解…

银河麒麟桌面操作系统V10SP1【FTP服务器】配置手册

简介: FTP是一个文件传输协议,主要是在互联网上提供文件储存和访问服务的计算机,一个FTP服务器可以对多个客户端提供服务。本文主要介绍在银河麒麟桌面操作系统V10SP1上如何搭建FTP服务器以及在客户端如何访问FTP服务器的操作方法。 正文: 一、操作环境 服务端:银河麒…

推荐系统(LLM去偏?) | (WSDM24)预训练推荐系统:因果去偏视角

::: 大家好!今天我分享的文章是来自威斯康星大学麦迪逊分校和亚马逊AWS AI实验室的最新工作,文章所属领域是推荐系统和因果推理,作者针对跨域推荐中的偏差问题提出了一种基于因果去偏的预训练推荐系统框架PreRec。 ::: 原文:Pre-t…

logstash配置文件中明文密码加密

1 案例背景 应用配置文件中禁止使用明文密码,需要加密处理 上图中,红框打码位置为es的明文密码,需要对其进行处理 2 创健keystore文件 /rpa/logstash/bin/logstash-keystore --path.settings /rpa/isa/conf/logstash/ create 注&#xff1…

3d渲染软件有哪些(2),渲染100邀请码1a12

3D渲染软件有很多,上次我们介绍了几个,这次我们接着介绍。 1、Arnold Arnold渲染器是一款基于物理算法的电影级渲染引擎,它具有渲染质量高、材质系统丰富、渲染速度快等特点,是3D设计师的极佳选择。2、Octane Render Octane Ren…

一文详解:什么是企业邮箱?最全百科

什么是企业邮箱?企业邮箱即绑定企业自有域名作为邮箱后缀的邮箱,是企业用于内部成员沟通和客户沟通的邮箱系统。 一、企业邮箱概念拆解 1.什么是企业邮箱? 企业邮箱即使用企业域名作为后缀的邮箱系统。它不仅提供专业的电子邮件收发功能&a…

JFreeChart 生成Word图表

文章目录 1 思路1.1 概述1.2 支持的图表类型1.3 特性 2 准备模板3 导入依赖4 图表生成工具类 ChartWithChineseExample步骤 1: 准备字体文件步骤 2: 注册字体到FontFactory步骤 3: 设置图表具体位置的字体柱状图:饼图:折线图:完整代码&#x…

韩顺平0基础学java——第30天

p600-611 坦克大战! 艰难推进中 坦克大战-子弹 发射子弹 1.当发射一颗子弹后,就相当于启动一个线程 2.玩家拥有子弹对象,当按下J时,就启动发射行为(线程),让子弹不停移动,形成…