人工智能(pytorch)搭建模型15-手把手搭建MnasNet模型,并实现模型的训练与预测

news2024/11/18 21:28:00

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型15-手把手搭建MnasNet模型,并实现模型的训练与预测,本文将介绍MnasNet模型的原理,并使用PyTorch框架构建一个MnasNet模型用于图像分类任务,让大家充分了解该模型。

文章将分为以下几个部分:

  1. MnasNet模型简介
  2. MnasNet模型的实现
  3. 数据集准备
  4. 模型训练
  5. 模型测试
  6. 结论

1. MnasNet模型简介

MnasNet(Mobile Neural Architecture Search Network)是一种通过搜索得到的高效卷积神经网络,最早由Google于2018年提出。MnasNet的主要特点是在保证模型的高性能的同时,尽量降低计算复杂度和参数数量,适用于移动设备等资源有限的场景。

MnasNet的核心思想是利用神经结构搜索(Neural Architecture Search, NAS)技术来找到最优的模型结构。NAS的目标是在给定任务和硬件平台的约束下,自动搜索出性能最佳的神经网络结构。MnasNet采用的搜索空间主要包括卷积层、深度可分离卷积层以及倒残差结构等基本操作。
在这里插入图片描述
MnasNet模型的数学原理可以用以下公式表示:

假设输入图像为 x ∈ R H × W × C x\in \mathbb{R}^{H\times W\times C} xRH×W×C,其中 H H H W W W C C C分别表示图像的高、宽和通道数。MnasNet模型可以看作一个函数 f ( x ; θ ) f(x;\theta) f(x;θ),其中 θ \theta θ表示模型的参数,包括卷积核、批量归一化参数、全连接层参数等。

MnasNet模型的核心是自动化神经架构搜索技术,可以自动搜索最佳的神经网络架构。假设搜索得到的最佳神经网络架构为 A A A,则模型的输出可以表示为:

f ( x ; θ A ) = f A ( x ; θ A ) f(x;\theta_A)=f_A(x;\theta_A) f(x;θA)=fA(x;θA)

其中 f A ( x ; θ A ) f_A(x;\theta_A) fA(x;θA)表示使用神经网络架构 A A A搭建的模型, θ A \theta_A θA表示模型 A A A的参数。模型 A A A的参数由两部分组成,即共享参数和非共享参数,可以表示为:

θ A = { w , α } \theta_A=\{w,\alpha\} θA={w,α}

其中 w w w表示共享参数, α \alpha α表示非共享参数。共享参数在不同的神经网络架构之间是共享的,而非共享参数则是每个神经网络架构独立的。

由于自动化神经架构搜索技术的存在,MnasNet模型可以在保证准确率的前提下,显著减小模型的参数量和计算量,从而在移动设备上得到更好的应用。

2. MnasNet模型的实现

以下是用PyTorch实现MnasNet模型的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

class _InvertedResidual(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor):
        super(_InvertedResidual, self).__init__()
        hidden_dim = round(in_ch * expansion_factor)
        self.use_residual = in_ch == out_ch and stride == 1
        
        layers = []
        if expansion_factor != 1:
            layers.append(nn.Conv2d(in_ch, hidden_dim, 1, 1, 0, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))
        
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, kernel_size//2, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True),
            nn.Conv2d(hidden_dim, out_ch, 1, 1, 0, bias=False),
            nn.BatchNorm2d(out_ch),
        ])

        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        if self.use_residual:
            return x + self.layers(x)
        else:
            return self.layers(x)

class MnasNet(nn.Module):
    def __init__(self, num_classes=1000, alpha=1.0):
        super(MnasNet, self).__init__()
        self.alpha = alpha
        self.num_classes = num_classes

        def conv_dw(in_ch, out_ch, stride):
            return _InvertedResidual(in_ch, out_ch, 3, stride, 1)

        def conv_pw(in_ch, out_ch, stride):
            return _InvertedResidual(in_ch, out_ch, 1, stride, 6)

        def make_layer(in_ch, out_ch, num_blocks, stride):
            layers = [conv_pw(in_ch, out_ch, stride)]
            for _ in range(num_blocks - 1):
                layers.append(conv_pw(out_ch, out_ch, 1))
            return nn.Sequential(*layers)

        # 构建MnasNet模型
        self.model = nn.Sequential(
            nn.Conv2d(3, int(32 * alpha), 3, 2, 1, bias=False),
            nn.BatchNorm2d(int(32 * alpha)),
            nn.ReLU6(inplace=True),
            make_layer(int(32 * alpha), int(16 * alpha), 1, 1),
            make_layer(int(16 * alpha), int(24 * alpha), 2, 2),
            make_layer(int(24 * alpha), int(40 * alpha), 3, 2),
            make_layer(int(40 * alpha), int(80 * alpha), 4, 2),
            make_layer(int(80 * alpha), int(96 * alpha), 2, 1),
            make_layer(int(96 * alpha), int(192 * alpha), 4, 2),
            make_layer(int(192 * alpha), int(320 * alpha), 1, 1),
            nn.Conv2d(int(320 * alpha), 1280, 1, 1, 0, bias=False),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.2),
            nn.Linear(1280, self.num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        x = self.classifier(x)
        return x

3. 数据集准备

我们将在CIFAR-10数据集上训练和测试我们的MnasNet模型。CIFAR-10数据集包含10个类别的60000张32x32彩色图像,每个类别有6000张图像。其中50000张用于训练,10000张用于测试。

准备数据集的代码如下:

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

4. 模型训练

接下来,我们将使用训练集对MnasNet模型进行训练,并在每个epoch后输出训练损失值和准确率。训练代码如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
mnasnet = MnasNet(num_classes=num_classes, alpha=0.5)
mnasnet = mnasnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mnasnet.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# 学习率调整策略
def adjust_learning_rate(optimizer, epoch):
    lr = 0.1
    if epoch >= 80:
        lr = 0.01
    if epoch >= 120:
        lr = 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

num_epochs = 150
for epoch in range(num_epochs):
    adjust_learning_rate(optimizer, epoch)
    mnasnet.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = mnasnet(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(f'Epoch: {epoch+1}, Loss: {train_loss/(batch_idx+1)}, Acc: {100.*correct/total}')

5. 模型测试

训练完成后,我们使用测试集对模型进行测试,输出测试损失值和准确率。测试代码如下:

mnasnet.eval()
test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = mnasnet(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

print(f'Test Loss: {test_loss/(batch_idx+1)}, Acc: {100.*correct/total}')

6. 结论

本文主要介绍了MnasNet模型的搭建与训练,测试,MnasNet特点在于使用了自动化神经架构搜索技术,可以在保证准确率的前提下,显著减小模型的参数量和计算量,从而在移动设备上得到更好的应用。

MnasNet模型的搭建与训练的总结:

数据集准备:首先需要准备适当的数据集,包括训练集、验证集和测试集。可以使用公共数据集,如ImageNet,也可以使用自己收集的数据集。

模型架构搜索:MnasNet采用神经架构搜索技术,可以自动搜索最佳的神经网络架构。这个过程需要使用大量的计算资源和时间,可以在GPU或者云端进行。

模型搭建:在得到最佳的神经网络架构之后,需要将其搭建成一个可以训练的模型。这个过程可以使用深度学习框架来实现,如TensorFlow、PyTorch等。

模型训练:训练模型需要选择合适的优化算法和损失函数,同时设置合适的超参数,如学习率、批量大小等。训练过程中,可以使用一些技巧来提高模型的性能,如数据增强、学习率调整等。

模型评估:在训练完成后,需要使用验证集或测试集来评估模型的性能。可以使用一些指标来评估模型的准确率、召回率、F1值等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/691935.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1、电商数仓(用户行为采集平台)数据仓库概念、用户行为日志、业务数据、模拟数据、用户行为数据采集模块、日志采集Flume

1、数据仓库概念 数据仓库( Data Warehouse ),是为企业制定决策,提供数据支持的。可以帮助企业,改进业务流程、提高产品质量等。 数据仓库的输入数据通常包括:业务数据、用户行为数据和爬虫数据等。 业务数…

Linux——使用第三方库链接的方式——动态式

回顾上文: (122条消息) Linux使用第三方库链接的使用方式——静态式_橙予清的zzz~的博客-CSDN博客https://blog.csdn.net/weixin_69283129/article/details/131414804?spm1001.2014.3001.5502 上篇文章中,我讲到了关于链接第三方库作为静态库的使…

股票技术分析方法综述

文章目录 K线均线MACDKDJ和RSIBOLL线趋势理论、支撑位和压力位形态理论量价关系理论道氏理论波浪理论江恩理论缠论自定义指标 K线 K线的组合形态是K线技术分析中的重要部分,包括早晨之星、黄昏之星、红三兵、黑三兵等。 早晨之星:由三根K线组成&#x…

OpenGL 抗锯齿

1.简介 你可以看到,我们只是绘制了一个简单的立方体,你就能注意到它存在锯齿边缘。 可能不是非常明显,但如果你离近仔细观察立方体的边缘,你就应该能够看到锯齿状的图案。如果放大的话,你会看到下面的图案&#xff1a…

家校互动、班级管理系统

最近做了一款使用若依开源框架搭建的一款家校互动、班级管理的平台,采用uniapp作为APP端,原生小程序作为小程序的家长端。

软件测试的概念与过程(软件测试的历史、概念、结构、过程)

软件测试的概念与过程----软件测试的历史 软件测试的历史软件的概念软件的结构软件测试的过程 软件测试的历史 1、早期的的软件开发过程中,将测试“调试”,目的是纠正软件已经知道的故障,常常有开发人员自己去完成这部分工作。 2、1957年&…

使用数据集工具

一.数据集工具介绍 HuggingFace通过API提供了统一的数据集处理工具,它提供的数据集如下所示: 该界面左侧可以根据不同的任务类型、类库、语言、License等来筛选数据集,右侧为具体的数据集列表,其中有经典的glue、super_glue数据集…

Unity | HDRP高清渲染管线学习笔记:材质系统Lit着色器

目录 一、Lit着色器 1. Surface Options 2. Surface Inputs(表面输入) 3. Transparency Inputs 二、HDRP渲染优先级 我们可以把现实世界中的物体分成不透明物体和透明物体(其中包括透明或者半透明)。在实时渲染时&#xff0c…

Debian二次开发网关支持Docker+RS485+网口

随着物联网技术的不断发展,瑞芯微边缘计算网关作为一种集成多种接口和功能的智能网关,逐渐成为了物联网领域中的热门产品。本文将详细介绍瑞芯微边缘计算网关的特点和优势,并探讨其在实际应用中的广泛应用。 瑞芯微Linux系统边缘计算网关是一…

【Java】 Java 私有接口方法的使用

本文仅供学习参考! 相关教程地址: https://www.baeldung.com/java-interface-private-methods https://www.geeksforgeeks.org/private-methods-java-9-interfaces/ https://www.runoob.com/java/java9-private-interface-methods.html 接口是定义一组方…

java之路—— SpringMVC的常用注解解析以及作用、应用

创作不易,真的希望能给个免费的小 文章目录 1、Controller2、RequestMapping3.GetMapping、PostMapping、PutMapping、DeleteMapping4. RequestParam5.PathVariable6.RequestHeader7.CookieValue8.RequestBody9.ResponseBody10.SessionAttribute11.ControllerAdvice…

二层、三层交换机是什么?有什么区别?

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 前言 本文将讲解二层交换机和三层交换机是什么,以及他们的区别。 目录 一、二层交换机是什么? 二、二层交换机的主…

本地生活多城市合伙人系统开发

本地生活多城市合伙人项目是一种基于本地生活服务的创业项目,旨在为各个城市的居民提供方便、实惠、高品质的生活服务。该项目通过招募多个城市的合伙人,建立完整的本地生活服务平台和供应链体系,覆盖不同类型的本地生活服务,如餐…

Nginx的Rewrite(地址重定向)

目录 前言 一、Rewrite 跳转场景 二、Rewrite 跳转实现 三、Rewrite实际场景 3.1Nginx跳转需求的实现方式 3.2rewrite放在 server{},if{},location{}段中 3.3对域名或参数字符串 四、Rewrite正则表达式 五、Rewrite语法格式 5.1rewrite语法格式…

互联网常见架构接口压测性能分析及调优手段建议

目录 互联网常见架构接口压测性能分析及调优手段建议 1 接口名称: 获取列表 1.1 压测现象:单台tps700多,应用cpu高负载 1.1.1 问题分析: 1.1.2 改进措施: 1.1.3 改进效果: 1.2 压测现象:数据库资源利用率高 1.2.1 问题分析: 1.2.2 改进措施: 1.2.3 改…

SciencePub学术 | 计算机科学类重点SCIEEI征稿中

SciencePub学术 刊源推荐:计算机科学类重点SCIE&EI征稿中!信息如下,录满为止: 一、期刊概况: 计算机科学类重点SCIE&EI 【期刊简介】IF:3.0-3.5,JCR 2区,中科院4区; 【检…

使用R绘制气泡图、带有显著性标记的热力图、渐变曲线图

大家好,我是带我去滑雪! 一幅精美的科研绘图会有诸多益处,精美的图像可以更好地传达研究结果和数据分析的重要信息。通过使用清晰、直观和易于理解的图像,可以更好地向读者展示研究的发现,有助于读者理解和解释数据。还…

JAVA开发(记一次504 gateway timeout错误排查过程)

一、问题与背景: 最近在发布一个web项目,在测试环境都是可以的,发布到生产环境通过IP访问也是可以的,但是通过域名访问就出现504 gateway timeout。通过postman去测试接口也是一样。ip和端口都可以通,域名却不行&…

如何在矩池云上运行 AI 图像编辑工具 DragGAN

5 月,DragGAN 横空出世,在开源代码尚未公布前,就在Github上斩获近 20000 Star,彼时,页面上只有效果图和一句“Code will be released in June”,然而这也足够带给人们无限期待。 在6月末,在若干…

SpringBoot最多可以处理多少个请求?

SpringBoot最多可以处理多少个请求? SpringBoot夺命连环14问,1天刷完别人半个月的springboot面试内容,比啃书效果好多了!_哔哩哔哩_bilibili 最小线程数:最少的厨师的量,饭店人不多的时候的量。 最大线程数…