240622_昇思学习打卡-Day4-ResNet50迁移学习

news2024/11/26 10:17:48

240622_昇思学习打卡-Day4-ResNet50迁移学习

我们对事物的认知都是一点一点积累出来的,往往借助已经认识过的东西,可以更好地理解和认识新的有关联的东西。比如一个人会骑自行车,我们让他去骑摩托车他也很快就能学会,比如已经学会C++,现在让他去学python他也很容易就能理解。这种情况我们一般称为举一反三。反言之,我们从原始部落找出来一个人(仅作举例),指着摩托车让他骑,可能是一件特别难的事,因为他对这个领域没有丝毫的认知和理解,在实现这件事上就会特别困难。

映射到神经网络上也是一样的道理,如果我们在训练时不导入预训练权重,他就像一个没有见过现代社会的原始人,学任何东西都特别慢,学习成本特别高,但如果我们导入了相似模型结构下针对别的任务的训练权重(比如训练识别自行车),用来训练识别摩托车,我们只需要改变网络最后的分类层,即可得到比较好的训练效果,可以大大缩小模型训练的时间。

原理是我在这么多层神经网络的训练下,已经明白了轱辘(车轮)长什么样,把手长什么样,最后的分类层只是区分出来什么是自行车,你现在给我一堆摩托的照片,我就可以去寻找两者的相似处,我对轱辘和把手的认知就不用从0开始重新学习,只需要进行微调,比如摩托车的轱辘比自行车大一点,摩托车的车把手比自行车大一点。基于以前已经学习到的东西,可以大大缩小训练成本。

迁移学习就是这个道理。我们在神经网络技术的发展中,针对不同的任务,不可能每个网络都从0开始训练,那样需要的数据集及成本都是不可承受的。

本文以ResNet50迁移学习为例展开讲解,在MindSpore架构下运行。

数据准备

下载数据集

本文用到狗与狼分类数据集,使用download接口下载,也可自行下载放在项目当前目录下

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"

download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

数据目录结构如下:

text
datasets-Canidae/data/
└── Canidae
    ├── train
    │   ├── dogs
    │   └── wolves
    └── val
        ├── dogs
        └── wolves

首先定义一些超参数:

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 5                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

加载数据集以及做一些数据增强(本文所用的狼狗数据集属于ImageNet数据集,其典型mean和std值分别为[0.485, 0.456, 0.406]和[0.229, 0.224, 0.225],所以代码中直接使用):

# 导入MindSpore库,用于深度学习框架
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 定义训练数据集和验证数据集的路径
# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"

# 定义函数,用于创建Canidae分类任务的训练集或验证集
# 参数dataset_path: 数据集路径,usage: 数据集的用途,"train"或"val"
# 返回处理后的数据集
# 创建训练数据集
def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    # 初始化ImageFolderDataset,使用多线程并打乱数据顺序
    # 使用mindspore.dataset.ImageFolderDataset接口来加载数据集
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 定义数据预处理的参数
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    # 根据数据集的用途(训练或验证)选择不同的数据增强操作
    if usage == "train":
        # 训练集的数据增强操作,包括随机裁剪、水平翻转、归一化等
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # 验证集的数据增强操作,主要包括解码、缩放、中心裁剪、归一化等
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]

    # 对数据集应用预处理操作
    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    # 将数据集分批处理,指定批大小
    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set

# 创建训练数据集和验证数据集,并获取每个数据集的步长(即数据集的大小)
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()

数据集可视化

mindspore.dataset.ImageFolderDataset接口中加载的训练数据集返回值为字典,用户可通过 create_dict_iterator 接口创建数据迭代器,使用 next 迭代访问数据集。前面 batch_size 设为18,所以使用 next 一次可获取18个图像及标签数据。

# 获取训练数据集的第一个批次数据,是18张图像及标签数据。
data = next(dataset_train.create_dict_iterator())
images = data["image"]
labels = data["label"]

print("Tensor of image", images.shape)
print("Labels:", labels)

'''
执行结果为
Tensor of image (18, 3, 224, 224)    # 意思是这一批有18张3通道(RGB通道)的长224宽224的图像
Labels: [1 1 0 1 1 0 1 0 0 0 0 0 0 0 0 1 0 1]    # 因为该任务是一个二分类任务,所以类别只有简单的0和1
'''

目前拿到的数据我们可以先看看长什么样,展示图像及标题,标题为对应的label名称

# 导入matplotlib.pyplot库用于绘图
import matplotlib.pyplot as plt
# 导入numpy库用于处理数组
import numpy as np

# 定义一个字典,映射类别编号到类别名称
# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "dogs", 1: "wolves"}

# 创建一个5x5大小的画布
plt.figure(figsize=(5, 5))
# 循环遍历4个图像
for i in range(4):
    # 获取当前图像的数据和标签
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 将图像数据从HWC格式转换为RGB格式
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    # 对图像进行预处理,包括归一化和颜色空间转换
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 在画布上创建子图,并显示当前图像
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    # 设置子图标题为图像的类别名称
    plt.title(class_name[int(labels[i].asnumpy())])
    # 关闭子图的坐标轴显示
    plt.axis("off")

# 显示画布上的所有图像
plt.show()

image-20240622221036216

训练模型

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal


weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

今日就先写这些,明天学完之后汇总。

不知道为什么我在这篇里面点编辑然后发表之后变成了一篇新的博客,那这里就附上后续的博客地址吧
240622_昇思学习打卡-Day4-5-ResNet50迁移学习

打卡图片:
在这里插入图片描述







本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1875347.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32开发方式的演变与未来展望

一、STM32开发方式的演变 自2007年STM32微控制器首次亮相以来,其开发方式经历了从寄存器到标准库,再到HAL(硬件抽象层)的演变。 1.寄存器开发(2007年-2010年代初) 最初,由于初期缺乏足够的软…

Cyuyanzhong的内存函数

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、memcpy函数的使用与模拟实现二、memmove函数的使用和模拟实现三、memset函数与memcmp函数的使用(一)、memset函数(内存块…

Qt creator实现一个简单计算器

目录 1 界面设计 2 思路简介 3 代码 目录 1 界面设计 ​2 思路简介 3 代码 3.1 widget.h 3.2 widget.c 4 完整代码 在这里主要记载了如何使用Qt creator完成一个计算器的功能。该计算器可以实现正常的加减乘除以及括号操作,能实现简单的计算器功能。 1 界…

.NET使用CsvHelper快速读取和写入CSV文件

前言 在日常开发中使用CSV文件进行数据导入和导出、数据交换是非常常见的需求,今天我们来讲讲在.NET中如何使用CsvHelper这个开源库快速实现CSV文件读取和写入。 CsvHelper类库介绍 CsvHelper是一个.NET开源、快速、灵活、高度可配置、易于使用的用于读取和写入C…

Spring Boot集成vavr快速入门demo

1.什么是vavr? 初闻vavr,感觉很奇怪,咋这个名字,后面看到它的官网我沉默了,怀疑初创团队付费资讯了UC震惊部如何取名字,好家伙,vavr就是java这四个字倒过来,真的是’颠覆’了java……

如何成为-10x工程师:反向教学大数据开发实际工作中应如何做

10x 工程师可能是神话,但 -10x 工程师确实存在。要成为 -10x 工程师,只需每周浪费 400 小时的工程时间。结合以下策略: 目录 如何使 10 名工程师的输出无效化改变需求大数据开发示例 创建 400 小时的繁忙工作任务示例大数据开发示例 创建 400…

心理辅导平台系统

摘 要 中文本论文基于Java Web技术设计与实现了一个心理辅导平台。通过对国内外心理辅导平台发展现状的调研,本文分析了心理辅导平台的背景与意义,并提出了论文研究内容与创新点。在相关技术介绍部分,对Java Web、SpringBoot、B/S架构、MVC模…

Unable to get expected results using BM25 or any search functions in Weaviate

题意:使用 Weaviate 中的 BM25 或任何搜索函数都无法获得预期结果 问题背景: I have created a collection in Weaviate, and ingested some documents into the Weaviate database using LlamaIndex. When I used the default search, I found that it…

高精度除法的实现

高精度除法与高精度加法的定义、前置过程都是大致相同的,如果想了解具体内容,可以移步至我的这篇博客:高精度加法计算的实现 在这里就不再详细讲解,只讲解主体过程qwq 主体过程 高精度除法的原理和小学学习的竖式除法是一样的。 …

Chrome导出cookie的实战教程

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

备份SQL Server数据库并还原到另一台服务器

我可以将SQL Server数据库备份到另一台服务器吗? 有时您可能希望将 SQL数据库从一台服务器复制到另一台服务器,或者将计算机复制到计算机。可能的场景包括测试、检查一致性、从崩溃的机器恢复数据库、在不同的机器上处理同一个项目等。 是的&#xff0c…

Vue+Proj4Leaflet实现地图瓦片(Nginx代理本地地图瓦片为网络url)加载并实现CRS投影转换(附资源下载)

场景 Leaflet中加载离线OSM瓦片地图(使用OfflineMapMaker切割下载离线png地图文件): Leaflet中加载离线OSM瓦片地图(使用OfflineMapMaker切割下载离线png地图文件)_offline map maker-CSDN博客 Leaflet快速入门与加载OSM显示地图: Leaflet快速入门与…

等保测评练习卷14

等级保护初级测评师试题14 姓名: 成绩: 判断题(10110分) 1. 方案编制活动中测评对象确定、测评指…

sql想查询一个数据放在第一个位置

sql想查询一个数据放在第一个位置 背景:比如在查询后台账号的时候想将管理员账号始终放在第一个,其他账号按照创建时间倒序排序, 可以这样写sql: SELECTid,create_time FROMuser ORDER BY CASEWHEN id 1 THEN1 ELSE 2 END ASC, create_time DESC 运行截图: 可以看到id…

企业源代码加密软件丨透明加密技术是什么

在一个繁忙的软件开发公司中,两位员工小李和小张正在讨论源代码安全的问题。 “小张,你有没有想过我们的源代码如果被泄露了怎么办?”小李担忧地问。 “是啊,这是个大问题。源代码是我们的核心竞争力,一旦泄露&#…

CentOS 8 Stream 上安装 Docker 遇到的一些问题

curl 命令无法连接到 URL,可能是由于网络问题或 IPv6 配置问题。我们可以使用以下方法来解决这个问题: 强制使用 IPv4: 尝试使用 curl 强制使用 IPv4 进行连接: curl -4 -fsSL https://get.docker.com -o get-docker.sh 检查网络…

Python28-2 机器学习算法之SVM(支持向量机)

SVM(支持向量机) 支持向量机(Support Vector Machine,SVM)是一种用于分类和回归分析的监督学习模型,在机器学习领域中被广泛应用。SVM的目标是找到一个最佳的分割超平面,将不同类别的数据分开&…

【详细】CNN中的卷积计算是什么-实例讲解

本文来自《老饼讲解-BP神经网络》https://www.bbbdata.com/ 目录 一、 CNN的基础卷积计算1.1.一个例子了解CNN的卷积计算是什么1.2.卷积层的生物意义 二、卷积的拓展:多输入通道与多输出通道2.1.多输入通道卷积2.2.多输出通道卷积 三、卷积的实现3.1.pytorch实现卷积…

夏令营1期-对话分角色要素提取挑战赛-第①次打卡

零基础入门大模型技术竞赛 简介: 本次学习是 Datawhale 2024 年 AI 夏令营第一期,学习活动基于讯飞开放平台“基于星火大模型的群聊对话分角色要素提取挑战赛”开展实践学习。 适合想 入门并实践大模型 API 开发、了解如何微调大模型的学习者参与 快来…

Windows系统开启自带虚拟机功能Hyper-V

前言 最近有小伙伴咨询:Windows系统上有自带的虚拟机软件吗? 答案肯定是有的。它就是Hyper-V,但很多小伙伴都不知道怎么打开这个功能。 今天小白就带大家来看看如何正确打开这个Windows自带的虚拟机功能Hyper-V。 开始之前,你…