pytorch-ResNet18简单复现

news2024/10/5 22:23:01

目录

  • 1. ResNet block
  • 2. ResNet18网络结构
  • 3. 完整代码
    • 3.1 网络代码
    • 3.2 训练代码

1. ResNet block

ResNet block有两个convolution和一个short cut层,如下图:
在这里插入图片描述
代码:

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self. bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        out = self.extra(x) + out
        out = F.relu(out)

        return out

2. ResNet18网络结构

在这里插入图片描述
在这里插入图片描述
从上图可以看出,resnet18有1个卷积层,4个残差层和1一个线性输出层,其中每个残差层有2个resnet块,每个块有2个卷积层。
对于cifar10数据来说,输入层[b, 64, 32,32],输出是10分类
代码:

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # 四个残差层
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # 全连接层
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    # 创建一个残差层
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #out = F.avg_pool2d(out, 4)
        out = F.adaptive_avg_pool2d(out, [1, 1])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

3. 完整代码

3.1 网络代码

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    expansion = 1

    def __init__(self, ch_in, ch_out, stride):
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        out = self.extra(x) + out
        out = F.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # 四个残差层
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # 全连接层
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    # 创建一个残差层
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #out = F.avg_pool2d(out, 4)
        out = F.adaptive_avg_pool2d(out, [1, 1])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(ResBlk, [2, 2, 2, 2], 10)


if __name__ == '__main__':
    model = ResNet18()
    print(model)

3.2 训练代码

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import sys

sys.path.append('.')
#from Lenet5 import Lenet5
from resnet import ResNet18


def main():
    batchz = 128
    cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)

    device = torch.device('cuda')
    #model = Lenet5().to(device)
    model = ResNet18().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1888415.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

锻炼 读书笔记 番外 身体激素及神经递质

最近在读《锻炼》的时候,对于各种激素很感兴趣,多巴胺、内啡肽、荷尔蒙、肾上腺素、褪黑素、皮质醇、糖化、氧化等等。索性认真梳理下它们是什么,思考当处于心流状态时,人体发生什么样的变化,分泌什么激素?…

Milvus ConnectionRefusedError: how to connect locally

题意:怎样在本地连接到 Milvus 数据库。连接 Milvus 数据库被拒绝的错误 问题背景: I am trying to run a RAG pipeline using haystack & Milvus. 我正在尝试使用 haystack 和 Milvus 运行一个 RAG(检索增强型生成)管道。 …

10 - Python文件编程和异常

文件和异常 在实际开发中,常常需要对程序中的数据进行持久化操作,而实现数据持久化最直接简单的方式就是将数据保存到文件中。说到“文件”这个词,可能需要先科普一下关于文件系统的知识,对于这个概念,维基百科上给出…

重温被Mamba带火的SSM:HiPPO的一些遗留问题

©PaperWeekly 原创 作者 | 苏剑林 单位 | 科学空间 研究方向 | NLP、神经网络 书接上文,在上一篇文章《重温被Mamba带火的SSM:线性系统和HiPPO矩阵》中,我们详细讨论了 HiPPO 逼近框架其 HiPPO 矩阵的推导,其原理是通过正交…

MySQL的Geometry数据处理之WKB方案

MySQL的Geometry数据处理之WKT方案:https://blog.csdn.net/qq_42402854/article/details/140134357 MySQL的Geometry数据处理之WKT方案中,介绍WTK方案的优点,也感受到它的繁琐和缺陷。比如: 需要借助 ST_GeomFromText和 ST_AsTex…

Gradio 教程四:Building Generative AI Applications with Gradio

文章目录 一、使用interface构建NLP应用1.1 构建文本摘要应用1.1.1 设置API密钥1.1.2 调用文本摘要API1.1.3 运行本地模型获取响应1.1.4 使用interface构建应用 1.2 构建命名实体识别应用1.2.1 调用NER任务API1.2.2 使用interface构建应用1.2.3 加入额外函数,合并to…

C语言实战 | 用户管理系统

近期推出的青少年防沉迷系统,采用统一运行模式和功能标准。在“青少年模式”下,未成年人的上网时段、时长、功能和浏览内容等方面都有明确的规范。防沉迷系统为青少年打开可控的网络空间。 01、综合案例 防沉迷系统的基础是需要一个用户管理系统管理用户…

Unity3d C#实现基于UGUI ScrollRect的轮播图效果功能(含源码)

前言 轮播功能是一种常见的页面组件,用于在页面中显示多张图片/素材并自动或手动进行切换,以提高页面的美观度和用户体验。主要的功能是:自动/手动切换;平滑的切换效果;导航指示器等。可惜Unity的UGUI系统里没有现成的实现该功能&#xff0c…

[Labview] 改写表格内容并储存覆盖Excel

在上一个功能的基础上,新增表格改写保存功能 [Labview] Excel读表 & 输出表单中选中的单元格内容https://blog.csdn.net/Katrina419/article/details/140120584 Excel修改前: 修改保存后,动态改写储存Excel,并重新写入新的表…

使用antd的<Form/>组件获取富文本编辑器输入的数据

前端开发中,嵌入富文本编辑器时,可以通过富文本编辑器自身的事件处理函数将数据传输给后端。有时候,场景稍微复杂点,比如一个输入页面除了要保存富文本编辑器的内容到后端,可能还有一些其他输入组件获取到的数据也一并…

Go - 8.func 函数使用

目录 一.引言 二.func 定义 三.func 实践 1.多个返回值 2.命名返回值 3.可变参数 四.总结 一.引言 函数是编程语言中的基本构建块,用于将代码组织成可重用的逻辑单元。函数可以接受输入参数,执行特定的操作,并返回结果。在 Go 语言&a…

设计IC行业SAP软件如何处理芯片成本计算

在集成电路(IC)设计与制造行业中,精确的成本计算对于维持健康的财务状况、优化生产流程以及保持市场竞争力至关重要。SAP软件,作为一种全面的企业资源规划(ERP)解决方案,为IC行业提供了强大且灵活的成本计算工具。以下是SAP软件如何处理芯片成…

【python】OpenCV—Feature Detection and Matching

参考学习来自OpenCV基础(23)特征检测与匹配 文章目录 1 背景介绍2 Harris角点检测3 Shi-Tomasi角点检测4 Fast 角点检测5 BRIEF 特征描述子6 ORB(Oriented Fast and Rotated Brief) 特征描述子7 SIFT(Scale Invariant Feature Transform) 特征描述子8 SU…

​​​​​​​​​​​​​​Spark Standalone集群环境

目录 Spark Standalone集群环境 修改配置文件 【workers】 【spark-env.sh】 【配置spark应用日志】 【log4j.properties】 分发到其他机器 启动spark Standalone 启动方式1:集群启动和停止 启动方式2:单独启动和停止 连接集群 【spark-shel…

基于Hadoop平台的电信客服数据的处理与分析③项目开发:搭建基于Hadoop的全分布式集群---任务1:运行环境说明

任务描述 项目的运行环境是基于Hadoop的全分布式模式集群。 任务的主要内容是规划集群节点及网络使用,准备初始环境,关闭防火墙和Selinux。 任务指导 1. 基于Hadoop的全分布式模式集群,如下图所示; 2. 硬软件环境:…

Android性能优化面试题经典之ANR的分析和优化

本文首发于公众号“AntDream”,欢迎微信搜索“AntDream”或扫描文章底部二维码关注,和我一起每天进步一点点 造成ANR的条件 以下四个条件都可以造成ANR发生: InputDispatching Timeout:5秒内无法响应屏幕触摸事件或键盘输入事件 …

《企业实战分享 · MyBatis 使用合集》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 近期刚转战 CSDN,会严格把控文章质量,绝不滥竽充数,如需交流&#xff…

Seatunnel本地模式快速测验

前言 SeaTunnel(先前称为WaterDrop)是一个分布式、高性能、易于扩展的数据集成平台,旨在实现海量数据的同步和转换。它支持多种数据处理引擎,包括Apache Spark和Apache Flink,并在某个版本中引入了自主研发的Zeta引擎…

虚拟机交叉编译基于ARM平台的opencv(ffmpeg/x264)

背景: 由于手上有一块rk3568的开发板,需要运行yolov5跑深度学习模型,但是原有的opencv不能对x264格式的视频进行解码,这里就需要将ffmpegx264编译进opencv。 但是开发板算力有限,所以这里采用在windows下,安…

2024年07年01日 Redis数据类型以及使用场景

String Hash List Set Sorted Set String,用的最多,对象序列化成json然后存储 1.对象缓存,单值缓存 2.分布式锁 Hash,不怎么用到 1.可缓存经常需要修改值的对象,可单独对对象某个属性进行修改 HMSET user {userI…