昇思25天学习打卡营第13天 | ResNet50迁移学习

news2024/9/24 19:14:32

昇思25天学习打卡营第13天 | ResNet50迁移学习

文章目录

  • 昇思25天学习打卡营第13天 | ResNet50迁移学习
    • 数据集
      • 加载数据集
      • 数据集可视化
    • 模型训练
      • 固定特征
    • 总结
    • 打卡

在实际应用场景中,由于训练数据集不足,很少从头开始训练整个网络。普遍做法是在一个非常大的基础数据集上训练得到一个 预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务。

数据集

使用ResNet50在狗与狼分类数据集上进行训练,数据集中图片来自ImageNet,每个分类大约120张训练图像和30张验证图像。

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"

download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

加载数据集

使用mindspore.dataset.ImageFolderDataset接口来加载数据集,并进行数据增强::

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 5                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"

# 创建训练数据集

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]


    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)


    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set


dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()

数据集可视化

通过next迭代访问数据集,一次获取batch_size个图像及标签:

data = next(dataset_train.create_dict_iterator())
images = data["image"]
labels = data["label"]

print("Tensor of image", images.shape)
print("Labels:", labels)

import matplotlib.pyplot as plt
import numpy as np

# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "dogs", 1: "wolves"}

plt.figure(figsize=(5, 5))
for i in range(4):
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    plt.title(class_name[int(labels[i].asnumpy())])
    plt.axis("off")

plt.show()

模型训练

使用ResNet50网络,通过将pretrained设置为True来加载预训练模型:

class ResNet(nn.Cell):
	 def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:
	 	# ...
	 def construct(self, x):
	 	# ...

	def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)    

通过resnet50接口创建网络模型,如果设置pretrained=True,则下载并加载预训练模型到ResNet50中。

固定特征

使用固定特征进行训练时,需要冻结除最后一层之外的所有网络层。通过设置requires_grad=False冻结参数,使其不在反向传播中计算梯度。

import mindspore as ms
import matplotlib.pyplot as plt
import os
import time

net_work = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool

# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

总结

这一小节对预训练模型引入的原因进行了说明,通过加载预训练模型的参数到ResNet50模型中进行参数初始化,从而加速网络的训练过程。可以通过设置参数requires_grad=False来冻结参数,使其作为固定特征,不参与梯度计算与参数优化。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全自主巡航无人机项目思路:STM32/PX4 + ROS + AI 实现从传感融合到智能规划的端到端解决方案

1. 项目概述 本项目旨在设计并实现一款高度自主的自动巡航无人机系统。该系统能够按照预设路径自主飞行,完成各种巡航任务,如电力巡线、森林防火、边境巡逻和灾害监测等。 1.1 系统特点 基于STM32F4和PX4的高性能嵌入式飞控系统多传感器融合技术实现精…

【Caffeine】⭐️SpringBoot 项目整合 Caffeine 实现本地缓存

目录 🍸前言 🍻一、Caffeine 🍺二、项目实践 2.1 环境准备 2.2 项目搭建 2.3 接口测试 ​💞️三、章末 🍸前言 小伙伴们大家好,缓存是提升系统性能的一个不可或缺的工具,通过缓存可以避免大…

【答疑】8080或其他端口被占用如何解决?

我们在做项目时总会遇到各式各样千奇百怪的问题,但基本上每个刚接触tomcat的小白早晚都会遇到一个问题——8080端口被占用: 报错信息很容易理解,端口8080已经被使用了,那么这时我们该如何知道是谁使用了这个端口并关掉它呢&#x…

c++基础语法之内联函数

引言: 在C编程中,性能优化是一个永恒的话题。内联函数(Inline Functions)作为提高程序执行效率的一种重要手段,在编译器优化过程中扮演着关键角色。 一、内联函数的基本概念 定义:内联函数是C中一种特殊…

C#可空类型与数组

文章目录 可空类型NULL合并运算符(??)数组数组声明数组初始化数组赋值数组访问多维数组交错数组数组类数组类的常用属性数组类的常用方法 可空类型 C#提供了一种特殊的数据类型,nullable类型(可空类型),可…

k8s字段选择器

文章目录 一、概述二、基本语法三、支持的字段1、错误示例2、支持的字段列表 四、支持的操作符1、示例 五、跨多种资源类型使用字段选择器 一、概述 在Kubernetes中,字段选择器(Field Selectors)和标签选择器(Label Selectors&am…

MySQL更新和删除(DML)

DML-修改数据 UPDATE 表名 SET 字段1 值1,字段2值2,....[WHERE 条件] 例如 1.这个就是把employee表中的这个name字段里面并且id字段为1的名字改为itheima update employee set nameitheima where id 1; 2.这个就是把employee这个表中的name字段和…

string 的完整介绍

1.string类 还记得我们数据结构学的串吗,现在在c中,我们有了c提供的标准库,它是一个写好的类,非常方便使用 1. string是表示字符串的字符串类 2. 该类的接口与常规容器的接口基本相同,再添加了一些专门用来操作stri…

被边缘化后:飞猪“摆烂”,庄卓然求变?

文:互联网江湖 作者:刘致呈 走向独立的飞猪,在最近两年是越来越放飞自我了。 从“酱香大床房”的硬蹭热度,到“攻城价”被京都威斯汀酒店声明“打假”; 从年初的大数据杀熟争议,到最近被12036退票点名&a…

VLM技术介绍

1、背景 视觉语言模型(Visual Language Models)是可以同时从图像和文本中学习以处理许多任务的模型,从视觉问答到图像字幕。 视觉识别(如图像分类、物体保护和语义分割)是计算机视觉研究中一个长期存在的难题&#xff…

据传 OpenAI秘密研发“Strawberry”项目

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

ollama + lobechat 搭建自己的多模型助手

背景 人工智能已经推出了快2年了,各种模型和插件,有渐渐变成熟的趋势,打造一个类似 hao123网站的人工智能模型入口,也变得有需求了。用户会去比较多个ai给出的答案,作为程序员想拥有一台自己的GPU服务器来为自己服务。…

GuLi商城-商品服务-API-品牌管理-统一异常处理

每个方法都加这段校验太麻烦了 准备做一个统一异常处理@ControllerAdvice 后台代码: package com.nanjing.gulimall.product.exception;import com.nanjing.common.exception.BizCodeEnum; import com.nanjing.common.utils.R; import lombok.extern.slf4j.Slf4j; import org…

【UE5.1】Chaos物理系统基础——06 子弹破坏石块

前言 在前面我们已经完成了场系统的制作(【UE5.1】Chaos物理系统基础——02 场系统的应用_ue5)以及子弹的制作(【UE5.1 角色练习】16-枪械射击——瞄准),现在我们准备实现的效果是,角色发射子弹来破坏石柱。…

【算法】单调队列

一、什么是单调队列 单调队列是一种数据结构,其特点是队列中的元素始终保持单调递增或递减,主要用于维护队列中的最小值或最大值。 不同于普通队列只能从队头出队、队尾入队,单调队列为了维护其特征,还允许从队尾出队 不管怎么…

【学习笔记】4、组合逻辑电路(上)

数字电路的分类:组合逻辑电路,时序逻辑电路。本章学习组合逻辑电路。 4.1 组合逻辑电路的分析 给定的逻辑电路,确定其逻辑表达式,列出真值表,得到简化后的逻辑表达式,分析得到其功能。 3位奇校验电路 &…

OSPF.综合实验

1、首先将各个网段基于172.16.0.0 16 进行划分 1.1、划分为4个大区域 172.16.0.0 18 172.16.64.0 18 172.16.128.0 18 172.16.192.0 18 四个网段 划分R4 划分area2 划分area3 划分area1 2、进行IP配置 如图使用配置指令进行配置 ip address x.x.x.x /x 并且将缺省路由…

MQTT——Mosquitto使用(Linux订阅者+Win发布者)

前提:WSL(Ubuntu22)作为订阅者,本机Win10作为发布者。 1、Linux安装Mosquitto 命令行安装。 sudo apt-get install mosquitto 以上默认只安装了mosquitto的服务,不带测试客户端工具mosquitto_sub和mosquitto_pub。如…

持续学习中避免灾难性遗忘的Elastic Weight Consolidation Loss数学原理及代码实现

训练人工神经网络最重要的挑战之一是灾难性遗忘。神经网络的灾难性遗忘(catastrophic forgetting)是指在神经网络学习新任务时,可能会忘记之前学习的任务。这种现象特别常见于传统的反向传播算法和深度学习模型中。主要原因是网络在学习新数据…

全网最详细单细胞保姆级分析教程(二) --- 多样本整合

上一节我们研究了如何对单样本进行分析,这节我们就着重来研究一下如何对多样本整合进行研究分析! 1. 导入相关包 library(Seurat) library(tidyverse) library(patchwork)2. 数据准备 # 导入单样本文件 dir c(~/Desktop/diversity intergration/scRNA_26-0_filtered_featur…