mindspore mindcv图像分类算法;模型保存与加载

news2024/10/6 22:19:05

参考:
https://www.mindspore.cn/tutorials/en/r1.3/save_load_model.html
https://github.com/mindspore-lab/mindcv/blob/main/docs/zh/tutorials/finetune.md

1、mindspore mindcv图像分类算法

import os
from mindcv.utils.download import DownLoad
import os
import mindspore as ms


os.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)  ##指定cpu
#ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0)  ##需要使用才能npu加速



dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
root_dir = "./"

if not os.path.exists(os.path.join(root_dir, 'data/Canidae')):
    DownLoad().download_and_extract_archive(dataset_url, root_dir)
    
    
    
##加载数据

from mindcv.data import create_dataset, create_transforms, create_loader

num_workers = 8

# 数据集目录路径
data_dir = "./data/Canidae/"

# 加载自定义数据集
dataset_train = create_dataset(root=data_dir, split='train', num_parallel_workers=num_workers)
dataset_val = create_dataset(root=data_dir, split='val', num_parallel_workers=num_workers)



# 定义和获取数据处理及增强操作
trans_train = create_transforms(dataset_name='ImageNet', is_training=True)
trans_val = create_transforms(dataset_name='ImageNet',is_training=False)

loader_train = create_loader(
    dataset=dataset_train,
    batch_size=16,
    is_training=True,
    num_classes=2,
    transform=trans_train,
    num_parallel_workers=num_workers,
)
loader_val = create_loader(
    dataset=dataset_val,
    batch_size=5,
    is_training=True,
    num_classes=2,
    transform=trans_val,
    num_parallel_workers=num_workers,
)


#模型微调

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)


#训练
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindspore import Model, LossMonitor, TimeMonitor

# 定义优化器和损失函数
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-4)
loss = create_loss(name='CE')

# 实例化模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'accuracy'})
model.train(10, loader_train, callbacks=[LossMonitor(5), TimeMonitor(5)], dataset_sink_mode=False)

res = model.eval(loader_val)
print(res)

import matplotlib.pyplot as plt
import mindspore as ms
import numpy as np

def visualize_model(model, val_dl, num_classes=2):
    # 加载验证集的数据进行验证
    images, labels= next(val_dl.create_tuple_iterator())
    # 预测图像类别
    output = model.predict(images)
    pred = np.argmax(output.asnumpy(), axis=1)
    # 显示图像及图像的预测值
    images = images.asnumpy()
    labels = labels.asnumpy()
    class_name = {0: "dogs", 1: "wolves"}
    plt.figure(figsize=(15, 7))
    for i in range(len(labels)):
        plt.subplot(3, 6, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()
    
visualize_model(model, loader_val)

在这里插入图片描述

在这里插入图片描述

2、模型保存与加载

## 保存模型
import mindspore as ms

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)

ms.save_checkpoint(network, "model1.ckpt")
## 加载模型
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Model

param_dict = load_checkpoint("model1.ckpt")
param_not_load = load_param_into_net(network, param_dict)
print(param_not_load)

model1 = Model(network, loss, metrics={"accuracy"})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1188939.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

举个栗子!Tableau 技巧(259):文本表中省市县数据的灵活逐级下钻「方法一」

之前,我们分享过 🌰:实现地图中的省市县逐级下钻。有数据粉提出问题:如果不是地图,而是文本表,有什么办法可以像这样,实现地理位置逐级下钻呢? 文本表也是可以的。但是,…

Count-based exploration with neural density models论文笔记

Count-based exploration with neural density models[J]. International Conference on Machine Learning,International Conference on Machine Learning, 2017. 基于计数的神经密度模型探索 0、问题 这篇文章的关键在于弄懂pseudo-count的概念,以及是如何运用…

【Leetcode】202. 两数之和

给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以按任意顺序返回…

Java 身份证号校验,根据身份证号识别出生地

Java 身份证号校验: import org.apache.commons.lang.StringUtils;import java.util.Calendar; import java.util.Collections; import java.util.HashMap; import java.util.Map;/*** desc 身份证工具类* auth llp* date 2022/7/7 16:13*/ public class IdCardNum…

Java算法(三): 判断两个数组是否为相等 → (要求:长度、顺序、元素)相等

Java算法(三) 需求: 1. 定义一个方法,用于比较两个数组是否相同2. 需求:长度,内容,顺序完全相同package com.liujintao.compare;public class SameArray {public static void main (String[] a…

JAVA微信端医院3D智能导诊系统源码

医院智能导诊系统利用高科技的信息化手段,优化就医流程。让广大患者有序、轻松就医,提升医疗服务水平。 随着人工智能技术的快速发展,语音识别与自然语言理解技术的成熟应用,基于人工智能的智能导诊导医逐渐出现在患者的生活视角中…

小红书达人投放比例是多少合适?品牌方必看

品牌做小红书种草推广想要产生更好的效果,是需要素人和达人按照一定比例去进行投放的,素人铺量可以让产品产生迅速曝光的效果,少量达人投放可以让产品产生更好的转化效果。 小红书达人投放具有较高的互动性和口碑传播效果。达人通过自身的影…

打开pr提示找不到vcomp100.dll无法继续执行代码怎么办?5种dll问题解决方案全解析

vcomp100.dll是一个由Microsoft开发的动态链接库(DLL)文件,它对于许多基于图形的应用程序(如Photoshop)和多个游戏(如《巫师3》)至关重要。以下是关于vcomp100.dll的属性介绍以及找不到vcomp100…

小程序如何部署SSL证书

微信小程序开发前提必须拥有一本SSL证书,办理SSL证书之前确保好指定的微信小程序开发接口使用的域名,如果没有域名的提前申请好,并且到国内服务器提供商去办理备案。 了解微信小程序使用SSL证书的作用,包括以下三个方面&#xff1…

[C语言基础]文件读取模式简析

文件操作 打开方式介绍r / rb模式w / wb模式 打开方式介绍 函数fopen可打开一个文件,返回值是文件指针FILE * 第一个参数是文件路径,第二个参数是打开方式mode 参数可为以下几种: r/w/a/r/w/a/rb/wb/ab/rb/wb/ab 其中, r 为只读&…

求臻医学MRD产品喜获北京市新技术新产品(服务)证书

近日,北京市科学技术委员会、中关村科技园区管理委员会、北京市发展和改革委员会等五大部门联合公示了2023年度第一批(总第十八批)北京市新技术新产品(服务)名单。凭借领先的技术能力、产品创新能力及质量可靠性等优势…

大数据毕业设计选题推荐-河长制大数据监测平台-Hadoop-Spark-Hive

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

VM虚拟机安装

想编译一个 c 代码 windows 转成 linux 安装一个vm 准备一个虚拟机安装包,双击,开始安装 下一步 缸盖安装位置路径,添加PATH,下一步 下一步 添加到桌面,加入开始菜单,下一步 打开桌面的软件图标&#…

Panorama SCADA平台的警报通知功能配置详解

1. 前言 SCADA系统的主要目标是采集与监控工业过程数据,以确保工业生产正常运行。通过实时警报通知功能,操作人员可以立即获取有关潜在问题的信息,从而能够快速采取行动解决问题,防止进一步的损害或生产中断。因此,及…

小程序版本审核未通过,需在开发者后台「版本管理—提交审核——小程序订单中心path」设置订单中心页path,请设置后再提交代码审核

小程序版本审核未通过,需在开发者后台「版本管理—提交审核——小程序订单中心path」设置订单中心页path,请设置后再提交代码审核 因小程序尚未发布,订单中心不能正常打开查看,请先发布小程序后再提交订单中心PATH申请 初次提交…

03【远程协作开发、TortoiseGit、IDEA绑定Git插件的使用】

上一篇:02【Git分支的使用、Git回退、还原】 下一篇:【已完结】 目录:【Git系列教程-目录大纲】 文章目录 一、远程协作开发1.1 远程仓库简介1.1.1 Github1.1.2 Gitee1.1.3 其他托管平台 1.2 发布远程仓库1.2.1 创建项目1) 新…

deeplog中输出某个 event 的概率

1 实现之后效果 # import DeepLog and Preprocessor import numpy as np from deeplog import DeepLog import torch# Create DeepLog object deeplog DeepLog(input_size 10, # Number of different events to expecthidden_size 64 , # Hidden dimension, we suggest 64…

K8s----资源管理

目录 一、Secret 1、创建 Secret 1.1 用kubectl create secret命令创建Secret 1.2 内容用 base64 编码,创建Secret 2、使用方式 2.1 将 Secret 挂载到 Volume 中,以 Volume 的形式挂载到 Pod 的某个目录下 2.2 将 Secret 导出到环境变量中 二、Co…

大数据之LibrA数据库系统告警处理(ALM-12033 慢盘故障)

告警解释 系统每一秒执行一次iostat命令,监控磁盘I/O的系统指标,如果在60s内,svctm大于100ms的周期数大于30次则认为磁盘有问题,产生该告警。 更换磁盘后,告警自动恢复。 告警属性 告警ID 告警级别 可自动清除 1…

99% 用户都不知道的 Power BI / Power Query 隐藏功能

Power Query 有一个被糟糕的翻译耽误了的宝藏功能,我估计绝大多数的用户都没发现。 在 Power Query —— 视图 —— 数据预览 下,有几个奇怪的选项 “列分发”、“列配置文件”、“列质量”,从名字根本看不出来是做什么的! 看英文…