【人工智能】MindSpore Hub

news2024/12/29 9:14:21

目录

前言

一、什么是MindSpore Hub

1.简单介绍

2.MindSpore Hub包含功能

3.MindSpore Hub使用场景

二、安装MindSpore Hub

1.确认系统环境信息

2.安装

 3.下载源码

 4.进行验证

三、加载模型

1.介绍

2.推理验证

3.迁移学习

四、模型发布


前言

MindSpore着重提升易用性并降低AI开发者的开发门槛,MindSpore原生适应每个场景包括端、边缘和云,并能够在按需协同的基础上,通过实现AI算法即代码,使开发态变得更加友好,显著减少模型开发时间,降低模型开发门槛。通过MindSpore自身的技术创新及MindSpore与华为昇腾AI处理器的协同优化,实现了运行态的高效,大大提高了计算性能;MindSpore也支持GPU、CPU等其它处理器。

一、什么是MindSpore Hub

1.简单介绍

官方版本的预训练模型中心库---MindSpore Hub

 mindspore_hub 是一个Python库

下载网址:点击跳转

2.MindSpore Hub包含功能

  • 即插即用的模型加载

  • 简单易用的迁移学习

import mindspore
import mindspore_hub as mshub
from mindspore import set_context, GRAPH_MODE

set_context(mode=GRAPH_MODE,
            device_target="Ascend",
            device_id=0)

model = "mindspore/1.6/googlenet_cifar10"

# Initialize the number of classes based on the pre-trained model.
network = mshub.load(model, num_classes=10)
network.set_train(False)

# ...

3.MindSpore Hub使用场景

· 推理验证:mindspore_hub.load用于加载预训练模型,可以实现一行代码完成模型的加载。

· 迁移学习:通过mindspore_hub.load完成模型加载后,可以增加一个额外的参数项只加载神经网络的特征提取部分,这样就能很容易地在之后增加一些新的层进行迁移学习。

· 发布模型:可以将自己训练好的模型按照指定的步骤发布到MindSpore Hub中,以供其他用户进行下载和使用。

二、安装MindSpore Hub

1.确认系统环境信息

  • 硬件平台支持Ascend、GPU和CPU。

  • 确认安装Python 3.7.5版本。

  • MindSpore Hub与MindSpore的版本需保持一致。

  • MindSpore Hub支持使用x86 64位或ARM 64位架构的Linux发行版系统。

  • 在联网状态下,安装whl包时会自动下载setup.py中的依赖项,其余情况需自行安装。

2.安装

在命令行中输入下面代码进行下载MindSpore Hub whl包

pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/{version}/Hub/any/mindspore_hub-{version}-py3-none-any.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

 3.下载源码

 

从Gitee下载源码。

git clone https://gitee.com/mindspore/hub.git -b r1.9

编译安装MindSpore Hub。

cd hub   ##切换到hub文件下
python setup.py install   ## 下载

 4.进行验证

在能联网的环境中执行以下命令,验证安装结果。

import mindspore_hub as mshub

model = mshub.load("mindspore/1.6/lenet_mnist", num_class=10)

如果出现下列提示,说明安装成功:

Downloading data from url https://gitee.com/mindspore/hub/raw/r1.9/mshub_res/assets/mindspore/1.6/lenet_mnist.md

Download finished!
File size = 0.00 Mb
Checking /home/ma-user/.mscache/mindspore/1.6/lenet_mnist.md...Passed!

三、加载模型

1.介绍

于个人开发者来说,从零开始训练一个较好模型,需要大量的标注完备的数据、足够的计算资源和大量训练调试时间。使得模型训练非常消耗资源,提升了AI开发的门槛,针对以上问题,MindSpore Hub提供了很多训练完成的模型权重文件,可以使得开发者在拥有少量数据的情况下,只需要花费少量训练时间,即可快速训练出一个较好的模型。

2.推理验证

##使用url完成模型的加载

import mindspore_hub as mshub
import mindspore
from mindspore import Tensor, nn, Model, set_context, GRAPH_MODE
from mindspore import dtype as mstype
import mindspore.dataset.vision as vision

set_context(mode=GRAPH_MODE,
                        device_target="Ascend",
                        device_id=0)

model = "mindspore/1.6/googlenet_cifar10"

# Initialize the number of classes based on the pre-trained model.
network = mshub.load(model, num_classes=10)
network.set_train(False)

 最后使用MindSpore进行推理

3.迁移学习

#使用url进行MindSpore Hub模型的加载,注意:include_top参数需要模型开发者提供。

import os
import mindspore_hub as mshub
import mindspore
from mindspore import Tensor, nn, set_context, GRAPH_MODE
from mindspore.nn import Momentum
from mindspore import save_checkpoint, load_checkpoint,load_param_into_net
from mindspore import ops
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
from mindspore import dtype as mstype
from mindspore import Model
set_context(mode=GRAPH_MODE, device_target="Ascend", device_id=0)

model = "mindspore/1.6/mobilenetv2_imagenet2012"
network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid")
network.set_train(False)

#在现有模型结构基础上,增加一个与新任务相关的分类层。

class ReduceMeanFlatten(nn.Cell):
      def __init__(self):
         super(ReduceMeanFlatten, self).__init__()
         self.mean = ops.ReduceMean(keep_dims=True)
         self.flatten = nn.Flatten()

      def construct(self, x):
         x = self.mean(x, (2, 3))
         x = self.flatten(x)
         return x

# Check MindSpore Hub website to conclude that the last output shape is 1280.
last_channel = 1280

# The number of classes in target task is 10.
num_classes = 10

reducemean_flatten = ReduceMeanFlatten()

classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(True)

train_network = nn.SequentialCell([network, reducemean_flatten, classification_layer])

#定义数据集加载函数。



def create_cifar10dataset(dataset_path, batch_size, usage='train', shuffle=True):
    data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=shuffle)

    # define map operations
    trans = [
        vision.Resize((256, 256)),
        vision.RandomHorizontalFlip(prob=0.5),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        vision.HWC2CHW()
    ]

    type_cast_op = transforms.TypeCast(mstype.int32)

    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)
    return data_set

# Create Dataset
dataset_path = "/path_to_dataset/cifar-10-batches-bin"
dataset = create_cifar10dataset(dataset_path, batch_size=32, usage='train', shuffle=True)

#为模型训练选择损失函数、优化器和学习率。

def generate_steps_lr(lr_init, steps_per_epoch, total_epochs):
    total_steps = total_epochs * steps_per_epoch
    decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps]
    lr_each_step = []
    for i in range(total_steps):
        if i < decay_epoch_index[0]:
            lr = lr_init
        elif i < decay_epoch_index[1]:
            lr = lr_init * 0.1
        elif i < decay_epoch_index[2]:
            lr = lr_init * 0.01
        else:
            lr = lr_init * 0.001
        lr_each_step.append(lr)
    return lr_each_step

# Set epoch size
epoch_size = 60

# Wrap the backbone network with loss.
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
loss_net = nn.WithLossCell(train_network, loss_fn)
steps_per_epoch = dataset.get_dataset_size()
lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size)

# Create an optimizer.
optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5)
train_net = nn.TrainOneStepCell(loss_net, optim)

#开始重训练。

for epoch in range(epoch_size):
    for i, items in enumerate(dataset):
        data, label = items
        data = mindspore.Tensor(data)
        label = mindspore.Tensor(label)

        loss = train_net(data, label)
        print(f"epoch: {epoch}/{epoch_size}, loss: {loss}")
    # Save the ckpt file for each epoch.
    if not os.path.exists('ckpt'):
       os.mkdir('ckpt')
    ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt"
    save_checkpoint(train_network, ckpt_path)

#在测试集上测试模型精度。

model = "mindspore/1.6/mobilenetv2_imagenet2012"

network = mshub.load(model, num_classes=500, pretrained=True, include_top=False, activation="Sigmoid")
network.set_train(False)
reducemean_flatten = ReduceMeanFlatten()
classification_layer = nn.Dense(last_channel, num_classes)
classification_layer.set_train(False)
softmax = nn.Softmax()
network = nn.SequentialCell([network, reducemean_flatten, classification_layer, softmax])

# Load a pre-trained ckpt file.
ckpt_path = "./ckpt/cifar10_finetune_epoch59.ckpt"
trained_ckpt = load_checkpoint(ckpt_path)
load_param_into_net(classification_layer, trained_ckpt)

loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

# Define loss and create model.
eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False)
eval_metrics = {'Loss': nn.Loss(),
                 'Top1-Acc': nn.Top1CategoricalAccuracy(),
                 'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(network, loss_fn=loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(eval_dataset)
print("metric: ", metrics)

四、模型发布

#将你的预训练模型托管在可以访问的存储位置。参照模板,在你自己的代码仓中添加模型生成文件mindspore_hub_conf.py,文件放置的位置如下:

googlenet
├── src
│   ├── googlenet.py
├── script
│   ├── run_train.sh
├── train.py
├── test.py
├── mindspore_hub_conf.py

#参照模板,在hub/mshub_res/assets/mindspore/1.6文件夹下创建{model_name}_{dataset}.md文件,其中1.6为MindSpore的版本号,hub/mshub_res的目录结构为:

hub
├── mshub_res
│   ├── assets
│       ├── mindspore
│           ├── 1.6
│               ├── googlenet_cifar10.md
│   ├── tools
│       ├── get_sha256.py
│       ├── load_markdown.py
│       └── md_validator.py

本次分享完

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/20848.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

营造激发自驱力注重培养学习力的想法一

目录背景过程第一节&#xff1a;第二节&#xff1a;第三节&#xff1a;总结升华背景 小编做的是教育类公司&#xff0c;其实无论是做公司的产品&#xff0c;还是对于公司团队人员的培养&#xff0c;都需要去思考教育这件事&#xff0c;尤其是激发自驱力培养学习力&#xff1b;…

常用的框架07-消息中间件-RabbitMQ

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录1.消息中间件概述1.1 为什么学习消息队列1.2 什么是消息中间件1.3 消息队列应用场景1.3.1 异步处理1.3.2 应用程序解耦合1.3.3 削峰填谷1.3.4 什么是QPS1.3.5 什么是…

servlet和vue的增删改查

1.servlet实现步骤 Servlet->新增 servlet获取请求参数&#xff0c;将参数转化为对象&#xff0c;调用service WebServlet("/addService") public class addAllService extends HttpServlet {private BrandService brandService new BrandServiceimpl() ;Over…

云计算之虚拟化技术学习(KVM/Xen/Hyper-V/VMware)

文章目录虚拟化技术什么是虚拟化服务器虚拟化cpu的虚拟化内存虚拟化管理硬盘的虚拟化网络虚拟化IO虚拟化Intel虚拟化技术主流的虚拟化技术虚拟化技术对比XenKVMHyper-VVMware ESX/ESXi虚拟化服务平台Libvirt基于KVM的虚拟化服务平台虚拟化技术 什么是虚拟化 虚拟化是云计算的…

最长公共子序列长度

求两个字符串的最长公共子序列长度。 输入格式: 输入长度≤100的两个字符串。 输出格式: 输出两个字符串的最长公共子序列长度。 输入样例1: ABCBDAB BDCABA输出样例1: 4输入样例2: ABACDEF PGHIK输出样例2: 0 (1条消息) HBU训练营【动态规划DP】——最长公共子序列长…

力扣(LeetCode)799. 香槟塔(C++)

动态规划 设 iii 是行 , jjj 是列 &#xff0c; f[i][j]f[i][j]f[i][j] 表示经过杯子的酒量 &#xff0c;初始 f[0][0]pouredf[0][0]pouredf[0][0]poured &#xff0c; 为了理解&#xff0c;当做每个杯子有无限容量。 当香槟溢出时&#xff0c;f[i][j]f[i][j]f[i][j] 保留自己的…

放大镜-第12届蓝桥杯Scratch选拔赛真题精选

[导读]&#xff1a;超平老师计划推出Scratch蓝桥杯真题解析100讲&#xff0c;这是超平老师解读Scratch蓝桥真题系列的第80讲。 蓝桥杯选拔赛每一届都要举行4~5次&#xff0c;和省赛、国赛相比&#xff0c;题目要简单不少&#xff0c;再加上篇幅有限&#xff0c;因此我精挑细选…

SpringCloud系列(一)Eureka 注册中心

本文主要介绍 Eureka 用来做什么&#xff1f; 如何搭建以及测试&#xff1b;  微服务框架区分于普通的单体架构项目&#xff0c;它是一种经过良好架构设计的分布式架构方案&#xff0c;根据业务功能对系统进行拆分&#xff0c;将每个业务模块都当做是一个独立的项目进行开发&a…

session共享问题及四种解决方案-前端存储、session的复制 、session粘性、后端存储(Mysql、Redis等)

&#x1f468;‍&#x1f4bb;个人主页&#xff1a; 才疏学浅的木子 &#x1f647;‍♂️ 本人也在学习阶段如若发现问题&#xff0c;请告知非常感谢 &#x1f647;‍♂️ &#x1f4d2; 本文来自专栏&#xff1a; 常用工具类以及常见问题处理方法 &#x1f308; 每日一语&…

Alos PALSAR 12.5米免费DEM下载教程

Alos PALSAR 12.5米免费DEM下载教程ALOS 12.5米数据简介2. 下载2.1 搜索数据2.2 下载数据3. 使用数据ALOS 12.5米数据简介 ALOS 12.5m DEM 数据&#xff0c;是使用ALOS&#xff08;Advanced Land Observing Satellite&#xff09;卫星相控阵型L波段合成孔径雷达&#xff08;PA…

SpringBoot + EasyExcel 实现表格数据导入

1. 准备 导入依赖 <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.0.5</version><scope>compile</scope> </dependency><dependency><groupId>org.proj…

使用vi、vim、sed、echo、cat操作文件

记录&#xff1a;324 场景&#xff1a;在CentOS 7.9操作系统上&#xff0c;使用vi编辑器、vim编辑器、sed编辑器操作文件读、写、删、替换等操作&#xff1b;使用echo命令和cat命令将内容输出文件并查看内容。 版本&#xff1a; 操作系统&#xff1a;CentOS 7.9 1.vi编辑器…

同花顺_代码解析_技术指标_R

本文通过对同花顺中现成代码进行解析&#xff0c;用以了解同花顺相关策略设计的思想 目录 RAD RADER RCCD ROC ROCFS RSI RSIFS RAD 威力雷达 大盘指标。 RAD的判断基准法与传统指标相似: 白线上穿黄线为金叉,示强势,为买入建仓机会参考&#xff1b; 白线下穿黄线为…

红黑树的插入(C++实现)

1. 红黑树 1.1 概念 红黑树是一种二叉搜索树&#xff0c;它是AVL树的优化版本。红黑树是每个节点都带有颜色属性的二叉搜索树&#xff0c;颜色为红色或黑色。 之所以选择“红色”是因为这是作者在帕罗奥多研究中心公司Xerox PARC工作时用彩色雷射列印机可以产生的最好看的颜色…

Java学习之包访问修饰符

基本介绍 java 提供四种访问控制修饰符号&#xff0c;用于控制方法和属性(成员变量)的访问权限&#xff08;范围&#xff09; 公开级别:用 public 修饰,对外公开受保护级别:用 protected 修饰,对子类和同一个包中的类公开默认级别:没有修饰符号,向同一个包的类公开.私有级别:…

采用sFlow工具实现流量监控--实验

采用sFlow工具实现流量监控--实验采用sFlow工具实现流量监控---实验学习目标学习内容实验原理实验拓扑实验仿真启动sFlow-rt以及floodlight控制器创建拓扑部署sFlow agent步骤1.步骤2.步骤3步骤4步骤5.步骤6.总结申明&#xff1a; 未经许可&#xff0c;禁止以任何形式转载&…

C++模拟OpenGL库——图形学状态机接口封装(一):用状态模式重构部分代码及接口定义

目录 什么是状态机&#xff1f; 基于状态机模式进行重构 Canvas.h源码 什么是状态机&#xff1f; 回顾之前两部分内容&#xff0c;我们做了&#xff1a; 绘制点绘制线&#xff08;Brensenham&#xff09;绘制三角形&#xff08;拆分法&#xff09;图片操作&#xff08;stb…

RabbitMQ------延迟队列(整合SpringBoot以及使用延迟插件实现真正延时)(七)

RabbitMQ------延迟队列&#xff08;七&#xff09; 延迟队列 延迟队列&#xff0c;内部是有序的&#xff0c;特点&#xff1a;延时属性。 简单讲&#xff1a;延时队列是用来存放需要在指定时间被处理的元素队列。 是基于死信队列的消息过期场景。 适用场景 1.订单在十分钟…

Linux(centos7)安装MySQL5.7

Linux 安装MySQL5.7 数据库 所有的安装方式是基于手动式的安装&#xff0c;也就是整体的下载然后配置 rpm与yum之间的关系 rpm 是Linux 免除编译安装带来的安装方式&#xff0c;而yum 是在rpm 上面的进一步的优化&#xff0c;换句话说yum 既包含了rpm 的简单安装&#xff0c…

百度地图自定义覆盖物(html)格式

<style type"text/css"> body, html{ width: 100%; height: 100%; overflow: hidden; margin: 0; font-family: "微软雅黑"; display: flex; justify-content: space-between; } #cont…