计算机视觉框架OpenMMLab开源学习(三):图像分类实战

news2024/12/28 11:39:35

前言:本篇主要偏向图像分类实战部分,使用MMclassification工具进行代码应用,最后对水果分类进行实战演示,本次环境和代码配置部分省略,具体内容建议参考前一篇文章:计算机视觉框架OpenMMLab开源学习(二):图像分类

计算机视觉框架OpenMMLab开源学习(三):图像分类实战

一、安装OpenMMLab v2.0

Step 1. Install MMCV

mim install "mmcv>=2.0.0rc0"

Step 2. Install MMClassification and MMDetection

mim install "mmcls>=1.0.0rc0" "mmdet>=3.0.0rc0"

代码模版讲解:

model = dict(
    type='ImageClassifier',     # 分类器类型
    backbone=dict(
        type='ResNet',          # 主干网络类型
        depth=50,               # 主干网网络深度, ResNet 一般有18, 34, 50, 101, 152 可以选择
        num_stages=4,           # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。
        out_indices=(3, ),      # 输出的特征图输出索引。越远离输入图像,索引越大
        frozen_stages=-1,       # 网络微调时,冻结网络的stage(训练时不执行反相传播算法),若num_stages=4,backbone包含stem 与 4 个 stages。frozen_stages为-1时,不冻结网络; 为0时,冻结 stem; 为1时,冻结 stem 和 stage1; 为4时,冻结整个backbone
        style='pytorch'),       # 主干网络的风格,'pytorch' 意思是步长为2的层为 3x3 卷积, 'caffe' 意思是步长为2的层为 1x1 卷积。
    neck=dict(type='GlobalAveragePooling'),    # 颈网络类型
    head=dict(
        type='LinearClsHead',     # 线性分类头,
        num_classes=1000,         # 输出类别数,这与数据集的类别数一致
        in_channels=2048,         # 输入通道数,这与 neck 的输出通道一致
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0), # 损失函数配置信息
        topk=(1, 5),              # 评估指标,Top-k 准确率, 这里为 top1 与 top5 准确率
    ))

二、Pytorch图像分类任务

本次任务训练数据为FashionMNIST,完整代码如下:

# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Training

## Construct Dataset and Dataloader
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

## Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNetwork().to(device)

## Define loss function and Optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

## Inner loop for training
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Output Logs
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

## Inner loop for test
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


## Launch training / test loops#
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")


## Saving Models

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

# Deployment

## Loading Models
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

# Predict new images

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

三、利用MMClassification提供的预训练模型推理:

安装环境:

pip install openmim, mmengine
mim install mmcv-full mmcls

Inference using high-level API

from mmcls.apis import init_model, inference_model

model = init_model('mobilenet-v2_8xb32_in1k.py', 
                   'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', 
                   device='cuda:0')
load checkpoint from local path: mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
result = inference_model(model, 'banana.png')
result
{'pred_label': 954, 'pred_score': 0.9999284744262695, 'pred_class': 'banana'}
from mmcls.apis import show_result_pyplot

show_result_pyplot(model, 'banana.png', result)

PyTorch codes under the hood

Let write some raw PyTorch codes to do the same thing.

These are actual codes wrapped in high-level APIs.

construct an ImageClassifier

Note: current implementation only allow configs of backbone, neck and classification head instead of Python objects.

from mmcls.models import ImageClassifier

classifier = ImageClassifier(
    backbone=dict(type='MobileNetV2', widen_factor=1.0),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,
        in_channels=1280)
)

Load trained parameters

import torch

ckpt = torch.load('mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth')
classifier.load_state_dict(ckpt['state_dict'])

Construct data preprocessing pipeline

Important: A models work only if image preprocessing pipelines is correct.

from mmcls.datasets.pipelines import Compose

test_pipeline = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1), backend='pillow'),
    dict(type='CenterCrop', crop_size=224),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
])
data = dict(img_info=dict(filename='banana.png'), img_prefix=None)
data = test_pipeline(data)
data
{'img_metas': DataContainer({'filename': 'banana.png', 'ori_filename': 'banana.png', 'ori_shape': (403, 393, 3), 'img_shape': (224, 224, 3), 'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), 'std': array([58.395, 57.12 , 57.375], dtype=float32), 'to_rgb': True}}),
 'img': tensor([[[ 0.3309,  0.2967,  0.3138,  ...,  2.0263,  2.0092,  1.9920],
          [ 0.3481,  0.3309,  0.2282,  ...,  2.0263,  2.0092,  1.9920],
          [ 0.2796,  0.2967,  0.2967,  ...,  1.9920,  2.0263,  1.9749],
          ...,
          [ 0.1939,  0.1768,  0.2282,  ...,  0.3994,  0.3309,  0.3823],
          [ 0.1426,  0.1254,  0.2111,  ...,  0.5878,  0.5364,  0.5536],
          [-0.0116, -0.0801,  0.1597,  ...,  0.5707,  0.5536,  0.5364]],
 
         [[ 0.3803,  0.3803,  0.3803,  ...,  2.1660,  2.1485,  2.1134],
          [ 0.4153,  0.4153,  0.3102,  ...,  2.1835,  2.1310,  2.1134],
          [ 0.3452,  0.3803,  0.3803,  ...,  2.1134,  2.1485,  2.1134],
          ...,
          [ 0.2752,  0.2577,  0.3102,  ...,  0.5028,  0.4328,  0.4328],
          [ 0.2227,  0.1877,  0.3102,  ...,  0.6604,  0.6254,  0.5728],
          [ 0.0301, -0.0049,  0.2402,  ...,  0.6604,  0.6254,  0.5728]],
 
         [[ 0.5485,  0.5485,  0.5485,  ...,  2.3437,  2.3263,  2.2914],
          [ 0.5834,  0.5834,  0.4788,  ...,  2.3611,  2.3088,  2.2914],
          [ 0.5136,  0.5485,  0.5485,  ...,  2.3088,  2.3437,  2.3088],
          ...,
          [ 0.4091,  0.3916,  0.4439,  ...,  0.5834,  0.5136,  0.5311],
          [ 0.3568,  0.3045,  0.4265,  ...,  0.7576,  0.7228,  0.7054],
          [ 0.1651,  0.1128,  0.3742,  ...,  0.7576,  0.7402,  0.7054]]])}

equivalent in torchvision

from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor

tv_transform = Compose([Resize(256), 
                        CenterCrop(224), 
                        ToTensor(),
                        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                        ])

image = Image.open('banana.png').convert('RGB')
tv_data = tv_transform(image)

Forward through the model

## IMPORTANT: set the classifier to eval mode
classifier.eval()

imgs = data['img'].unsqueeze(0)
imgs = tv_data.unsqueeze(0)

with torch.no_grad():
    # class probabilities
    prob = classifier.forward_test(imgs)[0]
    # features
    feat = classifier.extract_feat(imgs, stage='neck')[0]
    
print(len(prob))
print(prob.argmax().item())
print(feat.shape)
1000
954
torch.Size([1, 1280])

3.使用MMClassificaiton完整进行水果分类实战:

数据集下载:

GitHub - TommyZihao/MMClassification_Tutorials: Jupyter notebook tutorials for MMClassificationJupyter notebook tutorials for MMClassification. Contribute to TommyZihao/MMClassification_Tutorials development by creating an account on GitHub.https://github.com/TommyZihao/MMClassification_Tutorials

代码框架: 

def main():
    model = build_classifier(cfg.model)
    model.init_weights()

    datasets = [build_dataset(cfg.data.train)]

    train_model(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        device=cfg.device,
        meta=meta)
mmcls/apis/train_model.py

def train_model(model,
                dataset,
                cfg):

    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer))

    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    runner.run(data_loaders, cfg.workflow)
mmcv/runner/epoch_based_runner.py

class EpochBasedRunner(BaseRunner):

    def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
        if train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.data_loader = data_loader
        for i, data_batch in enumerate(self.data_loader):
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
mmcls/models/classifiers/base.py

class BaseClassifier(BaseModule, metaclass=ABCMeta):

    def forward(self, img, return_loss=True, **kwargs):
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.

        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_meta are single-nested (i.e. Tensor and
        List[dict]), and when `resturn_loss=False`, img and img_meta should be
        double nested (i.e.  List[Tensor], List[List[dict]]), with the outer
        list indicating test time augmentations.
        """
        if return_loss:
            return self.forward_train(img, **kwargs)
        else:
            return self.forward_test(img, **kwargs)

    def train_step(self, data, optimizer=None, **kwargs):
        losses = self(**data)
        loss, log_vars = self._parse_losses(losses)

        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

        return outputs
mmcls/models/classifiers/image.py

class ImageClassifier(BaseClassifier):

    def __init__(self,
                 backbone,
                 neck=None,
                 head=None,
                 pretrained=None,
                 train_cfg=None,
                 init_cfg=None):
        super(ImageClassifier, self).__init__(init_cfg)

        if pretrained is not None:
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if head is not None:
            self.head = build_head(head)

    def extract_feat(self, img):
        x = self.backbone(img)

        if self.with_neck:
            x = self.neck(x)

        return x

    def forward_train(self, img, gt_label, **kwargs):
        x = self.extract_feat(img)

        losses = dict()
        loss = self.head.forward_train(x, gt_label)

        losses.update(loss)

        return losses
mmcv/runner/hooks/optimizer.py

class OptimizerHook(Hook):
    def after_train_iter(self, runner):
        runner.optimizer.zero_grad()
        runner.outputs['loss'].backward()
        runner.optimizer.step()

总结:本篇主要偏向图像分类实战部分,使用MMclassification工具进行代码应用,熟悉其框架应用,为后续处理不同场景下分类问题提供帮助。 

本文参考:GitHub - wangruohui/sjtu-openmmlab-tutorial

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/337015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络整理-问答

1. 程序工作的时候网络各层的状态 如下图所示: 1. TCP 在进行三次握手的时候,IP 层和 MAC 层对应都有什么操作呢? TCP 三次握手是通过在传输层建立连接的一个过程,在这个过程中,TCP 和 IP 层、MAC 层都起到了重要的…

ChatGPT API 本地如何调用

本文将会介绍,如何找到API文档和相应语言SDK,并使用PHP调用SDK实现本地请求API的完成过程及遇到的问题和解决方法。 API文档 1.打开官网 ChatGPT: Optimizing Language Models for Dialogue 2.找到API 3.查看文档 4.找到sdk库 OpenAI API 5.主流语言 …

2023年软考什么时候考试?

2023年软考各科目考试时间安排已确定!中国计算机技术职业资格网发布了《2023年度计算机技术与软件专业技术资格(水平)考试工作计划》,具体见下文。2023年度计算机软件资格考试(初级、中级、高级)上半年考试…

树莓派4b Raspberry Pi 4安装以前内置Python3.7版本的系统出现的一系列问题记录

今天想要重装树莓派系统,想装那种内置Python3.7版本的系统,从网上找到镜像源后烧录进去出现一系列问题: 烧录系统开机后,首先就出现报错: 上面显示一个问题就是:start4x.elf: is not compatible&#xff0…

Django自定义模板标签的使用详解

目录 1.创建子应用:python manage.py startapp test01 2.进行相关的配置 3.在新建的test01文件下创建urls.py(此处名称可变但注意上图) 4.在test01文件下创建名称为templatetags的文件夹 5.templatetags文件下继续创建几个py文件如下图​编辑 6.views视图函数…

走进独自开,带你轻松干副业

今天给大家分享一个开发者的福利平台——独自开(点击直接注册),让你在家就能解决收入问题。 文章目录一、平台介绍二、系统案例三、获取收益四、使用平台1、用户注册2、用户认证3、任务报价五、文末总结一、平台介绍 简单说明 独自开信息科技…

人工智能的未来———因果推理what if 第11章(统计模型) 文章解读

我们在观察数据当中,一般使用样本均值去估计目标人群的均值 在所有情况都是理想的情况下: 平均因果效应

Linux环境运行Maven 生成的hadoop jar包

运行命令: hadoop jar ./jar包名字 class对象路径 输入路径 输出路径 linux内部jar包测试 cd 到以下目录,创建以下文件夹 [rootreagan180 ~]# cd /opt/soft/hadoop313/share/hadoop/mapreduce/ 创建文件夹(读取路径) [roo…

ETL基础概念及要求详解

ETL基础概念及要求详解概念ETL与ELT数据湖与数据仓库ETL应用场景ETL具体流程及操作要求抽取清洗转换加载ETL设计模式SQL脚本语言ETL工具设计ETL工具SQLETL接口设计要求明确接口属性约定接口形式确定接口抽取方法规范接口格式概念 ETL即Extract(抽取)Tra…

Python学习-----无序序列1.0(字典的创建、查看、添加、修改、删除/替换)

目录 前言: 字典是什么 字典的特点 1.字典的创建 (1)直接创建{} (2)dict() 函数创建 2.字典的查询 (1)get()函数 (2)获取字典一组内容 3.字典键值对的添加 &a…

1CN/Jaccard/PA/AA/RA/Katz/PageRank/SimRank

common neighbors(CN) 公共邻居的数量。 Jaccard 用于比较有限样本集之间的相似性与差异性。Jaccard系数值越大,样本相似度越高。 preferential attachment(PA) 节点倾向于连接到节点度较高的节点上,&…

BSN-DDC基础网络详解(二):快速接入指南

本文将为大家介绍BSN算力中心方和DDC网络平台方接入DDC网络的基本流程,如下图所示,算力中心方和平台方依次执行图内左侧流程,右侧流程由DDC网络运营人员操作。01注册门户账号注册在接入之前,算力中心方和平台方需要先注册一个官方…

Android性能优化:getResources()与Binder交火导致的界面卡顿优化

欢迎:https://juejin.cn/post/7198430801851531324/ 欢迎:https://nasdaqgodzilla.github.io/2023/02/10/Android%E6%80%A7%E8%83%BD%E4%BC%98%E5%8C%96%EF%BC%9AgetResources-%E4%B8%8EBinder%E4%BA%A4%E7%81%AB%E5%AF%BC%E8%87%B4%E7%9A%84%E7%95%8C%E…

Neurosynth元分析——认知解码工具,软件包安装以及使用

Neurosynth元分析——认知解码工具,软件包安装以及使用 NeuroSynth 基本简介基本原理例子Neurosynth package安装及使用创建虚拟环境安装Dependencies:安装neurosynthNeurosynth使用加载必要的包下载neurosynth数据参考如上图所示。NeuroSynth 元分析感兴趣的区域沿功能连接梯…

玩转黑科技|ChatGPT保姆级注册指南(含免费手机号福利)

前言最近爆火的ChatGPT大家都应该多多少少的有所听说,各种渠道得知大家应该见识到他的强大,是不是很想上手玩一玩?但是由于其不支持中国电话号码进行注册,导致【注册ChatGPT】成了众多玩家头疼的事,也无法体验这个机器…

开源免费的WEB应用防火墙

开源免费的WEB应用防火墙 排名不分前后 资源宝分享:www.httple.net 1、南墙WEB应用防火墙(简称:)是有安科技推出的一款全方位网站防护产品。通过有安科技专有的WEB入侵异常检测等技术,结合有安科技团队多年应用安全的…

小白该从哪方面入手学习大数据

大数据本质上是海量数据。 以往的数据开发,需要一定的Java基础和工作经验,门槛高,入门难。 如果零基础入门数据开发行业的小伙伴,可以从Python语言入手。 Python语言简单易懂,适合零基础入门,在编程语言…

vue 回调函数(callback)的用法

一、介绍: 1、前提:在 js 中,函数也是对象,可以赋值给变量,可以作为参数放在函数的参数列表中,如: var doSomething function(a,b){return a b; } console.log(doSomething(2,3));2、概念&a…

神经网络基础部件-BN层详解

一,数学基础 1.1,概率密度函数 随机变量(random variable)是可以随机地取不同值的变量。随机变量可以是离散的或者连续的。简单起见,本文用大写字母 XXX 表示随机变量,小写字母 xxx 表示随机变量能够取到…

Zabbix 构建监控告警平台(二)--

Apache监控示例(图形监控)模板TemplateZabbix Items 1.Apache监控示例(图形监控) 1.1创建主机组 在“配置”->“主机群组”->“创建主机群组” 填入组名“webserver_test” 创建完成之后可以在“配置”->"主机群组&…