昇思25天学习打卡营第13天|应用实践之ResNet50迁移学习

news2024/9/29 1:22:13

基本介绍

        今日的应用实践的模型是计算机实践领域中十分出名的模型----ResNet模型。ResNet是一种残差网络结构,它通过引入“残差学习”的概念来解决随着网络深度增加时训练困难的问题,从而能够训练更深的网络结构。现很多网络极深的模型或多或少都受此影响。今日的主要任务是使用ResNet50进行迁移学习,所谓的迁移学习是不从头到尾训练一个模型,而是在一个特别大的数据集上面训练得到一个预训练模型,然后使用该模型的权重作为初始化参数,最后应用于特定的任务。对特定的任务来说也可以对模型进行训练,但是需要冻结模型的某些参数,一般只训练模型的分类器部分,不会训练特征提取部分。下面我们详细讲讲今日的应用实践。

数据集准备

        本次使用的数据集是来自ImageNet数据集中抽取出来的狼狗数据集,每个类别大概有120张训练图像与30张验证图像,数据集可通过华为云OBS和相关API直接下载,然后进行加载。可借助MindSpore提供的API对数据集进行加载,同时,因为数据集在送入模型之前需做些处理,这些可封装到一个函数内,具体代码如下:

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]


    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)


    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set

通过上述操作后,我们可以很方便的调用数据集,无论是进行训练还是可视化,都可以。

模型搭建

        准备好数据集后,自然就是进行模型搭建,ResNet50是一个非常常见的模型,具体模型结构和不同AI框架下的代码都很容易获取,MindSpore官方也有相关的实现代码,我们直接使用,模型的代码如下:

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义,
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = nn.AvgPool2d()
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):

        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

训练模型

        我们采用定特征进行训练,需要冻结除最后一层之外的所有网络层,最后一层其实就是一个分类层,前面是特征提取层,MindSpore可以通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度,具体代码如下:

import mindspore as ms
import matplotlib.pyplot as plt
import os
import time

net_work = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool

# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

一切准备妥当后,便可以进行训练,由于有预训练模型,模型的训练速度非常快,比从头到尾训练快了好几倍。训练了5轮,效果就非常好了:

模型可视化预测

        有了训练好的模型,自然要看看训练得好不好,除了看评价指标,最直观的就是实际使用一下模型,由于这是图像分类任务,所以将其可视化。这一整个流程的代码如下:

def visualize_model(best_ckpt_path, val_ds):
    net = resnet50()
    # 全连接层输入层的大小
    in_channels = net.fc.in_channels
    # 输出通道数大小为狼狗分类数2
    head = nn.Dense(in_channels, 2)
    # 重置全连接层
    net.fc = head
    # 平均池化层kernel size为7
    avg_pool = nn.AvgPool2d(kernel_size=7)
    # 重置平均池化层
    net.avg_pool = avg_pool
    # 加载模型参数
    param_dict = ms.load_checkpoint(best_ckpt_path)
    ms.load_param_into_net(net, param_dict)
    model = train.Model(net)
    # 加载验证集的数据进行验证
    data = next(val_ds.create_dict_iterator())
    images = data["image"].asnumpy()
    labels = data["label"].asnumpy()
    class_name = {0: "dogs", 1: "wolves"}
    # 预测图像类别
    output = model.predict(ms.Tensor(data['image']))
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(5, 5))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()

调用该函数后,可视化结果如下:可以看出还是很准的

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1909061.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ScreenAI ——能理解从信息图表到用户界面的图像和文本算法解析

概述 论文地址:https://arxiv.org/pdf/2402.04615.pdf 信息图表(图表、示意图、插图、地图、表格、文档布局等)能够将复杂的数据和想法转化为简单的视觉效果,因此一直以来都被视为传播的重要元素。这种能力来自于通过布局和视觉线…

C#——序列化和反序列化概念

(1)序列化 在编程中,序列化是指将对象转换为可存储或传输的格式,例如将对象转换为 JSON 字符串或字节流。 (2)反序列化 在编程中,反序列化则是将存储或传输的数据转换回对象的过程。 序列化和反序列化经常用于数据的持久化、数据交换以及…

【计算机毕业设计】013新闻资讯微信小程序

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

【高中数学/指数函数】比较a=0.6^0.9 b=0.6^1.5 c=1.5^0.6的大小

【问题】 比较a0.6^0.9 b0.6^1.5 c1.5^0.6的大小 【解答】 指数函数y0.6^x是减函数&#xff0c;因为0.9<1.5,所以0.6^0.9>0.6^1.5,即a>b; 指数函数y1.5^x是增函数&#xff0c;1.5^0.6>1.5^01>0.6^0.9,即c>a; 综上&#xff0c;得出c>a>b的结论。 …

移动UI:该如何给页面降噪,给你支8招。

为什么我的移动UI页面看来这么复杂、为什么用户很迷茫&#xff0c;不知道如何操作&#xff0c;为什么很拥挤&#xff0c;核心原因还是页面噪声过多了&#xff0c;需要适当的降噪。 降低页面噪音&#xff0c;提高页面的简洁高效性是一个重要的设计目标。以下是一些降噪的设计方…

音频demo:使用opencore-amr将PCM数据与AMR-NB数据进行相互编解码

1、README a. 编译 编译demo 由于提供的.a静态库是在x86_64的机器上编译的&#xff0c;所以仅支持该架构的主机上编译运行。 $ make编译opencore-amr 如果想要在其他架构的CPU上编译运行&#xff0c;可以使用以下命令&#xff08;脚本&#xff09;编译opencore-amr[下载地…

终于搞定了通过两路蓝牙接收数据

一直想做无线传感器&#xff0c;通过蓝牙来接收数据&#xff0c;无奈因为arduino接收串口数据的一些问题&#xff0c;一直搁到现在。因为学校里给学生开了选修课&#xff0c;所以手边有一些nano和mega可以使用&#xff0c;所以就做了用两个nano加上两个蓝牙模块来发射数据&…

solidity:构造函数和修饰器、事件

构造函数​ 构造函数&#xff08;constructor&#xff09;是一种特殊的函数&#xff0c;每个合约可以定义一个&#xff0c;并在部署合约的时候自动运行一次。它可以用来初始化合约的一些参数&#xff0c;例如初始化合约的owner地址&#xff1a; address owner; // 定义owner变…

web前端开发——标签一

今天我来针对web前端开发讲解标签一 Html标签_标题&段落&换行 注释标签&#xff1a;Ctrl/ Ctrl/ &#xff0c;用户可能会获取到注释标签 注释的原则: •和代码逻辑一致 •尽量使用中文 •正能量 标题标签&#xff1a;<h1></h1> h1-h6 标题标签有6…

Nacos2.X源码分析:服务注册、服务发现流程

文章目录 Nacos2.1.X源码源码下载服务注册NacosClient端NacosServer端 服务发现NacosClient端NacosServer端 Nacos2.1.X源码 源码下载 源码下载地址 服务注册 官方文档&#xff0c;对于NamingService接口服务注册方法的说明 Nacos2.X 服务注册总流程图 NacosClient端 一个…

2024年浙江省高考分数一分一段数据可视化

下图根据 2024 年浙江高考一分一段表绘制&#xff0c;可以看到&#xff0c;竞争最激烈的分数区间在620分到480分之间。 不过&#xff0c;浙江是考两次取最大&#xff0c;不是很有代表性。看看湖北的数据&#xff0c;580分到400分的区段都很卷。另外&#xff0c;从这个图也可以…

【Mac】Folder Icons for mac(文件夹个性化图标修改软件)软件介绍

软件介绍 Folder Icons for Mac 是一款专为 macOS 设计的应用程序&#xff0c;主要用于个性化和定制你的文件夹图标。以下是它的主要特点和使用方法&#xff1a; 主要特点&#xff1a; 个性化文件夹图标 Folder Icons for Mac 允许用户为 macOS 上的任何文件夹定制图标。你…

k8s集群如kubeadm init和kube-flannel.yam问题

查看k8s中角色内容kubectl get all (显示pod和server以及delment) 删除应用资源选择删除先删除部署查看部署和pod没了服务还在&#xff0c;但资源和功能以及删除&#xff0c;删除服务kubectl delete 服务名&#xff08;部署名&#xff09;&#xff0c;get pods 获取默认空间的容…

Android C++系列:Linux进程(二)

1. fork #include <unistd.h> pid_t fork(void);子进程复制父进程的0到3g空间和父进程内核中的PCB,但id号不同。 fork调用一次返回两次 父进程中返回子进程ID子进程中返回0读时共享,写时复制#include <sys/types.h> #include <unistd.h> #include <…

高颜值官网(4):酒店民宿网站12个,看着看着就醉了。

对于高星级酒店或者高端酒店来说&#xff0c;拥有一个高颜值的官方网站是非常重要的。一个精美、专业的网站设计可以有效地展现酒店的品牌形象和服务质量&#xff0c;吸引目标客户群体并提高预订转化率。 这次分享12个&#xff0c;都是超高颜值的。

机器学习中的可解释性

「AI秘籍」系列课程&#xff1a; 人工智能应用数学基础 人工智能Python基础 人工智能基础核心知识 人工智能BI核心知识 人工智能CV核心知识 为什么我们需要了解模型如何进行预测 我们是否应该始终信任表现良好的模型&#xff1f;模型可能会拒绝你的抵押贷款申请或诊断你患…

break 和 continue 的区别与用法

break 和 continue 的区别与用法 1、break 语句2、continue 语句3、总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在JAVA中&#xff0c;break 和 continue 是两种常用的控制流语句&#xff0c;它们主要用于在循环结构中改变程序的执行…

怎样卸载电脑上自带的游戏?

卸载电脑上自带的游戏通常是一个简单的过程&#xff0c;以下是几种常见的方法&#xff0c;您可以根据自己的操作系统版本选择相应的步骤进行操作&#xff1a; 方法一&#xff1a;通过“设置”应用卸载&#xff08;适用于Windows 10和Windows 11&#xff09; 1. 点击开始菜单&…

【链表】- 链表相交

1. 对应力扣题目连接 链表相交 2. 实现思路 链表详情&#xff1a; 考虑使用双指针&#xff1a; 解法一&#xff1a; 具体代码&#xff0c;详见3. 实现案例代码解析&#xff1a; 思路&#xff1a;因为链表按照如图的箭头走向&#xff0c;走的总路程是相等的&#xff0c;一…

泰迪智能科技受邀北京物资学院共讨校企合作交流

为落实“访企拓岗促就业”专项行动工作要求&#xff0c;推动科研成果向实际应用转化&#xff0c;培养适应新时代需求的高素质人才&#xff0c;拓宽毕业生就业渠道&#xff0c;提升就业竞争力。7月1日&#xff0c;广东泰迪智能科技股份有限公司区域总监曹玉红到访北京物资学院开…