人工智能(Pytorch)搭建模型7-改造后的新型RegNet设计空间模型的搭建与训练

news2024/10/7 18:52:34

大家好,我是微学AI,今天给大家带来人工智能(Pytorch)搭建模型7-新型的卷积神经网络RegNet模型的搭建与训练,RegNet是一种新颖的卷积神经网络架构,它的设计理念是通过稀疏网络结构和精细的正则化来实现高效的计算和更好的泛化能力。RegNet最初是用于图像分类任务,在ImageNet上实现了较好的性能,同时也受到了广泛的关注。

一、RegNet设计空间介绍

RegNet是一种用于图像分类的不同卷积神经网络模型架构组成的设计空间。RegNet的设计原则是通过正则化来平衡网络的深度和宽度,从而实现更好的性能。下面列出了RegNet模型的主要特点:

RegNet架构具有可配置的深度(层数)和宽度(每层的通道数)。

RegNet架构使用了分组卷积和跳跃连接来提高计算效率和模型性能。

RegNet架构可以通过调整参数来在不同的计算资源和性能要求之间进行权衡。

RegNet的稀疏结构是通过对网络宽度、深度和分辨率进行优化,减少了过度的冗余特征和层。同时,RegNet采用了一个特定的正则化方法,称为“网络级别L2正则化”,以帮助网络获得更好的泛化能力。此正则化方法在网络的结构定义阶段进行,为每一个层次参数引入一定的正则化强度。

改造RegNet架构后的特点:

1.高效的计算速度:该模型具有较小的参数量和计算成本,能够在移动设备等较低配置的硬件上快速运行。

2.较好的泛化能力:该模型通过使用Residual Block和L2正则化等技术来防止过拟合,从而提高了模型的泛化性能。

3.可拓展性:该模型的网络设计空间灵活,可以通过搜索算法来寻找最优的网络结构,同时也可以通过调整网络深度和宽度等参数进行适应不同任务。

4.简单易懂的结构:该模型只包含几个简单但非常重要的模块,并且模块的形式和功能就和传统的卷积神经网络非常相似。

构建的模型结构图:

二、PyTorch实现

2.1 准备数据

首先,我们生成假数据以进行模型训练和测试。这里我们使用torch.randn()函数生成随机数据,模拟图像数据。为了简化问题,我们假设输入图像的尺寸为3x32x32,类别数为4。

import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
# 生成假数据
num_samples = 200
input_shape = (3, 32, 32)
num_classes = 4

image_data = np.random.rand(num_samples, 3, 32, 32).astype(np.float32)
labels = np.random.randint(0, 4, size=num_samples, dtype=np.int64)

# 创建数据集和数据加载器
train_data = TensorDataset(torch.from_numpy(image_data), torch.from_numpy(labels))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)

2.2 构建RegNet架构

接下来,我们使用PyTorch构建RegNet模型。首先,我们定义一个基本的卷积块ConvBlock,它由一个卷积层、一个批量归一化层和一个ELU激活函数组成。然后,我们定义一个残差块ResidualBlock,它包含两个卷积块,并使用跳跃连接将输入添加到输出。

import torch.nn as nn
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.ELU = nn.ELU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.ELU(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.conv2(x)
        x += identity
        return x


class RegNet(nn.Module):
    def __init__(self, input_shape, num_classes, num_blocks):
        super(RegNet, self).__init__()
        self.in_channels = 64
        self.conv1 = ConvBlock(input_shape[0], self.in_channels, kernel_size=7, stride=2, padding=3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.stage2 = self._make_stage(num_blocks[0], self.in_channels * 1)
        self.stage3 = self._make_stage(num_blocks[1], self.in_channels * 1)
        #self.stage4 = self._make_stage(num_blocks[2], self.in_channels * 1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels * 1, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def _make_stage(self, num_blocks, out_channels):
        layers = []
        for _ in range(num_blocks):
            layers.append(ResidualBlock(self.in_channels, out_channels))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        #x = self.stage4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.softmax(x)
        return x

2.3 训练和测试

现在我们已经构建好RegNet模型,接下来我们使用随机生成的数据进行训练和测试。我们使用交叉熵损失函数和随机梯度下降优化器。

# 设置超参数
num_epochs = 200
learning_rate = 0.001
momentum = 0.8
#weight_decay = 0.0001

# 初始化模型、损失函数和优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = RegNet(input_shape, num_classes, num_blocks=[2, 2, 2]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / (i + 1)}")


# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    #print(f"Train Accuracy: {100 * correct / total}%")
    print(f"Accuracy of the model on the {total} test images: {100 * correct / total}%")

运行结果:

Epoch [188/200], Loss: 0.7604680061340332
Epoch [189/200], Loss: 0.7604513049125672
Epoch [190/200], Loss: 0.7604348838329316
Epoch [191/200], Loss: 0.7604187309741974
Epoch [192/200], Loss: 0.7604028165340424
Epoch [193/200], Loss: 0.760387197136879
Epoch [194/200], Loss: 0.760371795296669
Epoch [195/200], Loss: 0.7603566676378251
Epoch [196/200], Loss: 0.7603417694568634
Epoch [197/200], Loss: 0.7603270918130874
Epoch [198/200], Loss: 0.7603126436471939
Epoch [199/200], Loss: 0.7602984338998795
Epoch [200/200], Loss: 0.7602844357490539
Accuracy of the model on the 200 test images: 98.0%

大家可以自己放入图片数据进行训练哦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/572950.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小白看了也会的Redux编程

目录 介绍 演示 异步action react-redux 多组件管理的react-redux 扩展 介绍 redux是专门用于集中式管理状态的javascript库,并不是react的插件库。 比如你有多个组件A-E都想要用同一个组件D中的状态: 1)像以前我们可以通过父子组件通…

前几天面了个30岁左右的测试员,年薪50w问题基本都能回答上,必是刷了不少八股文···

互联网行业竞争是一年比一年严峻,作为测试工程师的我们唯有不停地学习,不断的提升自己才能保证自己的核心竞争力从而拿到更好的薪水,进入心仪的企业(阿里、字节、美团、腾讯等大厂.....) 所以,大家就迎来了…

TCP三次握手四次挥手(幽默版)

三次握手: 假设你是一只鸟,你要与另一只鸟进行交流。(你是客户端) 1.首先你会问候:“你好,我是一只鸟,你可以听到我说话吗?”(一次会话) 2.另一只鸟回答&am…

一文带你了解MySQL之Explain执行计划

前言: 一条查询语句在经过MySQL查询优化器的各种基于成本和规则的优化会后生成一个所谓的执行计划,这个执行计划展示了接下来具体执行查询的方式,比如多表连接的顺序是什么,对于每个表采用什么访问方法来具体执行查询等等。MySQL…

MySQL---show profile分析SQL、trace分析优化器执行计划

1. show profile分析SQL Mysql从5.0.37版本开始增加了对 show profiles 和 show profile 语句的支持。show profiles 能够 在做SQL优化时帮助我们了解时间都耗费到哪里去了。 通过 have_profiling 参数,能够看到当前MySQL是否支持profile: select ha…

3年软件测试经验月薪7k,只会“点点点”,我该如何破局?

经常听到一些行业内的朋友说 “做测试,有手就行” 但事实真的是如此嘛? 随着测试行业的发展,越来越多的测试岗位对自动化测试,性能测试都有所要求,这对于很多只会功能测试的职场老人们来说,有了一丝丝的危…

Druid连接池技术实践

什么是Druid连接池? Druid连接池是阿里巴巴开源的数据库连接池项目。 Druid连接池为监控而生,内置强大的监控功能,监控特性不影响性能。功能强大,能防SQL注入,内置Loging能诊断Hack应用行为。 哦,首先Dru…

2023ACP世界大赛教育者论坛:让职业教育直面AI机遇与挑战

“AI技术的普及对创意行业和教育带来的影响和变革-2023 Adobe Certified Professional教育者论坛”在苏州西交利物浦大学成功举办。 本次论坛,由Adobe Certified Professional 世界大赛中国赛区组委会主办,联动了来自院校、海内外杰出的创意公司及国际知…

搭建飞书早报机器人

飞书是字节跳动推出的一款企业级通讯及协作平台,于2016年正式上线。它是一款基于云计算技术的软件工具,可以帮助企业实现快速高效的沟通和协作,提升工作效率,降低沟通成本。下面将详细介绍飞书的功能、特点以及使用体验。 功能介…

Android动画深入分析(View动画)

Android动画深入分析(View动画) Android的动画我其实在View的滑动里面写过,主要还是分为2点。 一个就是View动画,还有一个是属性动画 先讲述View动画 View动画 View动画主要分为4种,平移动画,缩放动画,旋转动画,透明度动画。 还有一个叫帧动画,但是表现方式和…

python+vue旅游攻略分享推荐网站p0667

基于Python语言设计并实现了旅游分享网站。该系统基于B/S即所谓浏览器/服务器模式,应用Django框架,选择MySQL作为后台数据库。系统主要包括用户、景点信息、攻略分类、旅游攻略、门票购买、留言反馈、论坛管理、系统管理等功能模块。 软件开发前的需求分…

某渣渣企业平台相关加密参数

网址 aHR0cHM6Ly93d3cucWNjLmNvbS93ZWIvZWxpYi90ZWNsaXN0P3RlYz1UX1RTTUVT抓包 GET /api/elib/getTecList?countyCode110101&flag&industry&isSortAsc&pageIndex2&pageSize20&provinceBJ&registCapiBegin&registCapiEnd&searchKey&…

修改git已经push到远端的最近一次提交的commit

需求: 最新一次提交的message写错了且已经push到远程仓库,但是又不想重新创建一个commit记录。 注意: 如果是多人协同开发,使用强推前一定确保当前版本最新,期间无人提交代码。 使用git Bash进入命令行窗口 git co…

基于langChain 的privateGPT 文档问答 研究

参考:gihtub代码 https://github.com/imartinez/privateGPT 官网 privateGPT可以在断网的情况下,借助GPT和文档进行交互,有利于保护数据隐私。 privateGPT可以有四个用处: 1.增强知识管理:私有LLMs自动化&#xff0c…

《Spring Guides系列学习》guide26 - guide30

要想全面快速学习Spring的内容,最好的方法肯定是先去Spring官网去查阅文档,在Spring官网中找到了适合新手了解的官网Guides,一共68篇,打算全部过一遍,能尽量全面的了解Spring框架的每个特性和功能。 接着上篇看过的gu…

numpy库报错has no attribute ‘_no_nep50_warning‘的解决

本文介绍在Python中,numpy库出现报错module numpy has no attribute _no_nep50_warning的解决方法。 一次,在运行一个Python代码时,发现出现报错module numpy has no attribute _no_nep50_warning,如下图所示。 其中,这…

华为nova11系列:一个月的深度体验感受,告诉你值不值得入手

作为一个追求时尚风格的年轻人, nova系列手机一直是我的关注重点。nova 11 Pro发布之后,独特少见的11号色一下子就戳中了我,于是第一时间我给我自己和我老婆分别下单了一台nova 11和nova 11 Pro。 作为主力机深度使用一个月后,可以…

如何做好建筑行业的信息化建设?

如何做好建筑行业的信息化建设? 首先,我们来了解一下,什么是信息化转型? 信息化转型是指企业或组织通过应用信息技术,以提高业务效率和创新能力,实现组织战略目标的过程。 随着数字技术的发展&#xff0…

把字节大佬花3个月时间整理的软件测试面经偷偷给室友,差点被他开除了···

写在前面 “这份软件测试面经看起来不错,等会一起发给他吧”,我看着面前的面试笔记自言自语道。 就在这时,背后传来了leder“阴森森”的声音:“不错吧,我可是足足花了三个月整理的” 始末 刚入职字节的我收到了大学室…

Junit常见用法

一.Junit的含义 Junit是一种Java编程语言的单元测试框架。它提供了一些用于编写和运行测试的注释和断言方法,并且可以方便地执行测试并生成测试报告。Junit是开源的,也是广泛使用的单元测试框架之一。 二.Junit常用注解 1.Test 表示执行此测试用例 T…