MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

news2024/11/15 23:50:31

MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
在这篇博客中,我们将探索如何使用 MindSpore 框架和 Gradio 库来创建一个基于深度学习的图像分类应用。我们将使用预训练的 ResNet50 模型,以 CIFAR-10 数据集为例进行训练,并通过 Gradio 接口进行图像分类预测。下面是一个简单、直观的指南,适用于希望将深度学习模型转换为交互式应用的开发者。

训练模型

环境设置

首先,我们需要设置 GPU 作为训练的目标设备。MindSpore 提供了一个便捷的方式来配置环境。

from mindspore import context
context.set_context(device_target="GPU")

解析参数

我们使用 argparse 来解析命令行参数。这样可以方便地在训练时调整参数,例如数据集路径、学习率和训练周期数。

import argparse
def parse_args():
    """
    解析命令行参数。

    返回:
        argparse.Namespace: 包含命令行参数的命名空间。
    """
    parser = argparse.ArgumentParser(description="训练 ResNet 模型",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrain_path', type=str, default='',
                        help='预训练文件的路径')
    parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',
                        help='训练数据的路径')
    parser.add_argument('--output_path', default='train/resnet/', type=str,
                        help='模型保存路径')
    parser.add_argument('--epochs', default=10, type=int, help='训练周期数')
    parser.add_argument('--lr', default=0.0001, type=int, help='学习率')
    return parser.parse_args()

创建数据集

使用 MindSpore 的 create_dataset 方法,我们可以轻松创建和预处理 CIFAR-10 训练数据集。

from mindcv.data import create_dataset, create_transforms, create_loader


def create_training_dataset(data_path, batch_size):
    """
    创建训练数据集。

    参数:
        data_path (str): 数据集的路径。
        batch_size (int): 批量大小。

    返回:
        Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。
    """
    dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)
    transform_train = create_transforms(dataset_name='cifar10', image_resize=224)
    train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,
                                 num_classes=10, transform=transform_train)
    num_batches = train_loader.get_dataset_size()
    return train_loader, num_batches

模型训练

接下来,我们定义 train_model 函数来实现模型的训练逻辑。这包括模型的初始化、损失函数、优化器的设置,以及训练过程的启动。

from mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net

def train_model(args):
    """
    训练模型。

    参数:
        args (argparse.Namespace): 包含命令行参数的命名空间。
    """
    train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)

    net = create_model(model_name='resnet50', num_classes=10)

    if args.pretrain_path:
        param_dict = load_checkpoint(args.pretrain_path)
        load_param_into_net(net, param_dict)

    loss_fn = create_loss(name='CE', reduction='mean')

    lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)

    optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)

    model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})

    checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)
    checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,
                                          config=checkpoint_config)

    model.train(args.epochs, train_loader,
                callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])

构建 Gradio 接口

预测函数

在 Gradio 接口中,我们定义一个 predict_image 函数来处理图像输入并返回预测结果。

import gradio as gr
import numpy as np
from mindspore import Tensor
import cv2

def predict_image(img):
    # 创建模型实例
    net = create_model(model_name='resnet50', num_classes=NUM_CLASS)
    param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')
    load_param_into_net(net, param_dict)

    # 封装模型为 Model 类实例
    model = Model(net)
    # 调整图像格式和大小
    img = cv2.resize(img, (224, 224))
    img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32

    # 如果图像是 BGR 格式,转换为 RGB 格式
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 标准化处理
    img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # 转换维度 - 通道优先格式 (C, H, W)
    img = np.transpose(img, (2, 0, 1))

    # 添加批次维度 (N, C, H, W)
    img = np.expand_dims(img, axis=0)

    # 将图像数据转换为 MindSpore 张量
    img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型

    # 预测图像
    output = model.predict(img_tensor)

    # 应用 Softmax 获取概率
    softmax = Softmax(axis=1)
    predict_probability = softmax(output).asnumpy()
    predict_probability = predict_probability[0]  # 获取批量中的第一个元素

    # 将预测概率映射到类别名称
    return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}

Gradio 界面

使用 Gradio,我们可以快速构建一个交互式界面。用户可以上传图片,模型将返回图像分类的预测结果。

image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)

gr.Interface(css=".footer {display:none !important}",
             fn=predict_image,
             inputs=image,
             live=False,
             description="Please upload a image in JPG, JPEG or PNG.",
             title='Image Classification by ResNet50',
             outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),
             examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg',
                       './example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg',
                       './example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg',
                       './example_img/truck.jpg']
             ).launch(share=True)

image-20231121192446268

完整代码

import argparse

from mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net
from mindcv.data import create_dataset, create_transforms, create_loader
from mindspore import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint

# 设置GPU
from mindspore import context

context.set_context(device_target="GPU")


def parse_args():
    """
    解析命令行参数。

    返回:
        argparse.Namespace: 包含命令行参数的命名空间。
    """
    parser = argparse.ArgumentParser(description="训练 ResNet 模型",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrain_path', type=str, default='',
                        help='预训练文件的路径')
    parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',
                        help='训练数据的路径')
    parser.add_argument('--output_path', default='train/resnet/', type=str,
                        help='模型保存路径')
    parser.add_argument('--epochs', default=10, type=int, help='训练周期数')
    parser.add_argument('--lr', default=0.0001, type=int, help='学习率')
    return parser.parse_args()


def create_training_dataset(data_path, batch_size):
    """
    创建训练数据集。

    参数:
        data_path (str): 数据集的路径。
        batch_size (int): 批量大小。

    返回:
        Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。
    """
    dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)
    transform_train = create_transforms(dataset_name='cifar10', image_resize=224)
    train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,
                                 num_classes=10, transform=transform_train)
    num_batches = train_loader.get_dataset_size()
    return train_loader, num_batches


def train_model(args):
    """
    训练模型。

    参数:
        args (argparse.Namespace): 包含命令行参数的命名空间。
    """
    train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)

    net = create_model(model_name='resnet50', num_classes=10)

    if args.pretrain_path:
        param_dict = load_checkpoint(args.pretrain_path)
        load_param_into_net(net, param_dict)

    loss_fn = create_loss(name='CE', reduction='mean')

    lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)

    optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)

    model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})

    checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)
    checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,
                                          config=checkpoint_config)

    model.train(args.epochs, train_loader,
                callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])


if __name__ == '__main__':
    train_model(parse_args())
import gradio as gr
import numpy as np
from mindspore import Tensor
from mindspore.nn import Softmax
import cv2
from typing import Type, Union, List, Optional
from mindspore import nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindcv.models import create_model
import mindspore

print(mindspore.__version__)

NUM_CLASS = 10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def predict_image(img):
    # 创建模型实例
    net = create_model(model_name='resnet50', num_classes=NUM_CLASS)
    param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')
    load_param_into_net(net, param_dict)

    # 封装模型为 Model 类实例
    model = Model(net)
    # 调整图像格式和大小
    img = cv2.resize(img, (224, 224))
    img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32

    # 如果图像是 BGR 格式,转换为 RGB 格式
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 标准化处理
    img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # 转换维度 - 通道优先格式 (C, H, W)
    img = np.transpose(img, (2, 0, 1))

    # 添加批次维度 (N, C, H, W)
    img = np.expand_dims(img, axis=0)

    # 将图像数据转换为 MindSpore 张量
    img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型

    # 预测图像
    output = model.predict(img_tensor)

    # 应用 Softmax 获取概率
    softmax = Softmax(axis=1)
    predict_probability = softmax(output).asnumpy()
    predict_probability = predict_probability[0]  # 获取批量中的第一个元素

    # 将预测概率映射到类别名称
    return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}


image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)

gr.Interface(css=".footer {display:none !important}",
             fn=predict_image,
             inputs=image,
             live=False,
             description="Please upload a image in JPG, JPEG or PNG.",
             title='Image Classification by ResNet50',
             outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),
             examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg',
                       './example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg',
                       './example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg',
                       './example_img/truck.jpg']
             ).launch(share=True)

总结

通过 MindSpore 和 Gradio,我们可以不仅训练强大的深度学习模型,还可以将这些模型转化为交互式应用,使非专业人士也能轻松体验 AI 的魅力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234599.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

场景中的解剖学方向标记_vtkAnnotatedCubeActor

开发环境: Windows 11 家庭中文版Microsoft Visual Studio Community 2019VTK-9.3.0.rc0vtk-example参考代码 demo解决问题:显示标记当前视角、空间的方位,关键对象vtkAnnotatedCubeActor: vtkAnnotatedCubeActor 是一个混合3D 演员&#xf…

Python如何实现模板方法设计模式?什么是模板方法设计模式?Python 模板方法设计模式示例代码

什么是模板方法(Template Method)设计模式? 模板方法(Template Method)是一种行为型设计模式,它定义了一个算法的骨架,将一些步骤延迟到子类中实现。这种模式允许子类为一个算法的特定步骤提供…

远程桌面访问MATLAB 2018B,提示License Manger Error -103,终极解决方案

通过远程桌面方位Windows Server系统下的MATLAB2018B,报错License Manger Error -103,Crack文件夹下的dll文件已经替换,同时也已经输出了lic文件,但是仍然无法打开。但是在本地桌面安装就没有问题。初步怀疑MATLAB的License使用机…

Java实现象棋算法

象棋算法包括搜索算法、评估函数和剪枝算法。以下是一个简单的实现: 搜索算法:使用极大极小值算法,即每个玩家都会做出最好的选择,考虑到对方也会做出最好的选择,所以需要搜索多层。 public int search(int depth, i…

UE5 操作WebSocket

插件:https://www.unrealengine.com/marketplace/zh-CN/product/websocket-client 参考:http://dascad.net/html/websocket/bp_index.html 1. 安装Plugings 2.测试websocket服务器 http://www.websocket-test.com/ 3.连接服务器 如果在Level BP里使用&a…

武汉站--ChatGPT/GPT4科研技术应用与AI绘图及论文高效写作

2023年随着OpenAI开发者大会的召开,最重磅更新当属GPTs,多模态API,未来自定义专属的GPT。微软创始人比尔盖茨称ChatGPT的出现有着重大历史意义,不亚于互联网和个人电脑的问世。360创始人周鸿祎认为未来各行各业如果不能搭上这班车…

传输层协议 - TCP(Transmission Control Protocol)

文章目录: TCP 协议关于可靠性TCP 协议段格式序号与确认序号六个标志位16位窗口大小 确认应答(ACK)机制超时重传机制连接管理机制连接建立(三次握手)连接终止(四次挥手)TIME_WAIT 状态CLOSE_WAI…

5.2 Windows驱动开发:内核取KERNEL模块基址

模块是程序加载时被动态装载的,模块在装载后其存在于内存中同样存在一个内存基址,当我们需要操作这个模块时,通常第一步就是要得到该模块的内存基址,模块分为用户模块和内核模块,这里的用户模块指的是应用层进程运行后…

经典双指针算法试题(一)

📘北尘_:个人主页 🌎个人专栏:《Linux操作系统》《经典算法试题 》《C》 《数据结构与算法》 ☀️走在路上,不忘来时的初心 文章目录 一、移动零1、题目讲解2、讲解算法原理3、代码实现 二、复写零1、题目讲解2、讲解算法原理3、…

Spring-IOC-@Import的用法

1、Car.java package com.atguigu.ioc; import lombok.Data; Data public class Car {private String cname; }2、 MySpringConfiguration2.java package com.atguigu.ioc; import org.springframework.context.annotation.Bean; import org.springframework.context.annotatio…

VBA技术资料MF85:将工作簿批量另存为PDF文件

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…

香港科技大学广州|机器人与自主系统学域博士招生宣讲会—同济大学专场!!!(暨全额奖学金政策)

在机器人和自主系统领域实现全球卓越—机器人与自主系统学域 硬核科研实验室,浓厚创新产学研氛围! 教授亲临现场,面对面答疑解惑助攻申请! 一经录取,享全额奖学金1.5万/月! 🕙时间:…

mac 和 windows 相互传输文件【共享文件夹】

文章目录 前言创建共享文件夹mac 连接共享文件夹 前言 温馨提示:mac 电脑和 windows 电脑必须处于同一局域网下 本文根据创建共享文件夹的方式实现文件互相传输,所以两台电脑必须处于同一网络 windows 创建共享文件夹,mac 电脑通过 windows…

Mrakdown Nice:格式

标题 缩进 删除线 斜体 加粗

EANet:用于医学图像分割的迭代边缘注意力网络

EANet: Iterative edge attention network for medical image segmentation EANet:用于医学图像分割的迭代边缘注意力网络背景贡献实验方法Dynamic scale-aware context module(动态规模感知上下文模块)Edge attention preservation module&a…

【日常总结】Swagger-ui 导入 showdoc (优雅升级Swagger 2 升至 3.0)

一、场景 环境: 二、存在问题 三、解决方案 四、实战 - Swagger 2 升至 3.0 (Open API 3.0) Stage 1:引入Maven依赖 Stage 2:Swagger 配置类 Stage 3:访问 Swagger 3.0 Stage 4:获取 js…

pycharm 控制台中文乱码处理

今天使用pycharm,发现控制台输出又中文乱码了,看网上很多资料说把编码改为UTF-8,设置为并未生效,特此在此记录下本地设置。 1. 修改文件编码:Setting -> Editor ->File Encodings,修改配置如下: 2. …

Windows10环境下Python解析pacp文件

Windows10环境下Python解析pacp文件 一、背景 在Python中,你可以使用scapy库来解析pcap文件。scapy是一个功能强大的网络分析工具,可以用于解析、构建和发送网络数据包。 二、环境安装 命令在终端中安装: pip install scapy由于我使用的Pycharm,所以我就直接在Python Int…

释放搜索潜力:基于Docker快速搭建ES语义检索系统(快速版),让信息尽在掌握

搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术细节以及项目实战(含码源) 专栏详细介绍:搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排)、系统架构、常见问题、算法项目实战总结、技术…

聚焦操作系统迁移痛点,麒麟信安受邀参加 openEuler Meetup苏州站分享迁移实践干货

随着数字化转型持续深入,操作系统正在向支持多样性计算、支持全场景等方向不断发展。日前,由开放原子开源基金会指导,openEuler社区、移动云联合主办的迁移主题Meetup在苏州举办,邀请来自不同领域的技术专家分享系统迁移实践案例。…