《昇思25天学习打卡营第6天|ResNet50图像分类》

news2024/12/24 20:21:57

写在前面

从本次开始,接触一些上层应用。
本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。

resnet

通过对论文的结论进行展示,说明了模型的功能,解决了卷积网络层数加大后模型的退化问题。20层和56层相比,层数越大,模型效果越差,因此resnet主要解决这种问题。hekaiming是真的强呀。

基本流程

  1. 整理模型数据
  2. 构建模型网络核心逻辑(ResidualBlockBase/ResidualBlock)
  3. 创建模型一层

构建网络的代码

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal

# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

class ResidualBlockBase(nn.Cell):
    expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlockBase, self).__init__()
        if not norm:
            self.norm = nn.BatchNorm2d(out_channel)
        else:
            self.norm = norm

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)
        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):
        """ResidualBlockBase construct."""
        identity = x  # shortcuts分支

        out = self.conv1(x)  # 主分支第一层:3*3卷积层
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

创建模型一层

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None  # shortcuts分支

    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion
    # 堆叠残差网络
    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)

创建模型

搭建一个4层的网络。

from mindspore import load_checkpoint, load_param_into_net


class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = nn.AvgPool2d()
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):

        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x

接下来,连接数据和模型网络,开始构建容易使用的网络。在这里设置了,模型残差的方法和每个block。

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrained_ckpt, replace=True)
        param_dict = load_checkpoint(pretrained_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    """ResNet50模型"""
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

模型训练和评估

并没有完全训练,使用了预训练的方法,下载了预训练的模型。

# 定义ResNet50网络
network = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc

有了模型网络,接下来需要进行模型训练。训练的过程要设置学习率、优化器和损失函数。

# 设置学习率
num_epochs = 1
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,
                        step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = network(inputs)
    loss = loss_fn(logits, targets)
    return loss


grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)


def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

之后进行多个epoch的迭代,实现模型训练的目标。

import mindspore.ops as ops


def train(data_loader, epoch):
    """模型训练"""
    losses = []
    network.set_train(True)

    for i, (images, labels) in enumerate(data_loader):
        loss = train_step(images, labels)
        if i % 100 == 0 or i == step_size_train - 1:
            print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %
                  (epoch + 1, num_epochs, i + 1, step_size_train, loss))
        losses.append(loss)

    return sum(losses) / len(losses)


def evaluate(data_loader):
    """模型验证"""
    network.set_train(False)

    correct_num = 0.0  # 预测正确个数
    total_num = 0.0  # 预测总数

    for images, labels in data_loader:
        logits = network(images)
        pred = logits.argmax(axis=1)  # 预测结果
        correct = ops.equal(pred, labels).reshape((-1, ))
        correct_num += correct.sum().asnumpy()
        total_num += correct.shape[0]

    acc = correct_num / total_num  # 准确率

    return acc

# 开始循环训练
print("Start Training Loop ...")

for epoch in range(num_epochs):
    curr_loss = train(data_loader_train, epoch)
    curr_acc = evaluate(data_loader_val)

    print("-" * 50)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, curr_loss, curr_acc
    ))
    print("-" * 50)

    # 保存当前预测准确率最高的模型
    if curr_acc > best_acc:
        best_acc = curr_acc
        ms.save_checkpoint(network, best_ckpt_path)

print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)

进行多轮训练之后,达到训练的目的,模型开始进行收敛,并且能够获取到最终的结果。

最后进行评估,这个并不复杂。

打开

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1951571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第2章 编译SDK

安装编译依赖 sudo apt-get update sudo apt-get install clang-format astyle libncurses5-dev build-essential python-configparser sconssudo apt-get install repo git ssh make gcc libssl-dev liblz4-tool \ expect g patchelf chrpath gawk texinfo chrpath diffstat …

springboot促进高等教育可持续发展管理平台-计算机毕业设计源码36141

摘 要 随着全球对可持续发展的日益关注,高等教育作为培养未来领导者和创新者的摇篮,其在推动可持续发展中的角色日益凸显。然而,传统的高等教育管理模式在应对复杂多变的可持续发展挑战时,显得力不从心。因此,构建一个…

stm32入门-----USART串口实现数据包的接收和发送

目录 前言 数据包 1.HEX数据包 2.文本数据包 C编程实现stm32收发数据包 1.HEX数据包的收发 2.文本数据包的收发 前言 前面几期讲解了USART串口发送数据和接收数据的原理,那本期在前面的基础上学习stm32 USART串口发送和接收数据包。本期包括两个项目&a…

数据库作业四

1. 修改 student 表中年龄( sage )字段属性,数据类型由 int 改变为 smallint : ALTER TABLE student MODIFY Sage SMALLINT; 2. 为 Course 表中 Cno 课程号字段设置索引,并查看索引: ALTER TABLE…

Linux系统下非root用户自行安装的命令切换为root权限时无法使用,提示comman not found解决办法

今天在开发的时候遇上了一个问题就是要去我们数据平台中进行数据的提取,数据存储用的是minio,一个MinIO部署由一组存储和计算资源组成,运行一个或多个 minio server 节点,共同作为单个对象存储库。独立的MinIO实例由具有单个 mini…

多区域DNS以及主从DNS的搭建

搭建多域dns服务器: 搭建DNS多区域功能(Multi-Zone DNS)主要是为了满足复杂网络环境下的多样化需求,提高DNS服务的灵活性、可扩展性和可靠性。 适应不同网络环境: 在大型组织、跨国公司或跨地域服务中,网…

微服务安全——SpringSecurity6详解

文章目录 说明SpringSecurity认证快速开始设置用户名密码基于application.yml方式基于Java Bean配置方式 设置加密方式自定义用户加载方式自定义登录页面前后端分离认证认证流程 SpringSecurity授权web授权:基于url的访问控制自定义授权失败异常处理方法授权:基于注解的访问控制…

2024上半年热门网络安全产品和工具TOP10_wiz安全产品

今年上半年,利用生成式人工智能(GenAI)的网络安全工具继续激增。许多供应商正在利用GenAI的功能来自动化安全运营中心(SOC)的工作,特别是在自动化日常活动方面,如收集威胁信息和自动创建查询。 …

element 结合 {} 实现自适应布局

通过el-row el-col 实现 例如 :xl“{ 1: 24, 2: 12, 3: 8, 4: 6 }[tableData.length] || 6” length 1 2 3 4 、代码数量为 1 2 3 4 >4 时不同卡片数量时尺寸的配置

LLM模型之基于MindSpore通过GPT实现情感分类

前言 # 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行!pip install mindnlp0.3.1 !pip install mindnlp !pip install jieba %env HF_ENDPOINThttps://hf-mirror.com 导入对应的包 import osimport m…

【因数之和】python求解方法

输入两个整数A和B,求A的B次方的因子和,结果对1000000007取模。 def mod_exp(base, exp, mod):result 1while exp > 0:if exp % 2 1:result (result * base) % modbase (base * base) % modexp // 2return resultdef sum_of_factors(n):total 0…

讨逆猴子剪切板,浏览器复制失败?

讨逆猴子剪切板,复制失败? 问题:本地开发情况下可以直接复制,公网就不行了…触发了安全机制。 const link 内容;navigator.clipboard.writeText(link);报错: 解决方案: if (navigator.clipboard &&…

使用大型语言模型进行文档解析(附带代码)

动机 多年来,正则表达式一直是我解析文档的首选工具,我相信对于许多其他技术人员和行业来说也是如此。 尽管正则表达式在某些情况下功能强大且成功,但它们常常难以应对现实世界文档的复杂性和多变性。 另一方面,大型语言模型提供了…

5.CSS学习(浮动)

浮动(float) 是一种传统的网页布局方式,通过浮动,可以使元素脱离文档流的控制,使其横向排列。 其编写在CSS样式中。 float:none(默认值) 元素不浮动。 float:left 设置的元素在其包含…

【计算机网络】单臂路由实现VLAN间路由实验

一:实验目的 1:掌握如何在路由器端口上划分子接口,封装dot1q协议,实现VLAN间的路由。 二:实验仪器设备及软件 硬件:RCMS-C服务器、网线、Windows 2019/2003操作系统的计算机等。具体为:路由器…

git学习笔记(总结了常见命令与学习中遇到的问题和解决方法)

前言 最近学习完git,学习过程中也遇到了很多问题,这里给大家写一篇总结性的博客,主要大概讲述git命令和部分难点问题(简单的知识点这里就不再重复讲解了) 一.git概述 1.1什么是git Git是一个分布式的版本控制软件。…

SQL Developer 连接 MySQL

服务: $ sudo systemctl list-unit-files | grep enabled | grep mysql mysqld.service enabled日志: mysqld.log # sudo grep temporary password /var/log/mysqld.log 2024-05-28T00:57:16.676383Z 6 [Note] [MY…

GB28181摄像头管理平台WVP视频平台SQL注入漏洞复现 [附POC]

文章目录 GB28181摄像头管理平台WVP视频平台SQL注入漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现 GB28181摄像头管理平台WVP视频平台SQL注入漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内…

AI带货直播软件开发会用到哪些源代码?

随着人工智能技术的迅猛发展,AI带货直播软件已成为电商领域的一股新兴力量,这类软件通过结合AI技术、虚拟形象和实时互动功能,极大地提升了用户体验和购买转化率。 在开发这类软件的过程中,源代码的编写至关重要,本文…

20 B端产品的数据分析

数据分析的价值 数据衡量业务:通过管理数据报表,可以快速衡量业务发展状态。 数据洞察业务:通过数据分析,可以找到业务发展的机遇。 数据驱动指导业务:基于数据,驱动业务决策,数据支撑决策。 …