mindspore mindcv图像分类算法;昇腾NPU加速使用;模型保存与加载

news2024/11/17 12:36:12

参考:
https://www.mindspore.cn/tutorials/en/r1.3/save_load_model.html
https://github.com/mindspore-lab/mindcv/blob/main/docs/zh/tutorials/finetune.md

1、mindspore mindcv图像分类算法

import os
from mindcv.utils.download import DownLoad
import os
import mindspore as ms


os.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)  ##指定cpu
#ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0)  ##需要使用才能npu加速



dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
root_dir = "./"

if not os.path.exists(os.path.join(root_dir, 'data/Canidae')):
    DownLoad().download_and_extract_archive(dataset_url, root_dir)
    
    
    
##加载数据

from mindcv.data import create_dataset, create_transforms, create_loader

num_workers = 8

# 数据集目录路径
data_dir = "./data/Canidae/"

# 加载自定义数据集
dataset_train = create_dataset(root=data_dir, split='train', num_parallel_workers=num_workers)
dataset_val = create_dataset(root=data_dir, split='val', num_parallel_workers=num_workers)



# 定义和获取数据处理及增强操作
trans_train = create_transforms(dataset_name='ImageNet', is_training=True)
trans_val = create_transforms(dataset_name='ImageNet',is_training=False)

loader_train = create_loader(
    dataset=dataset_train,
    batch_size=16,
    is_training=True,
    num_classes=2,
    transform=trans_train,
    num_parallel_workers=num_workers,
)
loader_val = create_loader(
    dataset=dataset_val,
    batch_size=5,
    is_training=True,
    num_classes=2,
    transform=trans_val,
    num_parallel_workers=num_workers,
)


#模型微调

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)


#训练
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindspore import Model, LossMonitor, TimeMonitor

# 定义优化器和损失函数
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-4)
loss = create_loss(name='CE')

# 实例化模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'accuracy'})
model.train(10, loader_train, callbacks=[LossMonitor(5), TimeMonitor(5)], dataset_sink_mode=False)

res = model.eval(loader_val)
print(res)

import matplotlib.pyplot as plt
import mindspore as ms
import numpy as np

def visualize_model(model, val_dl, num_classes=2):
    # 加载验证集的数据进行验证
    images, labels= next(val_dl.create_tuple_iterator())
    # 预测图像类别
    output = model.predict(images)
    pred = np.argmax(output.asnumpy(), axis=1)
    # 显示图像及图像的预测值
    images = images.asnumpy()
    labels = labels.asnumpy()
    class_name = {0: "dogs", 1: "wolves"}
    plt.figure(figsize=(15, 7))
    for i in range(len(labels)):
        plt.subplot(3, 6, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()
    
visualize_model(model, loader_val)

在这里插入图片描述

上面改成昇腾NPU计算

##就改这一句

os.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0)  ##需要使用才能npu加速

训练速度回提升很多,但注意编译时间也先需要挺久
在这里插入图片描述

在这里插入图片描述

2、模型保存与加载

##另外mindscope 2.2后保存会新增save_mindir接口,参考:https://www.mindspore.cn/docs/zh-CN/r2.2/api_python/mindspore.html
在这里插入图片描述

## 保存模型
import mindspore as ms

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)

ms.save_checkpoint(network, "model1.ckpt")
## 加载模型
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Model

param_dict = load_checkpoint("model1.ckpt")
param_not_load = load_param_into_net(network, param_dict)
print(param_not_load)

model1 = Model(network, loss, metrics={"accuracy"})
加载训练模型验证
images1, labels1= next(loader_val.create_tuple_iterator())

output = model1.predict(images1)
pred = np.argmax(output.asnumpy(), axis=1)
# 显示图像及图像的预测值
images = images1.asnumpy()
labels = labels1.asnumpy()

pred,labels

在这里插入图片描述

##预测单张,增加一个batch维度unsqueeze
model1.predict(images1[0].unsqueeze(0))

model1.predict(ms.ops.unsqueeze(images1[0], dim=0))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1201714.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法通关村第十五关白银挑战——海量数据场景下的热门算法题

大家好,我是怒码少年小码。 最近超级忙,很多实验报告,已经四五天没搞了,但是我还是回来了! 海量数据场景下的热门算法题 本篇的题目不要求写代码,面试的时候能很清楚的说出思路就可以了。 1. 从40个亿中…

Flutter开发中的一些Tips(四)

最近接手了一个flutter项目,整体感觉代码质量不高,感觉有些是初学者容易犯的问题。几年前写的前三篇,我是站在我自己开发遇到问题的角度,这篇是站在别人遇到问题的角度,算是一种补充。下面我整理一下遇到的小问题&…

Linux安装微信

Linux安装微信 环境:ubuntu 20.04 https://archive.ubuntukylin.com/ubuntukylin/pool/partner/weixin_2.1.4_amd64.deb sudo dpkg -i weixin_2.1.4_amd64.deb完成 参考文章

【Spring进阶系列丨第一篇】初识Spring框架

前言 小伙伴们大家好,我是陈橘又青,今天起 《Spring进阶系列》 开始更新。本专栏将涵盖Spring框架的核心概念、配置管理、Web开发、AOP、Boot、Security、Data、Integration和Batch等多个主题。通过理论讲解和实际案例的剖析,帮助读者深入理解…

k8s的Init Containers容器实现代码版本升级发布和deployment版本回退:实战操作版

Pod中的初始化容器:Init Containers initContainers实现理论前提:同一个Pod内的容器共享 网络、volume等资源 Init Containers 在Kubernetes中,init容器是在同一个Pod中的其他容器之前启动和执行的容器。它的目的是为Pod上托管的主应用程序执行初始化…

【C++】STL容器适配器——priority_quene(堆/优先级队列)类的使用指南(含代码使用)(19)

前言 大家好吖,欢迎来到 YY 滴C系列 ,热烈欢迎! 本章主要内容面向接触过C的老铁 主要内容含: 欢迎订阅 YY滴C专栏!更多干货持续更新!以下是传送门! 目录 一.priority_quene的文档介绍二、prior…

爬虫实战:基于urllib和mysql爬取苏州公交线路信息

文章目录 写在前面实验环境实验描述实验目标实验内容1. 确定并分析目标网页结构2. 编写urllib代码爬取公交信息3. 保存公交数据到csv文件中4. 保存公交数据到mysql数据库中 写在后面 写在前面 本文将基于python的urllib模块,爬取北京公交线路的信息,最后…

web基础和http协议(粗糙版)

服务部署,集训,分布式,数据库,日志系统,等二阶段 web基础和http协议: web的相关基础知识,包括域名 dns解析 网页的概念以及http协议 1.网络当中通信:端口 ip 协议 tcp/ip 传输过程…

CAD转换器:CAD Exchanger SDK --Crack

转换器 目录 概述读取文件 增量加载写入文件格式特定的详细信息进度状态支持例子 读取和写入多种 CAD 和 BIM 文件格式。 概述 读取(导入)和写入(导出)文件是使用 CAD Exchanger SDK 时的主流场景。支持的格式列表可在此处获取。 …

Misc | bucket 第二届“奇安信”杯网络安全技能竞赛

题目描述: 解密Base全家桶。 密文: 下载附件,解压得到一个txt文本,打开如下。 3441344134363435344435323442344534423441343635353334353333323442343935413442353434393535354135333441344534353536353535333332353534413436…

酷柚易汛ERP-自定义打印整体介绍

1、产品介绍 每种单据系统预设常用模板,提供A4纸张、三等分、二等分,销货单额外提供80mm、58mm供用户选择;每张单据可设置一个默认模板和多个常用模;除默认模板外,其他模板都允许删除,用户可以根据公司业务…

ArcGIS实现矢量区域内所有要素的统计计算

1、任务需求:统计全球各国所有一级行政区相关属性的总和。 (1)有一个全球一级行政区的矢量图,包含以下属性(洪灾相关属性 province.shp) (2)需要按照国家来统计各个国家各属性的总值…

前端前沿技术

文章目录 网站静态化PWA - Progressive Web APP, 渐进式 Web 应用PWA 简介解决了哪些问题?PWA 的优势浏览器支持情况参考文档 Weex 是一个可以使用现代化的 Web 技术开发高性能原生应用的框架。高性能跨平台贴近前端生态被大规模的使用 GraphQL[一种用于 API 的查询语言](http…

并发事务下,不同隔离级别可能出现的问题

并发事务下,不同隔离级别可能出现的问题 1、事务的 ACID2、并发事务下,不同隔离级别可能出现的问题2.1、脏写2.2、脏读2.3、不可重复读2.4、幻读 3、SQL 中的四种隔离级别 1、事务的 ACID 原子性(Atomicity):原子性意味…

数据结构 | 栈的实现

数据结构 | 栈的实现 文章目录 数据结构 | 栈的实现栈的概念及结构栈的实现 Stack.h初始化栈入栈出栈获取栈顶元素获取栈中有效元素个数检测栈是否为空销毁栈 Stack.c 栈的概念及结构 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。…

勘察设计考试公共基础之数学篇

1、数学 向量点积:向量叉积:平面的法向量为n(A,B,C),则该平面的点法式方程为: A(x-x0)B(y-y0)C(z-z0)0 两平…

爬虫,TLS指纹 剖析和绕过

当你欲爬取某网页的信息数据时,发现通过浏览器可正常访问,而通过代码请求失败,换了随机ua头IP等等都没什么用时,有可能识别了你的TLS指纹做了验证。 解决办法: 1、修改 源代码 2、使用第三方库 curl-cffi from curl…

【T690 之十一】基于方寸EVB2开发板,结合 Eclipse+gdb+gdbserver 调试 CCAT 的流程总结

目录 1. 准备工作1.1 Eclipse1.2 工程编译1.3 烧写固件 2. 创建工程2.1 搭建调试工程2.2 配置Dbug调试信息 3. 调试4. 手动调试过程4. 总结 备注: 1,假设您已对方寸微电子的T690系列芯片的使用方式都有了一定的了解,可以根据此文的配置进行Li…

3D模型人物换装系统二(优化材质球合批降低DrawCall)

3D模型人物换装系统 介绍原理合批材质对比没有合批材质核心代码完整代码修改总结 介绍 本文使用2018.4.4和2020.3.26进行的测试 本文没有考虑法线贴图合并的问题,因为生成法线贴图有点问题,放在下一篇文章解决在进行优化 如果这里不太明白换装的流程可以…

基于物理的多偏置射频大信号氮化镓HEMT建模和参数提取流程

标题:Physics-Based Multi-Bias RF Large-Signal GaN HEMT Modeling and Parameter Extraction Flow 来源:JOURNAL OF THE ELECTRON DEVICES SOCIETY 摘要 本文展示了一种一致的Al镓氮化物(AlGaN)/氮化镓(GaN&#x…