人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型

news2024/11/20 4:41:17

大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型,在本文中,我们将学习如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。本文将涵盖这些内容:ResNet模型简介、ResNet模型结构、生成假数据、实现ResNet模型、训练与测试模型。

一、ResNet模型简介

ResNet(残差网络)模型是由何恺明等人在2015年提出的一种深度卷积神经网络。它的主要创新是引入了残差结构,通过这种结构,ResNet可以有效地解决深度神经网络难以训练的问题。ResNet在多个图像分类任务上取得了非常好的效果,包括ImageNet大规模视觉识别挑战赛。ResNet模型使得卷积神经网络不受到层数的限制,解决了层数越深,模型预测结果越差的情况。

二、ResNet模型结构

ResNet的核心思想是引入残差块,残差块的输入不仅直接传给下一层,而且还通过一个跳跃连接(Skip Connection)直接连接到后面的层。这种结构可以有效地缓解梯度消失问题,使得网络可以被有效地训练。

一个典型的残差块包含两个卷积层、两个激活函数和一个跳跃连接。ResNet网络的层数可以通过堆叠不同数量的残差块来实现,例如常见的ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。

三、生成假数据

为了演示模型的训练和测试,我们将在本文中使用生成的假数据。我们将生成一个包含1000个3通道32x32图像的数据集,使用随机数来表示图像像素值,并为每个图像分配一个介于0到9之间的类别标签。

四、实现ResNet模型

接下来,我们将使用PyTorch框架实现一个简化版的ResNet-18模型。首先,我们需要导入所需的库:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.skip_connection = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.skip_connection(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, stride=1)
        self.layer2 = self._make_layer(block, 128, stride=2)
        self.layer3 = self._make_layer(block, 256, stride=2)
        self.layer4 = self._make_layer(block, 512, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, stride):
        layers = [block(self.in_channels, out_channels, stride)]
        self.in_channels = out_channels
        layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

残差连接图: 

五、训练与测试模型

现在我们创建数据集、模型、损失函数和优化器,然后进行训练:

# 生成假数据
num_samples = 1000
image_data = np.random.rand(num_samples, 3, 32, 32).astype(np.float32)
labels = np.random.randint(0, 10, size=num_samples, dtype=np.int64)

# 创建数据集和数据加载器
train_data = TensorDataset(torch.from_numpy(image_data), torch.from_numpy(labels))
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

# 创建模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(ResidualBlock).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / (i + 1)}")

# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy of the model on the {total} test images: {100 * correct / total}%")

六、总结

文章详细介绍了如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。通过实现简化版的ResNet-18模型,我们了解了残差块的结构和原理,以及如何利用残差结构有效地训练深度神经网络。在实际应用中,可以通过调整模型结构和参数,以及使用真实的数据集来进一步提升模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/568857.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux---vi/vim编辑器、查阅命令

1. vi\vim编辑器三种模式 vim 是 vi 的加强版本,兼容 vi 的所有指令,不仅能编辑文本,而且还具有 shell 程序编辑的功能, 可以不同颜色的字体来辨别语法的正确性,极大方便了程序的设计和编辑性。 命令模式&#xff08…

整数智能重磅推出集成SAM的智能标注工具2.0

前言 图像语义分割一直是数据标注中最繁琐、最耗时的标注任务之一,利用钢笔工具手动描边的标注方式所带来的时间成本和低准确率都将影响模型的生产速度和模型性能。整数智能ABAVA数据工程平台最新发布了基于SAM(Segement Anything Model)改进…

【面试题】面试题总结

加油加油 文章目录 1. TCP与UDP的区别2. TCP为什么是四次挥手机制3. HTTP与HTTPS的区别4. HTTPS加密机制5. 简要介绍SSL/TSL协议6. GET与POST的区别7. cookie与session的区别8. JVM内存区域划分9. 程序运行时候内存不足,会出现什么状况10. 显式调用GC会立即执行吗11…

【Python json】零基础也能轻松掌握的学习路线与参考资料

Python中的JSON模块主要用于将Python对象序列化成JSON数据或解析包含JSON数据的字符串。JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,易于人阅读和编写,同时也易于机器解析和生成。由于JSON在Web应用中的广泛使用…

SpringCloud Sentinel实战限流熔断降级应用

目录 1 Sentinel核心库1.1 Sentinel介绍1.2 Sentinel核心功能1.2.1 流量控制1.2.2 熔断降级 2 Sentinel 限流熔断降级2.1 SentinelResource定义资源2.2 Sentinel的规则2.2.1 流量控制规则 (FlowRule)2.2.2 熔断降级规则 (DegradeRule)2.2.3 系统保护规则 (SystemRule)2.2.4 访问…

Tomcat配置https协议证书-阿里云,Nginx配置https协议证书-阿里云,Tomcat配置https证书pfx转jks

Tomcat/Nginx配置https协议证书 前言Tomcat配置https协议证书-阿里云方式一 pfx配置证书重启即可 方式二 jkspfx生成jks配置证书重启即可 Nginx配置https协议证书-阿里云实现方式重启即可 其他Tomcat相关配置例子如下nginx配置相关例子如下 前言 阿里云官网:https:…

探索Java面向对象编程的奇妙世界(三)

⭐ 垃圾回收机制(Garbage Collection)⭐ JVM 调优和 Full GC⭐ this 关键字⭐ static 关键字 ⭐ 垃圾回收机制(Garbage Collection) Java 引入了垃圾回收机制,令 C程序员最头疼的内存管理问题迎刃而解。Java 程序员可以将更多的精力放到业务逻辑上而不是内存管理工…

网安面试只要掌握这十点技巧,绝对轻轻松松吊打面试官

结合工作经验,在这里笔者给企业网管员提供一些保障企业网络安全的建议,帮助他们用以抵御网络入侵、恶意软件和垃圾邮件。 定义用户完成相关任务的恰当权限 拥有管理员权限的用户也就拥有执行破坏系统的活动能力,例如: ・偶然对系…

挂耳式耳机品牌排行榜,良心推荐这四款蓝牙耳机

在蓝牙耳机越来越普及的同时,大家开始重视佩戴耳机时的舒适度,市面上的耳机形式也逐步迭代,目前流行的开放式耳机不仅很好的避免长期佩戴耳机产生的酸痛感,而且对耳道健康问题的处理也具有极佳的效果。那么,面对市面上…

VB显示“shell32.dll”中的图标

在Form上添加一个ListBox列表控件 代码如下: Option Explicit Private Declare Function ExtractIconEx Lib “shell32.dll” Alias “ExtractIconExA” (ByVal lpszFile As String, ByVal nIconIndex As Long, phiconLarge As Long, phiconSmall As Long, ByVal nI…

奇偶分频电路

目录 偶数分频 寄存器级联法 计数器法 奇数分频 不满足50%占空比 50%占空比 偶数分频 寄存器级联法 寄存器级联法能实现2^N的偶数分频,具体做法是采用寄存器结构的电路,每当时钟上升沿到来的时候对输出结果进行翻转,以此来实现偶数分…

拥有自我意识的AI:AutoGPT | 得物技术

1.引言 ChatGPT在当下已经风靡一时,作为自然语言处理模型的佼佼者,ChatGPT的优势在于其能够生成流畅、连贯的对话,同时还能够理解上下文并根据上下文进行回答。针对不同的应用场景可以进行快速定制,例如,在客服、教育…

【Unity-UGUI控件全面解析】| TextMeshPro 控件详解

🎬【Unity-UGUI控件全面解析】| TextMeshPro控件详解一、组件介绍二、组件属性面板三、代码操作组件四、组件常用方法示例4.1 Font Asset Creator 面板介绍4.2 制作中文字体库五、组件相关扩展使用5.1 软化/扩张 效果5.2 描边效果5.3 投影效果5.4 光照效果5.5 外发光效果💯…

5.25 费解的开关

思路&#xff0c;枚举 开关按下两次就复原&#xff0c;所以一个开关只有两种情况&#xff0c;按下和不按下&#xff0c;5*5的开关&#xff0c;一共25个开关&#xff0c;一共有2^25种情况&#xff0c;for (int i 0;i < 2^25;i)进行操作&#xff0c;计算按下开关次数&#x…

mac 安装 MongoDB

一.官网下载安装包 1.1 下载安装包 Download MongoDB Community Server | MongoDB 1.2 将下载好的 MongoDB 安装包解压缩&#xff0c;并将文件夹名改为 mongodb&#xff08;可改成自己想要的任何名字&#xff09;。 1.3 按快捷键 Command Shift G 打开前往文件夹弹窗&#…

没有经验能做产品经理吗?

没有经验能做产品经理吗&#xff1f;这是一个经常被讨论的问题&#xff0c;因为很多人想转行成为产品经理&#xff0c;但他们没有相关的工作经验。这里我也给出一些解答。 一、产品经理的职责和技能 首先&#xff0c;让我们看一下产品经理的职责和技能。产品经理是负责产品开…

LeetCode:相交链表(java)

相交链表 题目描述指针法解题 #LeetCode 160题&#xff1a;相交链表&#xff0c;原题链接 原题链接。相交链表–可以打开测试 题目描述 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返…

【Python Power BI】零基础也能轻松掌握的学习路线与参考资料

Python和Power BI是现代数据分析和可视化领域中最受欢迎的工具之一&#xff0c;Python是一种高级编程语言&#xff0c;广泛用于数据科学和分析&#xff0c;而Power BI是一种业务智能工具&#xff0c;用于创建交互式大屏幕和实时报表。Python和Power BI的结合使用可以为数据科学…

【布隆过滤器】BitMap与布隆过滤器

1.案例&#xff1a;40亿个QQ号&#xff0c;限制1G内存&#xff0c;如何去重&#xff1f; 40亿个unsigned int&#xff0c;如果直接用内存存储的话&#xff0c;需要&#xff1a; 4*4000000000 /1024/1024/1024 14.9G &#xff0c;考虑到其中有一些重复的话&#xff0c;那1G的…

【P31】JMeter 循环控制器(Loop Controller)

这文章目录 一、循环控制器&#xff08;Loop Controller&#xff09;参数说明二、测试计划设计2.1、设置循环次数2.2、勾选永远2.3、设置线程组的持续时间 一、循环控制器&#xff08;Loop Controller&#xff09;参数说明 可以对部分逻辑按常量进行循环迭代 选择线程组右键 …