【OpenMMLab AI实战营二期笔记】第五天 MMPretrain代码课

news2024/12/21 19:17:11

1.环境安装

conda activate mmpre # 激活创建好的环境,确保安装好pytorch,可以使用gpu
git clone https://github.com/open-mmlab/mmpretrain.git # 下载mmpre源码
cd mmpretrain # 进入mmpretrian目录
pip install openmim # 安装管理工具
mim install -e ".[multimodal]"

2.代码演示

import mmpretrain
print(mmpretrain.__version__)
from mmpretrain import get_model,list_models,inference_model
print(list_models(task="Image Classification",pattern='resnet18'))#打印分类任务相关且名字中包含resnet18的模型
print(list_models(task="Image Caption",pattern='blip'))#打印图像描述任务相关且名字中包含blip的模型

2.1 构建模型部分:

#获取模型
model=get_model('resnet18_8xb16_cifar10')
print(type(model))# 查看模型类型
 
model =get_model('resnet18_8xb32_in1k')
print(type(model.backbone))#查看模型的backbone的类型

2.2 模型推理部分:

#未加载预训练权重的情况下模型推理
inference_model(model,'demo/bird.jpg',show=True)
#加载预训练权重
list_model(task='Image Caption',pattern='blip')
inference_model('blip-base_3rdparty_caption','demo/cat-dog.png',show=True)

3.基于分类数据集的微调训练

3.1 数据集准备:

从kaggle上找到一个类似的数据集,下载地址:https://www.kaggle.com/datasets/esuarez7/cats_dogs_dataset/download?datasetVersionNumber=1
预训练权重的下载地址:https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth

mkdir data # 创建data文件夹
cd data # 进入data文件夹
tar -xff ~/Downloads/cats_dogs_dataset.tar #将下载好的数据集解压到data文件夹下
cd cats_dogs_dataset #进入解压后的文件夹
ls #列出当前目录下的文件
tree ./ --filelimit=10 #列出文件目录结构

请添加图片描述

3.2 配置文件

介绍:

#回到mmpretrain文件夹下后
ls conmfig #列出config目录下的文件
ls configs/resnet18 #查看resnet18相关的配置文件

配置文件主要分为4部分:
(1)model(backbone、neck、head)
(2)dataset(数据预处理、训练、验证、测试数据流程配置)
(3)schedules(优化器配置等)
(4)runtime(包括日志配置、权重保存配置、随机性可指定随机种子)

配置自定义配置文件:

mkdir projects/cat_dog #创建cat_dog文件夹
cd projects/cat_dog #进入文件夹
vim resnet18_finetune.py #新建配置文件

以下是完整的配置文件中的内容

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ))
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
    num_classes=1000,
    # RGB format normalization parameters
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    # convert image from BGR to RGB
    to_rgb=True,
)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackInputs'),
]

train_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/train.txt',
        data_prefix='train',
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/val.txt',
        data_prefix='val',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

# optimizer
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=256)

# defaults to use registries in mmpretrain
default_scope = 'mmpretrain'

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type='IterTimerHook'),

    # print log every 100 iterations.
    logger=dict(type='LoggerHook', interval=100),

    # enable the parameter scheduler.
    param_scheduler=dict(type='ParamSchedulerHook'),

    # save checkpoint per epoch.
    checkpoint=dict(type='CheckpointHook', interval=1),

    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type='DistSamplerSeedHook'),

    # validation results visualization, set True to enable it.
    visualization=dict(type='VisualizationHook', enable=False),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,

    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),

    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends)

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

根据需求修改部分

#模型部分
head=dict(
	num_classes=2#修改
)
backbone=dict(
	init_cfg=dict(type='Pretrained',checkpoint='文件路径')#添加
)
# 数据集部分
dataset_type = 'CustomDataset'
train_dataloader=dict(
	dataset=dict(
		data_root="../../data/cats_dogs_dataset/training_set"#修改
	)
)
val_dataloader=dict(
	dataset=dict(
		data_root="../../data/cats_dogs_dataset/val_set"#修改
	)
)
val_evaluator=dict(type='Accuracy',topk=1)
optim_wrapper=dict(optimizer=dict(type='SGD',lr=0.01,momentum=0.9,weight_decay=0.0001))
train_cfg=dict(by_epoch=True,max_eopchs=5,val_interval=1)

3.3 训练

mim train mmpretrain resnet18_finetune.py --work-dir=./exp

3.4 评估

mim test mmpretrain resnet18_finetune.py --checkpoint exp/epoch_5.pth

mim test mmpretrain resnet18_finetune.py --checkpoint exp/epoch_5.pth --out result.pkl #把结果保存在.pkl文件中

请添加图片描述

3.5 结果分析

mim run mmpretrain analyze_results resnet18_finetune.py result.pkl --out_dir analyze
mim run mmpretrain confusion_matrix resnet18_finetune.py result.pkl --show --include-values # 画出分类的混淆矩阵

请添加图片描述

3.6推理

from mmpretrain import ImageClassificationInferencer
inferencer=ImageClassificationInferencer('./resnet18_finetune.py',pretrained='exp/epoch_5.pth')
inferencer("../../data/cats_dogs_dataset/val_set/cat_or_dog_1.jpg")

推理结果如下:
请添加图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVA开发运维(系统上到生产环境准备工作)

一、前言 java项目在开发环境开发完成,在测试环境测试没有问题后,就需要发布到生产环境,如果系统是对公众的,那就需要很多工作了。比如服务器申请,域名申请,渗透测试,漏洞扫描,公网…

第二章(第二节):导数与微分

1.导数与微分 1.导数概念 设曲线 L 的方程 y=f(x),a ≤ x ≤ b,x0 ∈ (a, b),在曲线 L 上的点 M0(x0, y0) 附近任取一点 M(x0 + Δx, y0 + Δy),过 M0 与 M 作曲线的割线M~0~M,的斜率为:当 x→x0 时,点 M 沿着曲线 L 趋向 M0,与此同时,割线 M0M 趋向一个极限位置 M0T…

想要转行的一定要看软件测试发展简史+学习路线

迄今为止,软件测试的发展一共经历了五个重要时期: 调试为主 20世纪50年代,计算机刚诞生不久,只有科学家级别的人才会去编程,需求和程序本身也远远没有现在这么复杂多变,相当于开发人员一人承担需求分析&am…

idea设置注释模板

目录 设置注释文件模板设置模板 设置注释文件模板 Ctrl Alt S 打开设置,Editor - File and Code Templates 选择class、interface、enum根据自己需要选择需要添加注释的文件,依次添加如下配置内容 /**1. ClassName ${NAME}2. Description TODO3. Aut…

BUUCTF Unencode 1

题目描述&#xff1a; 密文&#xff1a; 89FQA9WMD<V1A<V1S83DY.#<W3$Q,2TM]解题思路&#xff1a; 1、观察密文&#xff0c;尝试Base85、Base91等编码&#xff0c;均失败。 2、结合题目&#xff0c;联想到UUencode编码&#xff0c;尝试后成功&#xff0c;得到flag。 …

驱动LSM6DS3TR-C实现高效运动检测与数据采集(5)----上报匿名上位机实现可视化

概述 lsm6ds3trc包含三轴陀螺仪与三轴加速度计。 姿态有多种数学表示方式&#xff0c;常见的是四元数&#xff0c;欧拉角&#xff0c;矩阵和轴角。他们各自有其自身的优点&#xff0c;在不同的领域使用不同的表示方式。在四轴飞行器中使用到了四元数和欧拉角。 姿态解算选用的…

SpringBoot配置多数据源

SpringBoot配置多数据源 最近在做一个SpringBoot项目时需要关联两个数据库,于是乎我就研究了下关于springboot的多数据源配置,记录配置过程,分享一下 一、基础配置 (这里只展示主要配置) JDK1.8springBoot2.3.4.RELEASEmybatis2.1.0mysql-connector-java 8.0.21maven仓…

知乎家居产品种草营销怎么做?

近年来&#xff0c;家居产品种草营销已经成为了一种新型营销方式。知乎作为全球最大的中文问答社区&#xff0c;拥有着海量的用户和优质内容&#xff0c;逐渐成为了家居产品种草营销中不可忽视的平台。那么&#xff0c;在这个平台上如何进行家居产品种草营销呢&#xff1f;接下…

Python之函数【三】(高阶函数和闭包)

文章目录 前言一、高阶函数二、闭包&#xff08;也称之为&#xff1a;闭包函数&#xff09; 1、浅谈闭包函数 1.1、划重点1.2、注意点2、怎么判断是不是闭包函数呢&#xff1f; 2.1、那接下来&#xff0c;我们就细细的拆开解释2.2、对于这个作用域&#xff0c;在JavaSc…

【MySQL数据库基础】

MySQL数据库基础 1. 数据库的操作1.1 显示当前的数据库1.2 创建数据库1.3 使用数据库1.4 删除数据库 2. 常用数据类型2.1整数&#xff08;xxxint&#xff09;2.2日期时间类型2.3字符串型 3. 表的操作3.1 查看表结构3.2 创建表3.3 删除表 1. 数据库的操作 1.1 显示当前的数据库…

Es索引中时间字段是字符串Range查询的正确姿势

文章目录 [toc] 1. 问题2. Es索引的mapping模式2.1 dynamic动态宽松模式&#xff08;动态映射&#xff09;2.2 strict严格模式&#xff08;静态映射&#xff09; 3. text类型和keyword类型的区别3.1 text类型3.2 keyword类型 4.正确姿势5. 总结 1. 问题 由于之前搞了一个使用fl…

230616安装SqlServer2017Express

230616安装SqlServer2017Express 下载地址 选择语言 Microsoft SQL Server 2017 Express 下载地址: 简体中文 感谢下载 Microsoft SQL Server 2017 Express 我将下载的文件的名称加上了SHA256值, 一长串 是一个 .exe 的自解压文件, 双击后,默认解压到同根文件夹\同名文件夹下,…

那些可以当源码学习的优质开源项目分享

本篇收集的是自己平时逛 Github 发现的一些优质的开源项目&#xff0c;为什么收集它&#xff1f; 借助优质的开源项目&#xff0c;我们不仅可以拿来二次开发快速实现想要的功能&#xff0c;而且还可以学习里面优秀的代码&#xff0c;提高我们的编程能力。读&#xff08;拆解&am…

vue实现elementUI table表格树形结构-使用懒加载时-解决子节点增删改后,不刷新子节点数据问题

问题发现 在使用element-ui的table组件时&#xff0c;使用树形结构&#xff0c;并使用了懒加载&#xff0c;可出现了一个问题&#xff0c;在对当前节点添加一个子节点数据&#xff0c;或删除一个子节点数据时&#xff0c;当前节点的子节点数据并不自动刷新出来。element-ui官方…

景联文科技:一文详解关键点标注

关键点标注是计算机视觉领域的一种任务&#xff0c;指的是在图像或视频序列中标注出特定目标的关键点&#xff0c;这些关键点通常是目标的重要特征点或轮廓点&#xff0c;包括但不限于人体关节、面部特征点、车辆零部件等。通过对关键点的标注&#xff0c;可以为后续的目标跟踪…

19. 算法之分治算法

1. 概念 分治算法&#xff08;divide and conquer&#xff09;的核心思想其实就是四个字&#xff0c;分而治之 &#xff0c;也就是将原问题划分成n个规模较小&#xff0c;并且结构与原问题相似的子问题&#xff0c;递归地解决这些子问题&#xff0c;然后再合并其结果&#xff…

微信小程序开发(1)

10分钟入门 - 微信小程序开发 微信小程序详细教程 小程序简介 小程序是一种全新的连接用户与服务的方式&#xff0c;它可以在微信内被便捷地获取和传播&#xff0c;同时具有出色的使用体验。 小程序技术发展史 WeixinJSBridge.invoke(imagePreview, { 2. current: http://i…

大数据之路书摘:走近大数据——从阿里巴巴学习大数据系统体系架构

文章目录 1.数据采集层2.数据计算层3.数据服务层4.数据应用层 在大数据时代&#xff0c;人们比以往任何时候更能收集到更丰富的数据。但是如果不能对这些数据进行有序、有结构地分类组织和存储&#xff0c;如果不能有效利用并发掘它&#xff0c;继而产生价值&#xff0c;那么它…

SNMP软件及性能监控

SNMP&#xff08;Simple Network Management Protocol&#xff09;是一种用于网络管理的协议。通过SNMP&#xff0c;我们可以监测和管理网络设备、服务器等重要设备的性能和状况&#xff0c;从而确保网络的正常运行。但在开始使用之前&#xff0c;需要进行配置&#xff0c;以便…

计算机未来五年最吃香的4个职位,对女生超级友好!

今年计算机毕业的学弟学妹对于找工作感觉到非常焦虑&#xff0c;不知道该哪个方向就业才有出路。很多同学感觉在学校好像什么都学了&#xff0c;又好像什么都没学到&#xff0c;先不说企业会不会招&#xff0c;自己就连投简历的勇气都没有&#xff0c;生怕大把的简历投出去就石…